# **Starting Work**

In [None]:
from os import chdir, environ

In [None]:
!git pull

In [None]:
!ls -lt --time-style='+%y-%m-%d %H:%M:%S'
!dir

# **Experiment Init**

In [None]:
import json
import h5py
import torch
import os
from encoder import Encoder, MLP
from segment import Segment
from res_encoder import ResSegment
from dataset import *
# from encoders.model_cross import Encoder, init_logging, MLP
from build_tree import get_directions, init_directions
import logging

OUTPUT = 'scratch_shapenetpart'

#encoder
model_size = 2 ** 11
sample_layers = 50 # 2
channel = 1
dim = 2048 # 2048
dim_layer0 = 16 # 16
dim_repeat_cut = 5
use_sym = True

#decoder 
ancestor_dim = 512 # 512
model_name = "segment" #"res_segment"
use_dyn_tree = False
part_cls_dropout = None #0.5

# data
prefix = "_affine_iter_xinf"
transform = affine_transform
augment = 1
no_prealign = False #True 
rotate_only = False #False
load_pretrain = False
augment_fn = lambda pts: augment_generator(pts) #, shift=True, scale=True, rotate_y_axis=True, agg_coef=0.5)
pca_augment = False
test_as_valid = False #False
# prefix = "_orig_xinf"
# transform = no_transform
# augment = 1
# no_prealign = False #True 
# rotate_only = True #False
# load_pretrain = False
# augment_fn = lambda pts: augment_generator(pts, shift=True, scale=True, rotate_y_axis=True, agg_coef=0.5)
# pca_augment = False
# test_as_valid = False #False
# prefix = "_homo_x30"
# transform = homo_transform
# augment = 30
# no_prealign = False
# load_pretrain = False
# pca_augment = True

sample_child_first = False # True in l7s1
num_classes = 16
num_parts = 50
DATASET = './datasets/ShapeNetPart'
chaos_limit = 0


In [None]:
global logging_init_flag
logging_init_flag = False

def init_logging(OUTPUT):
    global logging_init_flag
    if logging_init_flag is True:
        return
    logging_init_flag = True

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s:\t%(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')

    fh = logging.FileHandler(f"{OUTPUT}/training.log")
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)

    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)

    logger.addHandler(ch)
    logger.addHandler(fh)

init_logging(OUTPUT)
logging.info(f"prefix = {prefix}")
_ = init_directions(chaos_limit, calc_dmap=False)

In [None]:
import json
class_parts = json.load(open(f'{DATASET}/id2parts.json'))
class_parts_list = []
part_mapping = []
for i, p in enumerate(class_parts):
    class_parts_list.append(torch.tensor(p).cuda())
    
    mask = torch.zeros(num_parts).cuda().bool()
    mask[p] = True
    class_parts[i] = mask

    pmap = torch.full([num_parts], -1).cuda()
    pmap[p] = torch.arange(len(p)).cuda()
    part_mapping.append(pmap)


class_parts = torch.stack(class_parts, dim=0)


