In this notebook shallow head is learned upon fixed GMflow to obtain relative pose from dense prediction. The confidences are calculated from raw cross-correlations. **Note**: I haven't fine-tuned the original GMFlow due to scarcity of resources

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../src')

In [3]:
import torch
from torch.utils.data import DataLoader

from data.scannet.utils import ScanNetDataset
from matching.gmflow_dense.gmflow_dense import GMflowDensePose

from training.loss_pose import LossPose
from training.train_dense import train
from utils.model import load_checkpoint

from tqdm.auto import tqdm
import wandb

  from .autonotebook import tqdm as notebook_tqdm


### 1. Data

In [4]:
train_data = ScanNetDataset(
    root_dir='/home/project/data/ScanNet/scans/',
    npz_path='/home/project/ScanNet/train_indicies_subset.npz',
    intrinsics_path='/home/project/ScanNet/scannet_indices/intrinsics.npz',
    mode='train'
)

train_loader = DataLoader(train_data, batch_size=2, shuffle=True, drop_last=True, pin_memory=True, num_workers=1)

val_data = ScanNetDataset(
    root_dir='/home/project/data/ScanNet//scans/',
    npz_path='/home/project/ScanNet/val_indicies_subset.npz',
    intrinsics_path='/home/project/ScanNet/scannet_indices/intrinsics.npz',
    mode='val'
)

val_loader = DataLoader(val_data, batch_size=2, shuffle=False, drop_last=False, pin_memory=True, num_workers=1)

### 2. Configuration

In [5]:
config = dict(
    general = dict(
        experiment_name='8_gmflow_dense_with_conf_ft',
        device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
        
        n_epochs=3,
        n_steps_per_epoch=len(train_loader.dataset)//train_loader.batch_size,
        n_accum_steps=8,
        batch_size=train_loader.batch_size,
        
        swa=False,
        n_epochs_swa=None,
        n_steps_between_swa_updates=None,
          
        repeat_val_epoch=1,
        repeat_save_epoch=1,
          
        model_save_path='../src/weights/8_gmflow_dense_with_conf_ft'
    )
)

### 3. Model

In [6]:
checkpoint = load_checkpoint('/home/project/code/src/matching/gmflow/weights/pretrained/gmflow_with_refine_kitti-8d3b9786.pth', config['general']['device'])
# checkpoint = load_checkpoint('/home/project/code/src/weights/3_gmflow_dense_with_conf_1.pth', config['general']['device'])


In [7]:
model = GMflowDensePose(conf_module=True)
model.flow_model.load_state_dict(checkpoint['model'])
model.to(config['general']['device']);

In [8]:
for name, module in list(model.flow_model.named_parameters()):
    if ('backbone' in name) or ('mlp' in name) or ('transformer.layers.0' in name) or ('transformer.layers.1' in name) or ('transformer.layers.2' in name):
        module.requires_grad = False

### 4. Loss, optimizer, scheduler

In [9]:
val_loss = LossPose()
train_loss = LossPose(agg_type='mean')

In [10]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)
config['optimizer'] = optimizer.__dict__['defaults']
# optimizer.load_state_dict(checkpoint['optimizer'])

In [11]:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                   step_size=1,
                   gamma=0.8)

### 6. Experiment

In [None]:
train(model, optimizer, scheduler, train_loss, val_loss, train_loader, val_loader, config, **config['general'])

[34m[1mwandb[0m: Currently logged in as: [33mkovanic[0m. Use [1m`wandb login --relogin`[0m to force relogin


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  0%|▌                                                                                                                | 248/49710 [01:39<5:23:54,  2.55it/s]