### Setup RoMa

In [None]:
!git clone https://github.com/Parskatt/RoMa.git
%cd RoMa
!pip install -q -e .

### Setup the dataset

In [None]:
SCENE = "0080"
SCENE_DIR = f"/content/{SCENE}"

!unzip -q /content/drive/MyDrive/thesis/Datasets/{SCENE}-nerf.zip -d {SCENE_DIR}

In [None]:
import os
import json
import numpy as np
from itertools import combinations
from sklearn.model_selection import train_test_split
import random
from collections import defaultdict

In [None]:
scene_dir = f"/content/{SCENE}"
images_dir = os.path.join(scene_dir, 'images')
depths_dir = os.path.join(scene_dir, 'depths')
camera_json = os.path.join(scene_dir, 'cameras.json')

with open(camera_json, 'r') as f:
    cam = json.load(f)

# Extract camera parameters
fl_x, fl_y = cam['fl_x'], cam['fl_y']
cx, cy = cam['cx'], cam['cy']
K = [
    [fl_x, 0.0, cx],
    [0.0, fl_y, cy],
    [0.0, 0.0, 1.0]
]

# Collect all images, depth maps, and poses
imgs = sorted([
    fn for fn in os.listdir(images_dir)
    if fn.lower().endswith(('.png', '.jpg', '.jpeg'))
])
keys = [fn.replace('.color.png', '') for fn in imgs if fn.endswith('.color.png')]

depth_map = {}
for fn in os.listdir(depths_dir):
    if fn.lower().endswith('.npy'):
        base = fn.rsplit('.depth.npy', 1)[0]
        depth_map[base] = fn

poses = {}
for frame in cam.get('frames', []):
    key = os.path.basename(frame['file_path']).replace('.jpg', '').replace('.color.png', '')
    poses[key] = np.array(frame['transform_matrix'])

# Select the train and test frames
train_keys, test_keys = train_test_split(keys, test_size=0.1, random_state=42)

def generate_balanced_pairs(keys, max_pairs_per_image=25):
    key_list = sorted(list(keys))
    used_counts = defaultdict(int)
    roma_pairs = []

    for i, key1 in enumerate(key_list):
        available = [k for k in key_list if k != key1]
        sampled_keys = random.sample(available, min(max_pairs_per_image, len(available)))

        for key2 in sampled_keys:
            # Ensure symmetric limit
            if used_counts[key1] >= max_pairs_per_image or used_counts[key2] >= max_pairs_per_image:
                continue

            img1 = f"{key1}.color.png"
            img2 = f"{key2}.color.png"
            d1 = depth_map.get(key1)
            d2 = depth_map.get(key2)
            T1 = poses.get(key1)
            T2 = poses.get(key2)

            if any(x is None for x in (d1, d2, T1, T2)):
                continue

            rel = (np.linalg.inv(T2) @ T1).tolist()

            roma_pairs.append({
                'img1': os.path.join('images', img1),
                'img2': os.path.join('images', img2),
                'im_A_depth': os.path.join('depths', d1),
                'im_B_depth': os.path.join('depths', d2),
                'K_A': K,
                'K_B': K,
                'rel_pose': rel
            })

            used_counts[key1] += 1
            used_counts[key2] += 1

    return roma_pairs

# Generate balanced pairs for training and validation sets
random.seed(42)
train_pairs = generate_balanced_pairs(train_keys, max_pairs_per_image=25)
val_pairs = generate_balanced_pairs(test_keys, max_pairs_per_image=25)

out_train = os.path.join(scene_dir, 'roma_pairs_train.json')
out_val = os.path.join(scene_dir, 'roma_pairs_val.json')

with open(out_train, 'w') as f:
    json.dump(train_pairs, f, indent=2)
with open(out_val, 'w') as f:
    json.dump(val_pairs, f, indent=2)

print(f"Exported {len(train_pairs)} training pairs to {out_train}")
print(f"Exported {len(val_pairs)} validation pairs to {out_val}")

