In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../../src')

In [3]:
import torch
from torch.utils.data import DataLoader
from torch.optim.swa_utils import SWALR

from data.seven_scenes.utils_7scenes import SevenScenesEvalDataset
from DEPO.depo import depo_best

from training.train_depo_pose_and_flow_weighted import MixedScheduler
from training.train_depo_pose import train, validate
from training.loss_pose import LossPose

from utils.model import load_checkpoint, plot_schedule
import numpy as np
import pandas as pd

from transformers import get_scheduler
from pathlib import Path
import os.path as osp
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

#### Data

In [5]:
train_data = SevenScenesEvalDataset(
    root_dir='/home/project/data/7scenes/',
    pairs_path='/home/project/data/7scenes/db_all_med_hard_train.txt')
   
train_loader = DataLoader(train_data, batch_size=4, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)

val_data = SevenScenesEvalDataset(
    root_dir='/home/project/data/7scenes/',
    pairs_path='/home/project/data/7scenes/db_all_med_hard_valid.txt')

val_loader = DataLoader(val_data, batch_size=4, shuffle=False, drop_last=False, pin_memory=True, num_workers=4)

#### Config

In [6]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

config = dict(
    experiment_name='7_scenes_ft_pretrained',
    device=device,
    n_epochs=11,
    n_accum_steps=16,
    batch_size=train_loader.batch_size,
    n_steps_per_epoch=len(train_loader.dataset) // train_loader.batch_size,
    swa=True,
    swa_lr=5e-5,
    n_epochs_swa=1,
    repeat_val_epoch=1,
    repeat_save_epoch=1,
    scheduler_step='step',
    model_save_path='../../src/weights/7_scenes_ft_pretrained'
)

config['n_effective_steps_per_epoch'] = np.ceil(len(train_loader.dataset) / (train_loader.batch_size * config['n_accum_steps'])) 
config['n_warmup_steps'] = int(config['n_effective_steps_per_epoch'] * 1.)
config['n_training_steps'] = int(config['n_effective_steps_per_epoch'] * (config['n_epochs'] - config['n_epochs_swa']))
config['n_swa_anneal_steps'] = int(config['n_effective_steps_per_epoch'] * 0.2)
config['n_steps_between_swa_updates'] = (config['n_effective_steps_per_epoch'] - config['n_swa_anneal_steps']) // 10

#### Model

In [7]:
checkpoint = load_checkpoint(
    '/home/project/code/src/weights/flow_and_pose_best_abs_0.pth',
    device)

model = depo_best(calculate_flow=False)
model.load_state_dict(checkpoint['model'])
model.to(device);

#### Loss & Optimizer & Scheduler

In [10]:
val_loss = LossPose(agg_type=None, t_norm='l2')
train_loss = LossPose(agg_type='mean', t_norm='l1')

In [11]:
for name, module in model.named_parameters():
    if ('pose_regressor' in name) or ('intrinsics_mlp' in name):
        module.requires_grad = True
    else:
        module.requires_grad = False
        
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [12]:
base_scheduler = get_scheduler(
    "cosine",    
    optimizer=optimizer,
    num_warmup_steps=config['n_warmup_steps'],
    num_training_steps=config['n_training_steps'])

swa_scheduler = SWALR(
    optimizer,
    swa_lr=config['swa_lr'],
    anneal_epochs=config['n_swa_anneal_steps'])

scheduler = MixedScheduler(
    base_scheduler,
    swa_scheduler,
    n_epochs=config['n_epochs'],
    n_epochs_swa=config['n_epochs_swa'],
    n_steps_per_epoch=config['n_effective_steps_per_epoch'],
    n_swa_anneal_steps=config['n_swa_anneal_steps'],
    n_steps_between_swa_updates=config['n_steps_between_swa_updates']
)

#### Train & val

In [None]:
train(model, optimizer, scheduler, train_loss, val_loss, train_loader, val_loader, config, **config)

[34m[1mwandb[0m: Currently logged in as: [33mkovanic[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10399/10399 [24:48<00:00,  6.99it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2601/2601 [05:17<00:00,  8.20it/s]


epoch 0: val loss(q)=0.18653163213945342, val loss(t)=0.633743319009341


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10399/10399 [24:50<00:00,  6.98it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2601/2601 [05:18<00:00,  8.17it/s]


epoch 1: val loss(q)=0.17704658279097588, val loss(t)=0.5815676708469893


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍| 10352/10399 [24:45<00:06,  6.93it/s]