In [None]:
torch.manual_seed(674433238)
def new_model(num_parts=num_parts):
    if model_name == 'res_segment':
        encoder_fn = lambda point_dim : Encoder(model_size, sample_layers, dim, OUTPUT, point_dim=point_dim, channel=channel, sample_child_first=sample_child_first, dim_layer0=dim_layer0, dim_repeat_cut=dim_repeat_cut, use_symmetry_loss=use_sym).cuda()
        model = ResSegment(ancestor_dim, encoder_fn, use_dyn_tree=use_dyn_tree, carry_dim_seg1=dim//16, num_parts=num_parts, part_cls_dropout=part_cls_dropout).cuda()
    else:
        encoder = Encoder(model_size, sample_layers, dim, OUTPUT, channel=channel, sample_child_first=sample_child_first, dim_layer0=dim_layer0, dim_repeat_cut=dim_repeat_cut, use_symmetry_loss=use_sym).cuda()
        model = Segment(ancestor_dim, encoder, num_parts=num_parts, part_cls_dropout=part_cls_dropout).cuda()
    model.tree.use_sym = False
    return model

model = new_model()

In [None]:
model

# **Data**

In [None]:
from dataset import *
import numpy as np
from math import ceil

In [None]:
make = make_data_default
if rotate_only:
    make = make_data_rotate_only
if no_prealign:
    make = make_data_no_prealign

In [None]:
n_train = 6
clouds = []
labels = []
extra_labels = []
for i in range(n_train):
    data_file = h5py.File(f'{DATASET}/train{i}.h5')
    clouds.append(torch.tensor(np.array(data_file['data'])))
    labels.append(torch.tensor(np.array(data_file['seg'])))
    extra_labels.append(torch.tensor(np.array(data_file['label'])))

clouds = torch.cat(clouds, dim=0)
labels = torch.cat(labels, dim=0)
extra_labels = torch.cat(extra_labels, dim=0)

train_dataset = PointCloudDataset(clouds, labels, model.tree.arrange, augment=augment, transform=transform, make=make, extra_labels=extra_labels, augment_fn=augment_fn)


In [None]:
n_valid = 1
clouds = []
labels = []
extra_labels = []
for i in range(n_valid):
    data_file = h5py.File(f'{DATASET}/val{i}.h5')
    clouds.append(torch.tensor(np.array(data_file['data'])))
    labels.append(torch.tensor(np.array(data_file['seg'])))
    extra_labels.append(torch.tensor(np.array(data_file['label'])))

clouds = torch.cat(clouds, dim=0)
labels = torch.cat(labels, dim=0)
extra_labels = torch.cat(extra_labels, dim=0)

valid_dataset = PointCloudDataset(clouds, labels, model.tree.arrange, augment=1, transform=transform, make=make, extra_labels=extra_labels)

In [None]:
n_test = 2
clouds = []
labels = []
extra_labels = []
for i in range(n_test):
    data_file = h5py.File(f'{DATASET}/test{i}.h5')
    clouds.append(torch.tensor(np.array(data_file['data'])))
    labels.append(torch.tensor(np.array(data_file['seg'])))
    extra_labels.append(torch.tensor(np.array(data_file['label'])))

clouds = torch.cat(clouds, dim=0)
labels = torch.cat(labels, dim=0)
extra_labels = torch.cat(extra_labels, dim=0)

test_dataset = PointCloudDataset(clouds, labels, model.tree.arrange, augment=1, make=make, transform=transform, extra_labels=extra_labels)

In [None]:
if test_as_valid:
    a = train_dataset
    b = valid_dataset
    train_dataset = PointCloudDataset(torch.cat([a.clouds, b.clouds], dim=0), torch.cat([a.labels, b.labels], dim=0), model.tree.arrange, augment=1, transform=transform, make=make, extra_labels=torch.cat([a.extra_labels, b.extra_labels], dim=0))
    valid_dataset = test_dataset


In [None]:
len(train_dataset), len(valid_dataset), len(test_dataset)

In [None]:
target = [] # ['train', 'valid', 'test']

import gc
gc.collect()

for name, dataset in zip(['train', 'valid', 'test'], [train_dataset, valid_dataset, test_dataset]):
    from tqdm import tqdm

    if name not in target:
        continue

    num_workers = 16
    batch_size = 8

    data_init = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=placeholder, pin_memory=False, prefetch_factor=32, drop_last=False)
    logging.info(f"Init {name}")

    counter = 0
    mem = []

    for i, data in enumerate(tqdm(data_init)):
        mem += data
        # logging.debug(f"Init {name}: {i+1}/{len(data_init)}")

        if len(mem) >= 16384:
            torch.save(mem, f'{OUTPUT}/{name}_data{prefix}.{counter}.pth')
            counter += 1
            mem.clear()
            gc.collect()
            
    torch.save(mem, f'{OUTPUT}/{name}_data{prefix}.{counter}.pth')
    del mem
    gc.collect()


# **Train**

## Defs

In [None]:
from dataset import make_batch_train, make_batch_eval
import torch.nn as nn

global batch_size
batch_size = 128