* Create the NeRFRoMaDataset class file in the `romatch` library

In [None]:
%%bash
cat > romatch/datasets/nerf_roma_dataset.py << 'EOF'
import os
import json
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset

class NeRFRoMaDataset(Dataset):
    def __init__(self, cfg, split="train", transform=None, depth_transform=None):
        self.root_dir = cfg['root_dir']
        self.transform = transform
        self.depth_transform = depth_transform

        json_path = os.path.join(self.root_dir, f"roma_pairs_{split}.json")
        if not os.path.exists(json_path):
            raise FileNotFoundError(f"Missing file: {json_path}")
        with open(json_path, "r") as f:
            self.pairs = json.load(f)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        pair = self.pairs[idx]

        im_A = Image.open(os.path.join(self.root_dir, pair["img1"])).convert("RGB")
        im_B = Image.open(os.path.join(self.root_dir, pair["img2"])).convert("RGB")

        if self.transform:
            im_A, im_B = self.transform((im_A, im_B))
        else:
            im_A = torch.from_numpy(np.array(im_A)).permute(2, 0, 1).float().div(255.0)
            im_B = torch.from_numpy(np.array(im_B)).permute(2, 0, 1).float().div(255.0)

        depth_A_np = np.load(os.path.join(self.root_dir, pair["im_A_depth"])).astype(np.float32)
        depth_B_np = np.load(os.path.join(self.root_dir, pair["im_B_depth"])).astype(np.float32)

        if self.depth_transform:
            depth_A, depth_B = self.depth_transform((depth_A_np, depth_B_np))
        else:
            depth_A = torch.from_numpy(depth_A_np)
            depth_B = torch.from_numpy(depth_B_np)

        K1 = torch.tensor(pair["K_A"], dtype=torch.float32)
        K2 = torch.tensor(pair["K_B"], dtype=torch.float32)
        T_1to2 = torch.tensor(pair["rel_pose"], dtype=torch.float32)

        return {
            'im_A': im_A.clone(),
            'im_B': im_B.clone(),
            'im_A_depth': depth_A.clone(),
            'im_B_depth': depth_B.clone(),
            'K1': K1.clone(),
            'K2': K2.clone(),
            'T_1to2': T_1to2.clone(),
        }
EOF

In [None]:
# ------- IF RELOAD IS NEEDED -------
# In case Colab uses cached version

import importlib
import romatch.datasets.nerf_roma_dataset
importlib.reload(romatch.datasets.nerf_roma_dataset)

from romatch.datasets.nerf_roma_dataset import NeRFRoMaDataset

In [None]:
import wandb
wandb.init(mode="disabled")

import os
from tqdm import tqdm
import torch
import numpy as np
from torch.utils.data import DataLoader
import json
import matplotlib.pyplot as plt
import glob
import sys
import gc
from PIL import Image
import torch.nn.functional as F
from romatch.utils import tensor_to_pil

import romatch
from romatch.datasets.nerf_roma_dataset import NeRFRoMaDataset
from romatch.losses.robust_loss import RobustLosses
from romatch import tiny_roma_v1_outdoor, roma_outdoor
from romatch.utils import get_tuple_transform_ops, get_depth_tuple_transform_ops
from romatch.utils.utils import to_cuda

### Training function definitions

In [None]:
# -----------------------------------------------------------------------------
# GLOBALS
# -----------------------------------------------------------------------------
save_dir = f"/content/drive/MyDrive/thesis/RoMa/Checkpoints_Metrics/{SCENE}_tiny"
os.makedirs(save_dir, exist_ok=True)

loss_log_path = os.path.join(save_dir, "losses.json")
loss_log = {}
if os.path.exists(loss_log_path):
    with open(loss_log_path, "r") as f:
        loss_log = json.load(f)

