# **Starting Work**

In [None]:
from os import chdir, environ

In [None]:
!git pull

In [None]:
!ls -lt --time-style='+%y-%m-%d %H:%M:%S'
!dir

# **Experiment Init**

In [None]:
import json
import h5py
import torch
import os
from encoder import Encoder, MLP
from segment import Segment
from res_encoder import ResSegment
from dataset import *
# from encoders.model_cross import Encoder, init_logging, MLP
from build_tree import get_directions, init_directions
import logging

OUTPUT = 'scratch_s3dis'

#encoder
model_size = 4096
sample_layers = 50 # 2
channel = 1
dim = 4096 # 2048
dim_layer0 = 16 # 16
dim_repeat_cut = 5
use_sym = False #True
use_tnet = True # True

#decoder 
ancestor_dim = 512 # 512
model_name = "segment"
use_dyn_tree = False


#homo w/o pa
# prefix = "_homo_nopa_xinf"
# transform = homo_transform
# augment = 1
# no_prealign = True 
# rotate_only = False 
# load_pretrain = False
# augment_fn = lambda pts: augment_generator(pts) #, fetch_perm=True), dropout=0.5) #, shift=True) #, scale=True, agg_coef=0.5)
# pca_augment = False
# test_as_valid = True #True
# trunc = 6
# part_cls_dropout = None


# #homo
# prefix = "_homo_xinf"
# transform = homo_transform
# augment = 1
# no_prealign = False 
# rotate_only = False 
# load_pretrain = False
# augment_fn = lambda pts: augment_generator(pts) #, fetch_perm=True), dropout=0.5) #, shift=True) #, scale=True, agg_coef=0.5)
# pca_augment = False
# test_as_valid = True #True
# trunc = 6
# part_cls_dropout = None

# # orig
# prefix = "_orig_xinf"
# transform = no_transform
# augment = 1
# no_prealign = False #True 
# rotate_only = True #False
# load_pretrain = False
# augment_fn = lambda pts: augment_generator(pts, fetch_perm=True, dropout=0.5) #, shift=True) #, scale=True, agg_coef=0.5)
# pca_augment = False
# test_as_valid = True #True
# trunc = 9
# part_cls_dropout = 0.5 #0.5

# #affine
prefix = "_affine_xinf"
transform = affine_transform
augment = 1
no_prealign = False 
rotate_only = False 
load_pretrain = False
augment_fn = lambda pts: augment_generator(pts) #, fetch_perm=True), dropout=0.5) #, shift=True) #, scale=True, agg_coef=0.5)
pca_augment = False
test_as_valid = True #True
trunc = 6
part_cls_dropout = None

# #affine w/o pa
# prefix = "_affine_nopa_xinf"
# transform = affine_transform
# augment = 1
# no_prealign = True 
# rotate_only = False 
# load_pretrain = False
# augment_fn = lambda pts: augment_generator(pts) #, fetch_perm=True), dropout=0.5) #, shift=True) #, scale=True, agg_coef=0.5)
# pca_augment = False
# test_as_valid = True #True
# trunc = 6
# part_cls_dropout = None


sample_child_first = False # True in l7s1
num_parts = 13
test_area = 5
DATASET = './datasets/S3DIS_hdf5'
chaos_limit = 0


In [None]:
global logging_init_flag
logging_init_flag = False

def init_logging(OUTPUT):
    global logging_init_flag
    if logging_init_flag is True:
        return
    logging_init_flag = True

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s:\t%(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')

    fh = logging.FileHandler(f"{OUTPUT}/training.log")
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)

    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)

    logger.addHandler(ch)
    logger.addHandler(fh)

init_logging(OUTPUT)
logging.info(f"prefix = {prefix}")
_ = init_directions(chaos_limit, calc_dmap=False)

