In [1]:
# %load_ext autoreload
# %autoreload 2

In [1]:
import sys
sys.path.append('../../../src')

In [2]:
import torch

from torch.utils.data import DataLoader

from data.scannet.utils_scannet_fast import ScanNetDataset
from DEPO.depo_ablations import A6
from training.train_depo_pose_and_flow import train, validate
from training.loss_depo import LossMixedDetermininstic
from utils.model import load_checkpoint, plot_schedule
import numpy as np

from transformers import get_scheduler

#### Data

In [8]:
train_data = ScanNetDataset(
    root_dir='/home/project/data/scans/',
    npz_path='/home/project/code/data/scannet_splits/train_subset_ablations.npz',
    intrinsics_path='/home/project/ScanNet/scannet_indices/intrinsics.npz',
    calculate_flow=True
)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)

val_data = ScanNetDataset(
    root_dir='/home/project/data/scans/',
    npz_path='/home/project/code/data/scannet_splits/val_subset_ablations.npz',
    intrinsics_path='/home/project/ScanNet/scannet_indices/intrinsics.npz',
    calculate_flow=False
)

val_loader = DataLoader(val_data, batch_size=32, shuffle=False, drop_last=False, pin_memory=True, num_workers=4)

#### Config

In [9]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

config = dict(
    experiment_name='A9:A6,B=128,lr=1e-3,wd=1e-5,step-lr,1wu',
    device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
    n_epochs=10,
    n_accum_steps=4,
    batch_size=train_loader.batch_size,
    n_steps_per_epoch=len(train_loader.dataset) // train_loader.batch_size,
    swa=False,
    n_epochs_swa=None,
    n_steps_between_swa_updates=None,
    repeat_val_epoch=1,
    repeat_save_epoch=20,
    model_save_path='../../src/weights/A9'
)

config['n_effective_steps_per_epoch'] = np.ceil(len(train_loader.dataset) / (train_loader.batch_size * config['n_accum_steps'])) 
config['n_warmup_steps'] = int(config['n_effective_steps_per_epoch'] * 1)
config['n_training_steps'] = int(config['n_effective_steps_per_epoch'] * config['n_epochs'])

#### Model

In [6]:
model = A7().to(config['device'])

for name, p in model.named_parameters():
    if 'self_encoder' in name:
        p.requires_grad = False
    else:
        p.requires_grad = True

#### Loss & Optimizer & Scheduler

In [7]:
val_loss = LossMixedDetermininstic(mode='val')
train_loss = LossMixedDetermininstic(mode='train', add_l2=False)

In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [24]:
class WarmupStepLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, step_size, gamma, min_lr, warmup_steps=5, warmup_lr_init=1e-7,
                 last_epoch=-1, verbose=False, **kwargs):
        self.warmup_steps = warmup_steps
        self.warmup_lr_init = warmup_lr_init
        self.step_size = step_size
        self.gamma = gamma
        self.min_lr = min_lr
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if self.last_epoch <= self.warmup_steps:
            return [self.warmup_lr_init + (self.base_lrs[group] - self.warmup_lr_init) * self.last_epoch / self.warmup_steps \
                    for group in range(len(self.optimizer.param_groups))]
        elif (self.last_epoch - self.warmup_steps) % self.step_size:
            return [group['lr'] for group in self.optimizer.param_groups]
        else:
            return [np.maximum(group['lr'] * self.gamma, self.min_lr) for group in self.optimizer.param_groups]


In [29]:
scheduler = WarmupStepLR(optimizer, config['n_effective_steps_per_epoch'] * 2,
                         0.75, 1e-7, config['n_warmup_steps'], 1e-7)

#### Train & val

In [None]:
train(model, optimizer, scheduler, train_loss, val_loss, train_loader, val_loader, config, **config)

[34m[1mwandb[0m: Currently logged in as: [33mkovanic[0m. Use [1m`wandb login --relogin`[0m to force relogin


 11%|████████████▋                                                                                                     | 697/6250 [20:16<2:46:28,  1.80s/it]