metrics_log_path = os.path.join(save_dir, "metrics_log.json")
metrics_log = {}
if os.path.exists(metrics_log_path):
    with open(metrics_log_path, "r") as f:
        metrics_log = json.load(f)

# -----------------------------------------------------------------------------
# TRAIN STEP
# -----------------------------------------------------------------------------
def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm=1., **kwargs):
    optimizer.zero_grad()
    with torch.amp.autocast("cuda", dtype=torch.float16):
        out = model(train_batch)
        l = objective(out, train_batch)
        l = torch.clamp(l, min=0.0, max=10.0)
    grad_scaler.scale(l).backward()
    grad_scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
    grad_scaler.step(optimizer)
    grad_scaler.update()
    return {"train_out": out, "train_loss": l.item()}

# -----------------------------------------------------------------------------
# EVAL
# -----------------------------------------------------------------------------
def compute_pck(pred_kpts, gt_kpts, thresholds=[1, 3, 5]):
    dists = np.linalg.norm(pred_kpts - gt_kpts, axis=2).ravel()
    pck = {}
    for t in thresholds:
        pck[f'PCK@{t}px'] = np.mean(dists < t)
    return pck

def compute_auc(errors, max_threshold=10):
    errs = np.sort(errors)
    recall = np.linspace(0, 1, len(errs))
    ths    = np.linspace(0, max_threshold, len(errs))
    mask = errs <= max_threshold
    if not np.any(mask):
        return 0.0
    return np.trapz(recall[mask], ths[mask]) / max_threshold