for name, dataset in zip(['train', 'valid', 'test'], [train_dataset, valid_dataset, test_dataset]):
    print(f"Loading {name}")
    dataset.mem = None

    try:
        tmp = torch.load(f'{OUTPUT}/{name}_data{prefix}.pth')
        dataset.mem = tmp
        assert len(dataset.mem) == len(dataset), f'Size unmatch: {len(dataset.mem)} != {len(dataset)}'
        continue
    except:
        print("Try part mode")

    dataset.mem = []
    for i in range(0, 1000000000):
        try:
            tmp = torch.load(f'{OUTPUT}/{name}_data{prefix}.{i}.pth')
            dataset.mem += tmp
            print(f"Loaded part {i} # = {len(tmp)}")
        except:
            break
    if len(dataset.mem) > 0:
        continue

    print("Use force online")
    dataset.mem = None
    dataset.force_online = True

def inf_iter(a):
    while True:
        for k in a:
            yield(k)


In [None]:
class AccStat:
    def __init__(self):
        self.clear()

    def clear(self):
        self.correct = 0
        self.total = 0

    def add(self, cor, num=1, mean=False):
        if mean:
            cor *= num
        self.correct += cor
        self.total += num

    def result(self, clear=False):
        ret = self.correct / max(self.total, 1e-5)
        if clear:
            self.clear()
        return ret 

    def __str__(self):
        return "%.4lf" % self.result()

def miou(logits, part_label):
    if len(logits.shape) < 3:
        logits = logits.unsqueeze(0)
    
    if len(part_label.shape) < 2:
        part_label = part_label.unsqueeze(0)

    pred = logits.cuda().argmax(dim=-1)
    part = torch.arange(logits.size(-1)).cuda()

    in_pred = part[None, :, None] == pred[:, None, :]
    in_label = part[None, :, None] == part_label.cuda()[:, None, :]

    I = (in_pred & in_label).sum(dim=-1)
    U = (in_pred | in_label).sum(dim=-1)

    part_IOU = I / U.clamp(min=1).float()
    part_IOU[U < 0.5] = 1
    # counted = in_label.sum(dim=-1) > 0.5
    # counted = class_parts[cloud_label.long().cuda()]
    
    # appeared = in_label.sum(dim=-1) > 0.5
    # if (counted != appeared).sum().item() != 0:
    #     for cl, cnt, app in zip(cloud_label, counted, appeared):
    #         if (cnt != app).sum().item() != 0:
    #             arr = torch.arange(num_parts).cuda()
    #             print(f"Error cnt = {arr[cnt]} app = {arr[app]}")
    #     assert False

    shape_IOU = part_IOU.mean(dim=-1)
    return shape_IOU.mean().item()

