In [1]:
# %load_ext autoreload
# %autoreload 2

In [1]:
import sys
sys.path.append('../../src')

In [2]:
import torch

from torch.utils.data import DataLoader
from data.scannet.utils_scannet_fast import ScanNetDataset
from DEPO.depo import depo_best
from training.train_depo_pose_and_flow_weighted import train, validate
from training.loss_depo import LossMixedRelativeWeighted
from utils.model import load_checkpoint, plot_schedule
import numpy as np

from transformers import get_scheduler

#### Data

In [4]:
train_data = ScanNetDataset(
    root_dir='/home/project/data/scans/',
    npz_path='/home/project/code/data/scannet_splits/smart_sample_train_ft.npz',
    intrinsics_path='/home/project/ScanNet/scannet_indices/intrinsics.npz',
    calculate_flow=True
)

train_loader = DataLoader(train_data, batch_size=8, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)

val_data = ScanNetDataset(
    root_dir='/home/project/data/scans/',
    npz_path='/home/project/code/data/scannet_splits/smart_sample_val.npz',
    intrinsics_path='/home/project/ScanNet/scannet_indices/intrinsics.npz',
    calculate_flow=False
)

val_loader = DataLoader(val_data, batch_size=8, shuffle=False, drop_last=False, pin_memory=True, num_workers=4)

#### Config

In [5]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

config = dict(
    experiment_name='flow_and_pose_best_relative:2',
    device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
    n_epochs=5,
    n_accum_steps=8,
    batch_size=train_loader.batch_size,
    n_steps_per_epoch=len(train_loader.dataset) // train_loader.batch_size,
    swa=False,
    n_epochs_swa=None,
    n_steps_between_swa_updates=None,
    repeat_val_epoch=1,
    repeat_save_epoch=1,
    model_save_path='../../src/weights/flow_and_pose_best_relative_part2'
)

config['n_effective_steps_per_epoch'] = np.ceil(len(train_loader.dataset) / (train_loader.batch_size * config['n_accum_steps'])) 
config['n_warmup_steps'] = int(config['n_effective_steps_per_epoch'] * 0.2)
config['n_training_steps'] = int(config['n_effective_steps_per_epoch'] * config['n_epochs'])

#### Model

In [8]:
checkpoint = load_checkpoint(
    '/home/project/code/src/weights/flow_and_pose_best_relative_2.pth',
    config['device'])

model = depo_best()
model.load_state_dict(checkpoint['model'])
model.to(config['device']);

#### Loss & Optimizer & Scheduler

In [9]:
val_loss = LossMixedRelativeWeighted(mode='val')
train_loss = LossMixedRelativeWeighted(mode='train', weights=[0., 0., 0.])
train_loss.weights.requires_grad = False

In [10]:
opt_parameters = []
for name, module in model.named_parameters():
    if 'self_encoder' in name:
        lr = 5e-5
    else:
        lr = 5e-4
    opt_parameters.append({
        'params': module,
        'weight_decay': 0.0 if ('bias' in name) else 1e-6,
        'lr': lr
    })

optimizer = torch.optim.AdamW(opt_parameters)
optimizer.load_state_dict(checkpoint['optimizer'])
weights_optimizer = torch.optim.SGD([train_loss.weights], lr=1e-4)

In [11]:
# scheduler = get_scheduler(
#     "cosine",    
#     optimizer=optimizer,
#     num_warmup_steps=config['n_warmup_steps'],
#     num_training_steps=config['n_training_steps']
# )
scheduler = checkpoint['lr_sched']

#### Train & val

In [12]:
train(model, optimizer, weights_optimizer, scheduler, train_loss, val_loss, train_loader, val_loader, config, **config)

[34m[1mwandb[0m: Currently logged in as: [33mkovanic[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/12500 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

epoch 0: val loss(flow)=None,
 val loss(q)=0.04760650121175809, val loss(t)=0.2552890590518713


  0%|          | 0/12500 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

epoch 1: val loss(flow)=None,
 val loss(q)=0.04949806845308814, val loss(t)=0.2665364249430597


  0%|          | 0/12500 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

epoch 2: val loss(flow)=None,
 val loss(q)=0.04993301642691726, val loss(t)=0.26271077171340584


  0%|          | 0/12500 [00:00<?, ?it/s]

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Train batch loss (flow),▅▃▃▅▅▄▃▅█▄▃▂▇▆▅▅▅▃▂▃▄▂▄▃▂▂▅▃▁▅▃▃▃▂▂▂▄▃▁▄
Train batch loss (q),▄▃▃█▆▄▂▅▇▃▆▄▇▃▃▃▄▂▆▂▅▅▂▁▁▄▃▅▂▆▂▅▅▅▆▅▅▅▄▄
Train batch loss (t),▅▇▇▇█▅▁▄▂▆▅▆█▂▆▃▃▃▆▅▅▅▅▆▄▆▄▅▄▃▄▇▂▄▄▄▄▅▇▅
Train batch loss (total),▅▃▃▅▆▄▃▅█▄▃▃▇▆▆▅▅▃▂▃▄▂▄▃▂▂▅▃▁▅▃▃▃▂▂▂▄▃▁▄
Train loss epoch (flow),█▃▁
Train loss epoch (q),▁▅█
Train loss epoch (t),█▄▁
Train loss epoch (total),█▃▁
Val loss epoch(q),▁▇█
Val loss epoch(t),▁█▆

0,1
Train batch loss (flow),7.8406
Train batch loss (q),0.03998
Train batch loss (t),0.33977
Train batch loss (total),8.22034
Train loss epoch (flow),9.05301
Train loss epoch (q),0.03971
Train loss epoch (t),0.27316
Train loss epoch (total),9.36589
Val loss epoch(q),0.04993
Val loss epoch(t),0.26271


KeyboardInterrupt: 