def compute_maa(errors, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(1, 10, 10)
    maa = np.mean([np.mean(errors < t) for t in thresholds])
    return maa

@torch.no_grad()
def eval_epoch(dataloader, model, objective, epoch, config):
    model.eval()
    model.exact_softmax = True

    all_losses, all_pcks, all_aucs, all_maas = [], [], [], []

    print(f"\nRunning validation at epoch {epoch}")
    pbar = tqdm(dataloader, desc="Validation", mininterval=10.0)
    for batch in pbar:
        batch_cuda = {k: v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        with torch.cuda.amp.autocast(enabled=False):
            out_train = model(batch_cuda)
            loss = torch.clamp(objective(out_train, batch_cuda), min=0.0, max=10.0)
            all_losses.append(loss.item())

        imA_batch = batch['im_A']
        B = imA_batch.shape[0]
        pred_batch = []
        for i in range(B):
            pilA = tensor_to_pil(batch['im_A'][i], unnormalize=True)
            pilB = tensor_to_pil(batch['im_B'][i], unnormalize=True)
            w, _ = model.match(pilA, pilB)
            if isinstance(w, torch.Tensor): w = w.cpu().numpy()
            u_norm = w[..., 2].ravel()
            v_norm = w[..., 3].ravel()
            x_pred = (u_norm + 1.0) * (config['dataset']['coarse_res'][1] / 2.0)
            y_pred = (v_norm + 1.0) * (config['dataset']['coarse_res'][0] / 2.0)
            pred_kpts = np.stack([x_pred, y_pred], axis=1)
            pred_batch.append(pred_kpts)

        pred_all = np.stack(pred_batch, axis=0)
        gt_all = batch['gt_kpts'].cpu().numpy().reshape(B, -1, 2)

        H, W = config['dataset']['coarse_res']
        valid_mask = (
            (gt_all[..., 0] >= 0) & (gt_all[..., 0] < W) &
            (gt_all[..., 1] >= 0) & (gt_all[..., 1] < H)
        )

        dists = np.linalg.norm(pred_all - gt_all, axis=2)
        dists = dists[valid_mask]

        if dists.size > 0:
            pbar.set_postfix(min=dists.min(), mean=dists.mean(), max=dists.max())
            all_aucs.append(compute_auc(dists))
            all_maas.append(compute_maa(dists))
        else:
            pbar.set_postfix_str("No valid correspondences")
            continue

        all_pcks.append(compute_pck(pred_all, gt_all))

    if len(all_aucs) == 0 or len(all_maas) == 0 or len(all_pcks) == 0:
        print("No valid samples found during evaluation.")
        return float('inf')

    avg_loss = np.mean(all_losses)
    pck_keys = list(all_pcks[0].keys())
    avg_pck = {k: np.mean([m[k] for m in all_pcks]) for k in pck_keys}
    avg_auc = np.mean(all_aucs); avg_maa = np.mean(all_maas)

    metrics_log[f"epoch_{epoch}"] = {"avg_loss": avg_loss, **avg_pck, "AUC": avg_auc, "mAA": avg_maa}
    with open(metrics_log_path, "w") as f: json.dump(metrics_log, f, indent=2)

    pck_str = " ".join(f"{k}={v:.3f}" for k, v in avg_pck.items())
    print(f"Eval epoch {epoch}: loss={avg_loss:.4f} AUC={avg_auc:.4f} mAA={avg_maa:.4f} {pck_str}")

    model.exact_softmax = False
    return avg_loss

# -----------------------------------------------------------------------------
# TRAIN EPOCH
# -----------------------------------------------------------------------------
def train_epoch(dataloader, model, objective, optimizer, lr_scheduler, scaler, epoch, config):
    model.train(True)
    print(f"\nStarting epoch {epoch}/{config['train']['epochs']}")

    epoch_losses = []
    pbar = tqdm(
        dataloader,
        desc=f"Epoch {epoch}/{config['train']['epochs']}",
        mininterval=10.0,
        leave=False,
        file=sys.stdout
    )

    for batch in pbar:
        batch = to_cuda(batch)
        res = train_step(batch, model, objective, optimizer, scaler)

        if not torch.isfinite(torch.tensor(res['train_loss'])):
            raise Exception("NaN/Inf loss")

        pbar.set_postfix(loss=f"{res['train_loss']:.4f}")
        epoch_losses.append(res['train_loss'])

        lr_scheduler.step()
        romatch.GLOBAL_STEP += romatch.STEP_SIZE

    ckpt_path = os.path.join(save_dir, f"{config['model']['name']}_ep{epoch:03d}.pth")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'grad_scaler_state_dict': scaler.state_dict(),
        'scheduler_state_dict': lr_scheduler.state_dict(),
    }, ckpt_path)

    avg_loss = np.mean(epoch_losses)
    print(f"Finished epoch {epoch}: avg batch loss = {avg_loss:.4f}")

    loss_log[f"epoch_{epoch}"] = {
        "batch_losses": epoch_losses,
        "avg_loss": avg_loss
    }
    with open(loss_log_path, "w") as f:
        json.dump(loss_log, f, indent=2)

# -----------------------------------------------------------------------------
# TRAIN K EPOCHS
# -----------------------------------------------------------------------------
def train_k_epochs(start_epoch, end_epoch, train_loader, val_loader, model, objective, optimizer, lr_scheduler, scaler, config):
    best_val_loss = float('inf')
    patience = config["train"].get("early_stopping_patience", None)
    patience_counter = 0

    eval_every = config["train"].get("eval_every", 3)

    for ep in range(start_epoch, end_epoch + 1):
        train_epoch(train_loader, model, objective, optimizer, lr_scheduler, scaler, ep, config)

        should_eval = (ep % eval_every == 0) or (ep == end_epoch) or (ep == 1)

        should_eval = True

        if should_eval:
            val_loss = eval_epoch(val_loader, model, objective, ep, config)

            if patience is not None:
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    patience_counter = 0
                else:
                    patience_counter += 1
                    print(f"Early stopping patience {patience_counter}/{patience}")
                    if patience_counter >= patience:
                        print("Early stopping triggered.")
                        break

#### Config