def evaluate(model, linear, loader, noprint=False, perms=[None], together=False):
    cls = AccStat()
    stat = AccStat()
    trad = AccStat()

    class_stat = [AccStat() for _ in range(num_classes)]

    if not noprint:
        logging.info(f"loader # = {len(loader)}")

    print_epoch = 1

    model.eval()
    linear.eval()
    activate = lambda x : x

    for epoch, (input, part_label, cloud_label) in enumerate(loader):
        part_label = part_label.cuda()
        cloud_label = cloud_label.squeeze(-1).cuda()
        with torch.no_grad():
            for iperm in perms:
                if together:
                    features = linear(model(*input, perm=iperm))
                    logits_all = model.part_classfier(features)

                    for cloud_l, part_l, logits in zip(cloud_label, part_label, logits_all):
                        C = cloud_l.item()
                        part_l = part_mapping[C][part_l]
                        logits = logits[:, class_parts[C]]

                        x = miou(logits, part_l)
                        stat.add(x)
                        class_stat[C].add(x)

                else:
                    features = linear(model(*input, perm=iperm))
                    class_index = cloud_label.unique().item()

                    logits = model.part_classfier(features)
                    part_label = part_mapping[class_index][part_label]

                    stat.add(miou(logits, part_label))

        if not noprint:
            if (epoch // batch_size + 1) % print_epoch == 0:
                logging.debug(f"test #{epoch} correct = {'%.6lf' % stat.result()}")

    if not noprint:
        logging.info(f"Done: score = {'%.8lf' % stat.result()}")

    model.train()
    linear.train()  
    inst_miou = stat.result()
    class_miou = sum([s.result() for s in class_stat]) / len(class_stat) if together else 0.0

    if not noprint:
        logging.info(" ".join(map(str, class_stat)))
    
    return inst_miou, "%.6lf %.6lf" % (inst_miou, class_miou)


## Train function

In [None]:
import torch.nn.functional as F


def train(class_index, more_epoch=10000, valid_result_threshold=999.0):
    current_epoch = 0
    best_vres = -1.0
    
    from random import choice, randint
    import build_tree
    import torch
    
    model = new_model(len(class_parts_list[class_index]))
    linear = nn.Identity()
    model.train()
    linear.train()
    activate = lambda x : x

    if load_pretrain:
        ckpt = torch.load(f"{OUTPUT}/trained_{class_index}_best_affined_pca_nosample_x10.pth")
        model.load_state_dict(ckpt['model'])
        logging.info("Pretrain loaded")


    def get_trans(n=3):
        if n == 0:  return [None]
        return [randint(0, len(build_tree.transforms) - 1) for _ in range(n)]

    def save(epoch):
        torch.save({
            'model': model.state_dict(),
            # 'linear': linear.state_dict(),
            # 'opt': opt.state_dict(),
            # 'sch': sch.state_dict(),
            # 'best_vres': best_vres,
        }, f"{OUTPUT}/trained_{class_index}_{epoch}.pth")

    global batch_size 

    subset_crit = lambda x : x[-1].item() == class_index
    num_workers = 24
    mbtrain = make_batch_generator(pca_augment=pca_augment and not no_prealign)
    train_loader = torch.utils.data.DataLoader(Subset(train_dataset, subset_crit), batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=mbtrain, pin_memory=True, drop_last=True)
    valid_loader = torch.utils.data.DataLoader(Subset(valid_dataset, subset_crit), batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=make_batch_eval, pin_memory=True, drop_last=False)
    test_loader = torch.utils.data.DataLoader(Subset(test_dataset, subset_crit), batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=make_batch_eval, pin_memory=True, drop_last=False)

    logging.info(f"train class_index = {class_index} epoch = {current_epoch + 1} ~ {more_epoch} threshold = {valid_result_threshold} count = {len(train_loader) * batch_size} num_parts = {len(class_parts_list[class_index])}")

    cum_loss = 0
    cum_inner_loss = 0

    batch_scale = 1 # 1 if basic else 1
    epoch_scale = 8
    num_trans = 0 #0 if no_prealign else 1


    print_epoch = 20
    valid_epoch = 20
    epoch_since = 0
    save_epoch = 100
    cut_epoch = 10000000000

    cloud_accuracy = AccStat()
    accuracy = AccStat()
    trad_accuracy = AccStat()
    cloud_classify_coef = 0.3
    part_classify_coef = 0.2
    class_spec_coef = 1

    threshold = -1.0

    crit = torch.nn.CrossEntropyLoss(ignore_index=-1)
    opt = torch.optim.Adam(list(model.parameters()) + list(linear.parameters()), lr=1e-4)
    # sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_epoch / 5, eta_min=1e-5)
    sch = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9999)
    save(current_epoch)

    train_iter = inf_iter(train_loader)

    for epoch in range(current_epoch + 1, current_epoch + more_epoch + 1):

        current_epoch = epoch

        for _ in range(epoch_scale):
            loss = torch.tensor(0.).cuda()
            inner_loss = torch.tensor(0.).cuda()
            for _ in range(batch_scale):
                input, part_label, cloud_label = next(train_iter)
                part_label = part_label.cuda()
                cloud_label = cloud_label.squeeze(-1).cuda()
                assert cloud_label.unique().item() == class_index
                for iperm in get_trans(num_trans):
                    
                    features = linear(model(*input, perm=iperm))

                    logits = model.part_classfier(features)
                    part_label = part_mapping[class_index][part_label]

                    # weight = get_parts_weight(part_label, num_parts=len(class_parts_list[class_index]))
                    loss += F.cross_entropy(logits.transpose(-1, -2), part_label) * class_spec_coef

                    if model_name == 'res_segment':
                        inner_logits = model.inner_part_classfier(model.inner_ans)
                        inner_loss += F.cross_entropy(inner_logits.transpose(-1, -2), part_label) * class_spec_coef

                    with torch.no_grad():
                        accuracy.add(miou(logits, part_label))
                        epoch_since += 1
            
            assert loss.isnan().sum() == 0
            cum_loss += loss.item()
            cum_inner_loss += inner_loss.item()
            opt.zero_grad()
            (loss + inner_loss).backward()
            opt.step()
            
        sch.step()

        if cum_loss / epoch_since < threshold:
            epoch_scale, batch_scale = batch_scale, epoch_scale
            
            logging.info("Threshold Reached")
            threshold = -1e10
            
        if epoch <= 5 or epoch % print_epoch == 0:
            valid_str = ""
            func = logging.debug

            stop_training = False
            if epoch % valid_epoch == 0:
                vres, valid_str = evaluate(model, linear, valid_loader, noprint=True)
                valid_str = "valid = " + valid_str
                stop_training = (vres >= valid_result_threshold)
                if vres > best_vres:
                    best_vres = vres
                    save(f"best{prefix}")
                    valid_str += " updated"

                func = logging.info
            func(f"train #{epoch} lr = {'%.2e' % sch.get_last_lr()[0]} loss = {'%.6lf / %.6lf' % (cum_loss / epoch_since, cum_inner_loss / epoch_since)} train = {accuracy} {valid_str}")
            epoch_since = cum_loss = cum_inner_loss = 0
            accuracy.clear()
            cloud_accuracy.clear()
            trad_accuracy.clear()

            if stop_training:
                break


        if epoch % save_epoch == 0:
            if not test_as_valid:
                save(epoch)
            tres, test_str = (0., best_vres) if test_as_valid else evaluate(model, linear, test_loader, noprint=True)
            logging.info(f"Saved test = {test_str}")

        if epoch % cut_epoch == 0:
            if batch_size > 8:
                batch_size //= 2
                epoch_scale *= 2

            logging.info(f"Cut batch_size = {batch_size} epoch_scale = {epoch_scale}")


