In this notebook I've added module for calculation of confidences out of raw cross-features scores. The loss is nll, where desnisty is modeled as mixture of Laplacian densities. **Note:** I haven't fine-tuned the original GMFlow because of scarcity of resources

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../src')

In [3]:
import torch
from torch.utils.data import DataLoader
from torch.optim.swa_utils import AveragedModel

from data.scannet.utils import ScanNetDataset
from matching.gmflow_confidence.gmflow_with_uncertainty import GMflowWithConfidence

from training.loss_gmflow_conf import LossGMflowWithConfidence
from training.train_gmflow_conf import train, CustomScheduler

from tqdm.auto import tqdm
import wandb

  from .autonotebook import tqdm as notebook_tqdm


### 1. Data

In [4]:
train_data = ScanNetDataset(
    root_dir='/home/project/data/scans/',
    npz_path='/home/project/ScanNet/train_indicies_subset.npz',
    intrinsics_path='/home/project/ScanNet/scannet_indices/intrinsics.npz',
    mode='train'
)

train_loader = DataLoader(train_data, batch_size=2, shuffle=True, drop_last=True, pin_memory=True, num_workers=1)

val_data = ScanNetDataset(
    root_dir='/home/project/data/scans/',
    npz_path='/home/project/ScanNet/val_indicies_subset.npz',
    intrinsics_path='/home/project/ScanNet/scannet_indices/intrinsics.npz',
    mode='val'
)

val_loader = DataLoader(val_data, batch_size=2, shuffle=False, drop_last=False, pin_memory=True, num_workers=1)

### 2. Configuration

In [5]:
config = dict(
    general = dict(
        experiment_name='0_gmflow_with_confidence_ft',
        device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
        
        n_epochs=5,
        n_steps_per_epoch=len(train_loader.dataset)//train_loader.batch_size,
        n_accum_steps=8,
        batch_size=train_loader.batch_size,
        
        swa=False,
        n_epochs_swa=None,
        n_steps_between_swa_updates=None,
          
        repeat_val_epoch=1,
        repeat_save_epoch=1,
          
        model_save_path='../src/matching/gmflow_confidence/weights/0_gmflow_with_confidence_ft'
    )
)

### 3. Model

In [6]:
model = GMflowWithConfidence(
    path_pretrained_gmflow='../src/matching/gmflow/weights/pretrained/gmflow_with_refine_kitti-8d3b9786.pth'
)

model.to(config['general']['device']);

In [7]:
for name, module in list(model.named_parameters()):
    if 'confidence' not in name:
        module.requires_grad = False

### 4. Loss, optimizer, scheduler

In [8]:
loss = LossGMflowWithConfidence()
val_loss = LossGMflowWithConfidence(mode='val')

In [9]:
opt_parameters = []
flow_module_params = []
confidence_module_params = []

for name, module in list(model.named_parameters()):
    if 'flow' in name:
        flow_module_params.append(module)
    if 'confidence' in name:
        confidence_module_params.append(module)
        
optimizer = torch.optim.AdamW(
    [{'params': flow_module_params, 'weight_decay': 1e-6, 'lr': 1e-4},
    {'params': confidence_module_params, 'weight_decay': 1e-6, 'lr': 1e-4}]
)


In [10]:
scheduler = CustomScheduler(optimizer, config['general']['n_steps_per_epoch'] // config['general']['n_accum_steps'])

### 6. Experiment

In [None]:
train(model, optimizer, scheduler, loss, val_loss, train_loader, val_loader, config, **config['general'])

[34m[1mwandb[0m: Currently logged in as: [33mkovanic[0m. Use [1m`wandb login --relogin`[0m to force relogin


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
 34%|█████████████████████████▏                                               | 17135/49710 [1:41:58<3:13:25,  2.81it/s]

In [None]:
# model.eval()

# data = next(iter(train_loader))

# img_0 = data['image0'].cuda()
# img_1 = data['image1'].cuda()

# with torch.no_grad():

#     out = model(img_0,
#                 img_1,
#                 attn_splits_list=[2, 8],
#                 corr_radius_list=[-1, 4],
#                 prop_radius_list=[-1, 1]
#                )


In [None]:
# out['var'][:, 1].min()

In [None]:
# from training.nll_losses import NLLMixtureLaplace

# loss_ = NLLMixtureLaplace()

# loss_(data['flow_0to1'].cuda(), out['flow_preds'][0], torch.log(out['var'][:, :2]), out['var'][:, 2:], data['mask'].cuda())