In [None]:
config = {
    "dataset": {
        "name": "nerf_roma",
        "root_dir": "/content/0080",
        "train_batch_size": 40,
        "val_batch_size": 40,
        "coarse_res": (560, 560),
    },
    "model": {
        "name": "tiny_roma",
        "freeze_backbone": False,
        "use_pretrained": True
    },
    "train": {
        "epochs": 100,
        "lr": 2e-6,
        "weight_decay": 1e-4,
        "early_stopping_patience": 10,
        "lr_scheduler": {
            "name": "cosine",
            "T_max": 20,
            "eta_min": 1e-6
        }
    }
}

### Training Script

In [None]:
# -----------------------------------------------------------------------------
# DEVICE, TRANSFORMS, DATA
# -----------------------------------------------------------------------------
torch.cuda.empty_cache()
gc.collect()
torch.cuda.ipc_collect()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

COARSE_RES = config["dataset"]["coarse_res"]
im_transform  = get_tuple_transform_ops(resize=COARSE_RES, normalize=True)
depth_trans   = get_depth_tuple_transform_ops(resize=COARSE_RES, normalize=False)
def depth_transform_pair(im_tuple):
    tA = torch.from_numpy(im_tuple[0]).float()[None,None]
    tB = torch.from_numpy(im_tuple[1]).float()[None,None]
    oA, oB = depth_trans((tA, tB))
    return oA.squeeze(), oB.squeeze()

train_dataset = NeRFRoMaDataset(
    config["dataset"],
    split="train",
    transform=im_transform,
    depth_transform=depth_transform_pair,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=config["dataset"]["train_batch_size"],
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True
)

val_dataset = NeRFRoMaDataset(
    config["dataset"],
    split="val",
    transform=im_transform,
    depth_transform=depth_transform_pair,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config["dataset"]["val_batch_size"],
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True
)

# -----------------------------------------------------------------------------
# MODEL, OPTIMIZER, SCHEDULER, SCALER
# -----------------------------------------------------------------------------
is_model_tiny = True if config["model"]["name"] == "tiny_roma" else False
model = (tiny_roma_v1_outdoor if is_model_tiny else roma_outdoor)(device=device).to(device)

criterion = RobustLosses(); criterion.local_largest_scale = -1

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=config["train"]["lr"],
    weight_decay=config["train"].get("weight_decay", 0)
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max = config["train"]["epochs"] * len(train_loader),
    eta_min = config["train"]["lr_scheduler"].get("eta_min", 0)
)

scaler = torch.amp.GradScaler()

# -----------------------------------------------------------------------------
# LOAD CHECKPOINT IF AVAILABLE
# -----------------------------------------------------------------------------
checkpoint_files = sorted(glob.glob(f"{save_dir}/{config['model']['name']}_ep*.pth"))
start_epoch = 1

if checkpoint_files:
    latest_ckpt = checkpoint_files[-1]
    print(f"üîÅ Resuming from checkpoint: {latest_ckpt}")
    checkpoint = torch.load(latest_ckpt)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    if 'grad_scaler_state_dict' in checkpoint:
        scaler.load_state_dict(checkpoint['grad_scaler_state_dict'])

    OVERRIDE_LR = False

    if OVERRIDE_LR:
        new_lr = config["train"]["lr"]
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lr
        print(f"Overriding LR to {new_lr:.2e} after loading checkpoint")

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max = config["train"]["epochs"] * len(train_loader),
            eta_min = config["train"]["lr_scheduler"].get("eta_min", 0)
        )
else:
    print("No checkpoint found, starting from scratch.")

# -----------------------------------------------------------------------------
# RUN TRAINING
# -----------------------------------------------------------------------------
print(f"Starting training: {config['train']['epochs']} epochs, {len(train_loader)} steps/epoch")
train_k_epochs(
    start_epoch = start_epoch,
    end_epoch   = config["train"]["epochs"],
    train_loader= train_loader,
    val_loader  = val_loader,
    model       = model,
    objective   = criterion,
    optimizer   = optimizer,
    lr_scheduler= scheduler,
    scaler      = scaler,
    config      = config
)