In [None]:
def train_together(more_epoch=100000, valid_result_threshold=999.0):
    current_epoch = 0
    best_vres = -1.0
    
    from random import choice, randint
    import build_tree
    import torch
    
    linear = nn.Identity()
    model.train()
    linear.train()
    activate = lambda x : x

    if load_pretrain:
        ckpt = torch.load(f"{OUTPUT}/trained_together_best_affined_pca_nosample_x10.pth")
        model.load_state_dict(ckpt['model'])
        logging.info("Pretrain loaded")


    def get_trans(n=3):
        if n == 0:  return [None]
        return [randint(0, len(build_tree.transforms) - 1) for _ in range(n)]

    global batch_size 

    num_workers = 24
    mbtrain = make_batch_generator(pca_augment=pca_augment and not no_prealign)
    train_loader = torch.utils.data.DataLoader(BalanceDataset(train_dataset), batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=mbtrain, pin_memory=True, drop_last=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=make_batch_eval, pin_memory=True, drop_last=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=make_batch_eval, pin_memory=True, drop_last=False)

    logging.info(f"train epoch = {current_epoch + 1} ~ {more_epoch} threshold = {valid_result_threshold} count = {len(train_loader) * batch_size} num_parts = {num_parts}")

    cum_loss = 0
    cum_inner_loss = 0

    batch_scale = 1 # 1 if basic else 1
    epoch_scale = 4 # 8 # 4 if basic else 2
    num_trans = 0 #0 if no_prealign else 1


    print_epoch = 20
    valid_epoch = 20
    epoch_since = 0
    save_epoch = 100
    cut_epoch = 10000000000

    cloud_accuracy = AccStat()
    accuracy = AccStat()
    trad_accuracy = AccStat()
    cloud_classify_coef = 0.3
    part_classify_coef = 0.2
    class_spec_coef = 1

    threshold = -1.0

    temperature = 1 # 0.07

    crit = torch.nn.CrossEntropyLoss(ignore_index=-1)
    opt = torch.optim.Adam(list(model.parameters()) + list(linear.parameters()), lr=1e-4)
    # sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_epoch / 5, eta_min=1e-5)
    sch = torch.optim.lr_scheduler.ExponentialLR(opt, 0.998)

    def save(epoch):
        torch.save({
            'model': model.state_dict(),
            'linear': linear.state_dict(),
            'opt': opt.state_dict(),
            'sch': sch.state_dict(),
            'best_vres': best_vres,
        }, f"{OUTPUT}/trained_together_{epoch}.pth")


    save(current_epoch)

    train_iter = inf_iter(train_loader)

    for epoch in range(current_epoch + 1, current_epoch + more_epoch + 1):

        current_epoch = epoch

        for _ in range(epoch_scale):
            loss = torch.tensor(0.).cuda()
            inner_loss = torch.tensor(0.).cuda()
            for _ in range(batch_scale):
                input, part_label, cloud_label = next(train_iter)
                part_label = part_label.cuda()
                cloud_label = cloud_label.squeeze(-1).cuda()
                for iperm in get_trans(num_trans):
                    
                    features = linear(model(*input, perm=iperm))
                    logits_all = model.part_classfier(features)

                    if model_name == 'res_segment':
                        inner_features = model.inner_ans
                        logits_inner_all = model.inner_part_classfier(inner_features)

                        for cloud_l, part_l, logits, logits_inner in zip(cloud_label, part_label, logits_all, logits_inner_all):
                            C = cloud_l.item()
                            part_l = part_mapping[C][part_l]
                            logits = logits[:, class_parts[C]]
                            logits_inner = logits_inner[:, class_parts[C]]

                            # count = (torch.arange(logits.shape[-1], device='cuda')[:, None] == part_l[None, :]).sum(dim=-1)
                            # weight = 1.0 / count.clamp(min=1)

                            loss += F.cross_entropy(logits / temperature, part_l) * class_spec_coef / batch_size
                            inner_loss += F.cross_entropy(logits_inner / temperature, part_l)  * class_spec_coef / batch_size

                            with torch.no_grad():
                                accuracy.add(miou(logits, part_l))    


                    else:
                        assert False, "disabled"
                        for cloud_l, part_l, logits in zip(cloud_label, part_label, logits_all):
                            C = cloud_l.item()
                            part_l = part_mapping[C][part_l]
                            logits = logits[:, class_parts[C]]
                            loss += F.cross_entropy(logits / temperature, part_l) * class_spec_coef / batch_size

                            with torch.no_grad():
                                accuracy.add(miou(logits, part_l))
                    
                    epoch_since += 1
            
            assert loss.isnan().sum() == 0
            cum_loss += loss.item()
            cum_inner_loss += inner_loss.item()
            opt.zero_grad()
            (loss + inner_loss).backward()
            opt.step()
            
        

        if cum_loss / epoch_since < threshold:
            epoch_scale, batch_scale = batch_scale, epoch_scale
            
            logging.info("Threshold Reached")
            threshold = -1e10
            
        if epoch <= 5 or epoch % print_epoch == 0:
            valid_str = ""
            func = logging.debug

            stop_training = False
            if epoch % valid_epoch == 0:
                vres, valid_str = evaluate(model, linear, valid_loader, noprint=True, together=True)
                valid_str = "valid = " + valid_str
                stop_training = (vres >= valid_result_threshold)
                if vres > best_vres:
                    best_vres = vres
                    save(f"best{prefix}")
                    valid_str += " updated"

                func = logging.info
            func(f"train #{epoch} lr = {'%.2e' % sch.get_last_lr()[0]} loss = {'%.6lf / %.6lf' % (cum_loss / epoch_since, cum_inner_loss / epoch_since)} train = {accuracy} {valid_str}")
            epoch_since = cum_loss = cum_inner_loss = 0
            accuracy.clear()
            cloud_accuracy.clear()
            trad_accuracy.clear()

            if stop_training:
                break

            if epoch % print_epoch == 0:
                sch.step()


        if epoch % save_epoch == 0:
            if not test_as_valid:
                save(epoch)
            tres, test_str = (0., best_vres) if test_as_valid else evaluate(model, linear, test_loader, noprint=True, together=True)
            logging.info(f"Saved test = {test_str}")

        if epoch % cut_epoch == 0:
            if batch_size > 8:
                batch_size //= 2
                epoch_scale *= 2

            logging.info(f"Cut batch_size = {batch_size} epoch_scale = {epoch_scale}")


## Loops

In [None]:
train(0, 2000)

In [None]:
train_together()