In [None]:
torch.manual_seed(674433238)
def new_model(num_parts=num_parts):
    if model_name == 'res_segment':
        def encoder_fn(point_dim):
            encoder = Encoder(model_size, sample_layers, dim, OUTPUT, point_dim=point_dim, extra_dim=trunc - 3, channel=channel, sample_child_first=sample_child_first, dim_layer0=dim_layer0, layer0_mlp_dim=ancestor_dim, dim_repeat_cut=dim_repeat_cut, use_symmetry_loss=use_sym).cuda()
            if not use_tnet:
                encoder.layers[0].pts_align = torch.nn.Identity()
            return encoder
        model = ResSegment(ancestor_dim, encoder_fn, use_dyn_tree=use_dyn_tree, carry_dim_seg1=min(dim//16, ancestor_dim), num_parts=num_parts, part_cls_dropout=part_cls_dropout).cuda()   
    else:
        encoder = Encoder(model_size, sample_layers, dim, OUTPUT, channel=channel, extra_dim=trunc - 3, sample_child_first=sample_child_first, dim_layer0=dim_layer0, dim_repeat_cut=dim_repeat_cut, use_symmetry_loss=use_sym).cuda()
        if not use_tnet:
            encoder.layers[0].pts_align = torch.nn.Identity()
        model = Segment(ancestor_dim, encoder, num_parts=num_parts, part_cls_dropout=part_cls_dropout).cuda()
    model.tree.use_sym = False
    return model

model = new_model()

In [None]:
model

# **Data**

In [None]:
from dataset import *
import numpy as np
from math import ceil
from tqdm import tqdm

In [None]:
make = make_data_default
if rotate_only:
    make = make_data_rotate_only
if no_prealign:
    make = make_data_no_prealign

In [None]:
clouds = []
labels = []

for i in tqdm(range(24)):
    data_file = h5py.File(f'{DATASET}/ply_data_all_{i}.h5')
    clouds.append(torch.tensor(np.array(data_file['data'])))
    labels.append(torch.tensor(np.array(data_file['label'])))

clouds = torch.cat(clouds, dim=0)
labels = torch.cat(labels, dim=0)

In [None]:
train_ind = []
test_ind = []

with open(f"{DATASET}/room_filelist.txt") as file:
    for i, filename in enumerate(file.read().split()):
        (test_ind if filename.startswith(f"Area_{test_area}_") else train_ind).append(i)

train_ind = torch.tensor(train_ind)
test_ind = torch.tensor(test_ind)

In [None]:
train_dataset = PointCloudDataset(clouds, labels, model.tree.arrange, trunc=trunc, use_norm=True, subset=train_ind, augment=augment, transform=transform, make=make, augment_fn=augment_fn, sample_points=model_size)
test_dataset = PointCloudDataset(clouds, labels, model.tree.arrange, trunc=trunc, use_norm=True, subset=test_ind, augment=1, transform=transform, make=make, sample_points=model_size)
valid_dataset = test_dataset

In [None]:
len(train_dataset), len(valid_dataset), len(test_dataset)

# **Train**

## Defs

In [None]:
from dataset import make_batch_train, make_batch_eval
import torch.nn as nn

global batch_size
batch_size = 64

for name, dataset in zip(['train', 'valid', 'test'], [train_dataset, valid_dataset, test_dataset]):
    print(f"Loading {name}")
    dataset.mem = None

    try:
        tmp = torch.load(f'{OUTPUT}/{name}_data{prefix}.pth')
        dataset.mem = tmp
        assert len(dataset.mem) == len(dataset), f'Size unmatch: {len(dataset.mem)} != {len(dataset)}'
        continue
    except:
        print("Try part mode")

    dataset.mem = []
    for i in range(0, 1000000000):
        try:
            tmp = torch.load(f'{OUTPUT}/{name}_data{prefix}.{i}.pth')
            dataset.mem += tmp
            print(f"Loaded part {i} # = {len(tmp)}")
        except:
            break
    if len(dataset.mem) > 0:
        continue

    print("Use force online")
    dataset.mem = None
    dataset.force_online = True

def inf_iter(a):
    while True:
        for k in a:
            yield(k)


In [None]:
class AccStat:
    def __init__(self):
        self.clear()

    def clear(self):
        self.correct = 0
        self.total = 0

    def add(self, cor, num=1, mean=False):
        if mean:
            cor *= num
        self.correct += cor
        self.total += num

    def result(self, clear=False):
        ret = self.correct / max(self.total, 1e-5)
        if clear:
            self.clear()
        return ret 

    def __str__(self):
        return "%.4lf" % self.result()

def calc_miou(logits, part_label):
    if len(logits.shape) < 3:
        logits = logits.unsqueeze(0)
    
    if len(part_label.shape) < 2:
        part_label = part_label.unsqueeze(0)

    pred = logits.cuda().argmax(dim=-1)
    part = torch.arange(logits.size(-1)).cuda()

    in_pred = part[None, :, None] == pred[:, None, :]
    in_label = part[None, :, None] == part_label.cuda()[:, None, :]

    I = (in_pred & in_label).sum(dim=-1)
    U = (in_pred | in_label).sum(dim=-1)

    part_IOU = I / U.clamp(min=1).float()
    part_IOU[U < 0.5] = 1
    shape_IOU = part_IOU.mean(dim=-1)
    return shape_IOU.mean().item()

def evaluate(model, linear, loader, noprint=False, perms=[None]):
    cls = AccStat()
    stat = AccStat()
    miou = AccStat()

    if not noprint:
        logging.info(f"loader # = {len(loader)}")

    print_epoch = 1

    model.eval()
    linear.eval()
    activate = lambda x : x

    for epoch, (input, part_label) in enumerate(loader):
        part_label = part_label.cuda()
        with torch.no_grad():
            for iperm in perms:
                features = linear(model(*input, perm=iperm))
                logits_all = model.part_classfier(features)

                stat.add((logits_all.max(dim=-1).indices == part_label).float().mean())
                miou.add(calc_miou(logits_all, part_label))

                   
        if not noprint:
            if (epoch // batch_size + 1) % print_epoch == 0:
                logging.debug(f"test #{epoch} correct = {'%.6lf' % stat.result()}")

    if not noprint:
        logging.info(f"Done: score = {'%.8lf' % stat.result()}")

    model.train()
    linear.train()  
    
    return stat.result(), "%.6lf %.6lf" % (stat.result(), miou.result())


In [None]:
import torch.nn.functional as F
def train(more_epoch=100000, current_epoch=0, ckpt=None, valid_result_threshold=999.0):
    best_vres = -1.0
    
    from random import choice, randint
    import build_tree
    import torch
    
    model = new_model(num_parts).cuda()
    linear = nn.Identity()
    model.train()
    linear.train()
    activate = lambda x : x

    def get_trans(n=3):
        if n == 0:  return [None]
        return [randint(0, len(build_tree.transforms) - 1) for _ in range(n)]

    global batch_size 

    num_workers = 12
    mbtrain = make_batch_generator(pca_augment=pca_augment and not no_prealign)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=mbtrain, pin_memory=True, drop_last=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=make_batch_eval, pin_memory=True, drop_last=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=make_batch_eval, pin_memory=True, drop_last=False)

    logging.info(f"train epoch = {current_epoch + 1} ~ {more_epoch} threshold = {valid_result_threshold} count = {len(train_loader) * batch_size} num_parts = {num_parts}")

    cum_loss = 0
    cum_inner_loss = 0

    batch_scale = 1 # 1 if basic else 1
    epoch_scale = max(1, 256 // batch_size) # 8 # 4 if basic else 2
    num_trans = 0 #0 if no_prealign else 1


    print_epoch = 20
    valid_epoch = 100
    epoch_since = 0
    save_epoch = 100
    cut_epoch = 10000000000

    cloud_accuracy = AccStat()
    accuracy = AccStat()
    miou = AccStat()
    
    cloud_classify_coef = 0.3
    part_classify_coef = 0.2
    class_spec_coef = 1

    threshold = -1.0

    temperature = 1 # 0.07

    crit = torch.nn.CrossEntropyLoss(ignore_index=-1)
    opt = torch.optim.Adam(list(model.parameters()) + list(linear.parameters()), lr=1e-4)
    # sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_epoch / 5, eta_min=1e-5)
    # sch = torch.optim.lr_scheduler.ExponentialLR(opt, 0.998)
    sch = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9999)

    if ckpt is not None:
        filename = ckpt
        ckpt = torch.load(ckpt)
        model.load_state_dict(ckpt['model'])
        linear.load_state_dict(ckpt['linear'])
        opt.load_state_dict(ckpt['opt'])
        sch.load_state_dict(ckpt['sch'])
        best_vres = ckpt['best_vres']
        logging.info(f"Loaded ckpt {filename} best_vres = {best_vres}")

    def save(epoch):
        torch.save({
            'model': model.state_dict(),
            'linear': linear.state_dict(),
            'opt': opt.state_dict(),
            'sch': sch.state_dict(),
            'best_vres': best_vres,
        }, f"{OUTPUT}/trained_{epoch}.pth")


    save(0)

    train_iter = inf_iter(train_loader)

    for epoch in range(current_epoch + 1, current_epoch + more_epoch + 1):

        current_epoch = epoch

        for _ in range(epoch_scale):
            loss = torch.tensor(0.).cuda()
            inner_loss = torch.tensor(0.).cuda()
            for _ in range(batch_scale):
                input, part_label = next(train_iter)
                part_label = part_label.cuda()
                for iperm in get_trans(num_trans):
                    
                    features = linear(model(*input, perm=iperm))
                    logits_all = model.part_classfier(features)
                    loss = loss + crit(logits_all.view(-1, num_parts), part_label.view(-1))

                    with torch.no_grad():
                        accuracy.add((logits_all.max(dim=-1).indices == part_label).float().mean())
                        miou.add(calc_miou(logits_all, part_label))

                    if model_name == 'res_segment':
                        inner_features = model.inner_ans
                        logits_inner_all = model.inner_part_classfier(inner_features)
                        inner_loss = inner_loss + crit(logits_inner_all.view(-1, num_parts), part_label.view(-1))
          
                    epoch_since += 1
            
            assert loss.isnan().sum() == 0
            cum_loss += loss.item()
            cum_inner_loss += inner_loss.item()
            opt.zero_grad()
            (loss + inner_loss).backward()
            opt.step()
            
        

        if cum_loss / epoch_since < threshold:
            epoch_scale, batch_scale = batch_scale, epoch_scale
            
            logging.info("Threshold Reached")
            threshold = -1e10
            
        if epoch <= 5 or epoch % print_epoch == 0:
            valid_str = ""
            func = logging.debug

            stop_training = False
            if epoch == 5 or epoch % valid_epoch == 0:
                vres, valid_str = evaluate(model, linear, valid_loader, noprint=True)
                valid_str = "valid = " + valid_str
                stop_training = (vres >= valid_result_threshold)
                if vres > best_vres:
                    best_vres = vres
                    save(f"best{prefix}")
                    valid_str += " updated"

                func = logging.info
            func(f"train #{epoch} lr = {'%.2e' % sch.get_last_lr()[0]} loss = {'%.6lf / %.6lf' % (cum_loss / epoch_since, cum_inner_loss / epoch_since)} train = {accuracy} {miou} {valid_str}")
            epoch_since = cum_loss = cum_inner_loss = 0
            accuracy.clear()
            cloud_accuracy.clear()
            miou.clear()

            if stop_training:
                break

        sch.step()


        if epoch % save_epoch == 0:
            save(0)
            tres, test_str = (best_vres, "%.6lf" % best_vres) if test_as_valid else evaluate(model, linear, test_loader, noprint=True, together=True)
            logging.info(f"Saved(Skipped) test = {test_str}")

        if epoch % cut_epoch == 0:
            if batch_size > 8:
                batch_size //= 2
                epoch_scale *= 2

            logging.info(f"Cut batch_size = {batch_size} epoch_scale = {epoch_scale}")


## Loop

In [None]:
train()