## Dataset Preparation

In [1]:
import os
import time
import numpy as np
import pandas as pd
import torch
import multiprocessing as mp
from datetime import datetime
from tqdm.auto import tqdm
from scipy.spatial.transform import Rotation as R
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, random_split

from model import VIOCNN
from dataset import VIOPreprocessedDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = {
    "data_root": "./output",
    "batch_size": 8,
    "sequence_length": 8,
    "imu_per_frame": 50,
    "image_size": (480, 640),
    "learning_rate": 1e-3,
    "weight_decay": 1e-5,
    "epochs": 150,
    "train_split": 0.8,
    "num_workers": 8,
    "checkpoint_dir": "./checkpoints",
    "resume_training": True,
    "use_tensorboard": True,
    "imu_noise_std": 0.01,
    "imu_dropout_prob": 0.05,
    "p_invert": 0.05,
    "p_bgr": 0.05,
    "p_gray": 0.05,
    "p_mask": 0.05,
    "use_quat": False
}

os.makedirs(config["checkpoint_dir"], exist_ok=True)
if config["use_tensorboard"]:
    writer = SummaryWriter(os.path.join("logs", datetime.now().strftime("%Y%m%d-%H%M%S")))

def create_csv_from_preprocessed(root):
    imu_dir, lbl_dir = os.path.join(root, "imu"), os.path.join(root, "labels")
    ts = np.load(os.path.join(root, "image_timestamps.npy"))
    imu_out, odo_out = os.path.join(root, "imu_data.csv"), os.path.join(root, "odom_data.csv")
    if not os.path.exists(imu_out):
        with open(imu_out, "w") as f:
            f.write("timestamp,ax,ay,az,gx,gy,gz\n")
            for i, t in enumerate(ts):
                data = np.load(os.path.join(imu_dir, f"imu_{i:06d}.npy"))
                for k, s in enumerate(data[::-1]):
                    f.write(f"{t - k * 1e6},{','.join(map(str, s))}\n")
    if not os.path.exists(odo_out):
        with open(odo_out, "w") as f:
            f.write("timestamp,x,y,z,roll,pitch,yaw\n")
            for i, t in enumerate(ts):
                lbl = np.load(os.path.join(lbl_dir, f"label_{i:06d}.npy"))
                x, y, z, qx, qy, qz, qw = lbl
                q = np.array([qx, qy, qz, qw])
                n = np.linalg.norm(q)
                if n < 1e-8:
                    r = p = yaw = 0
                else:
                    r, p, yaw = R.from_quat(q / n).as_euler('xyz')
                f.write(f"{t},{x},{y},{z},{r},{p},{yaw}\n")

def create_datasets(cf):
    create_csv_from_preprocessed(cf["data_root"])
    ds = VIOPreprocessedDataset(
        data_root=cf["data_root"],
        imu_csv=os.path.join(cf["data_root"], "imu_data.csv"),
        odom_csv=os.path.join(cf["data_root"], "odom_data.csv"),
        sequence_length=cf["sequence_length"],
        imu_per_frame=cf["imu_per_frame"],
        image_size=cf["image_size"][::-1],
        imu_noise_std=cf["imu_noise_std"],
        imu_dropout_prob=cf["imu_dropout_prob"],
        p_invert=cf["p_invert"],
        p_bgr=cf["p_bgr"],
        p_gray=cf["p_gray"],
        p_mask=cf["p_mask"],
    )
    n_train = int(cf["train_split"] * len(ds))
    return random_split(ds, [n_train, len(ds) - n_train], generator=torch.Generator().manual_seed(42))

In [2]:
# Create datasets
train_ds, val_ds = create_datasets(config)

In [3]:
import pandas as pd

imu_df = pd.read_csv("./output/imu_data.csv")
odom_df = pd.read_csv("./output/odom_data.csv")
print("IMU rows :", len(imu_df))
print("Odom rows:", len(odom_df))

IMU rows : 1844250
Odom rows: 36885


In [4]:
if mp.get_start_method(allow_none=True) != "spawn":
    mp.set_start_method("spawn", force=True)

In [5]:
ctx = mp.get_context("spawn")

In [6]:
train_loader = DataLoader(train_ds,
                          batch_size=config["batch_size"],
                          shuffle=True,
                          num_workers=config["num_workers"],
                          persistent_workers=True,
                          multiprocessing_context=ctx,
                          pin_memory=True
                          )
val_loader = DataLoader(val_ds,
                        batch_size=config["batch_size"],
                        shuffle=False,
                        num_workers=config["num_workers"],
                        persistent_workers=True,
                        multiprocessing_context=ctx,
                        pin_memory=True
                        )

In [7]:
class GeodesicLoss(nn.Module):
    def forward(self, pred, target):
        def rot_mat(r):
            theta = torch.linalg.norm(r, dim=-1, keepdim=True).clamp(min=1e-6)
            axis  = r / theta
            K     = torch.zeros(*r.shape[:-1], 3, 3, device=r.device)
            K[..., 0, 1], K[..., 0, 2] = -axis[..., 2],  axis[..., 1]
            K[..., 1, 0], K[..., 1, 2] =  axis[..., 2], -axis[..., 0]
            K[..., 2, 0], K[..., 2, 1] = -axis[..., 1],  axis[..., 0]
            I = torch.eye(3, device=r.device)
            return I + torch.sin(theta)[..., None] * K + (1 - torch.cos(theta))[..., None] * (K @ K)

        r = pred.view(-1, 3); t = target.view(-1, 3)
        R_pred = rot_mat(r); R_gt = rot_mat(t)
        R_diff = R_pred @ R_gt.transpose(-1, -2)
        trace  = R_diff.diagonal(offset=0, dim1=-2, dim2=-1).sum(-1)
        angle  = torch.acos(torch.clamp((trace - 1) / 2, -1+1e-6, 1-1e-6))
        return angle.mean()

class VIOLoss(nn.Module):
    def __init__(self, huber_delta=0.1):
        super().__init__()
        self.huber = nn.HuberLoss(delta=huber_delta)
        self.geo   = GeodesicLoss()
    def forward(self, pred, target):
        pos_p, rot_p = pred[..., :3], pred[..., 3:]
        pos_t, rot_t = target[..., :3], target[..., 3:]
        return self.huber(pos_p, pos_t) + 0.5 * self.geo(rot_p, rot_t)

## Model Initialization

In [8]:
model = VIOCNN(
    img_channels=3,
    imu_dim=6,
    emb_dim=192,
    hidden_size=256,
    gru_layers=3,
    dropout_p=0.2
).to(device)
criterion = VIOLoss()

optimiser = torch.optim.AdamW(model.parameters(),
                              lr=config["learning_rate"],
                              weight_decay=config["weight_decay"])

from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(
    optimiser,
    mode='min',
    factor=0.5,
    patience=4,
    threshold=2e-4,
    min_lr=1e-6,
)

patience, min_delta, best_val, pat_cnt = 30, 0.002, float("inf"), 0
start_ep, history = 0, {"loss": [], "val_loss": []}

cp = os.path.join(config["checkpoint_dir"], "latest_checkpoint.pth")
if config["resume_training"] and os.path.exists(cp):
    ckpt = torch.load(cp, map_location=device)
    model.load_state_dict(ckpt["model_state_dict"])
    optimiser.load_state_dict(ckpt["optimizer_state_dict"])
    scheduler.load_state_dict(ckpt["scheduler_state_dict"])
    start_ep, best_val = ckpt["epoch"] + 1, ckpt["best_val_loss"]
    history = ckpt["train_history"]

In [9]:
num_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {num_params:,}")

Total parameters: 3,809,677


In [10]:
def compute_component_losses(pred, target):
    """Returns dict with position, rotation and total loss components."""
    pos_p, rot_p = pred[..., :3], pred[..., 3:]
    pos_t, rot_t = target[..., :3], target[..., 3:]

    pos_loss = criterion.huber(pos_p, pos_t)          # Huber positional
    rot_loss = criterion.geo(rot_p, rot_t)            # Geodesic rotational
    total    = pos_loss + 0.5 * rot_loss              # match VIOLoss weighting
    return {'pos': pos_loss, 'rot': rot_loss, 'total': total}

In [11]:
def run_epoch(loader, train=True, epoch_idx=0):
    model.train() if train else model.eval()
    sums = {'pos': 0., 'rot': 0., 'total': 0.}
    count = 0

    with torch.set_grad_enabled(train):
        for imgs, imu, poses, lens in (pbar := tqdm(loader, leave=False)):
            imgs, imu, poses = imgs.to(device), imu.to(device), poses.to(device)
            optimiser.zero_grad()

            pred, _   = model(imgs, imu, lens)
            losses    = compute_component_losses(pred, poses)

            if train:
                losses['total'].backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimiser.step()

            # accumulate
            bs = imgs.size(0)
            count += bs
            for k in sums:
                sums[k] += losses[k].item() * bs

            pbar.set_postfix({f"{k}_loss": f"{v.item():.3f}" for k, v in losses.items()})

    # averages
    for k in sums:
        sums[k] /= count
    return sums  # dict of avg losses

In [12]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from math import ceil

def map_view_val_batch(model, val_loader, device='cuda', max_seq: int = 9):
    """Plot up to `max_seq` sequences from one validation batch."""
    model.eval()
    with torch.no_grad():
        imgs, imu, gt_poses, seq_lens = next(iter(val_loader))
        imgs, imu = imgs.to(device), imu.to(device)
        pred_poses, _ = model(imgs, imu, seq_lens)

    B = min(len(gt_poses), max_seq)
    cols = 3
    rows = ceil(B / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4),
                             squeeze=False)
    axes = axes.flatten()

    for idx in range(B):
        ax = axes[idx]
        gt = gt_poses[idx, :, :2].cpu().numpy()
        pr = pred_poses[idx, :, :2].cpu().numpy()

        # cumulative positions
        gt_xy   = np.vstack([[0, 0], np.cumsum(gt, 0)])
        pred_xy = np.vstack([[0, 0], np.cumsum(pr, 0)])

        rmse = np.sqrt(((gt_xy - pred_xy) ** 2).sum(1).mean())

        ax.plot(gt_xy[:, 0],   gt_xy[:, 1],  'o-',  lw=2,  label='GT')
        ax.plot(pred_xy[:, 0], pred_xy[:, 1], 'x--', lw=1.5, label='Pred')
        ax.scatter(gt_xy[0, 0],  gt_xy[0, 1],  c='green', marker='s', s=60)
        ax.scatter(gt_xy[-1,0],  gt_xy[-1, 1],  c='red',   marker='*', s=80)
        ax.set_aspect('equal')
        ax.grid(True, ls=':')
        ax.set_title(f'seq {idx}  |  XY-RMSE: {rmse:.2f} m')
        ax.set_xlabel('X (m)'); ax.set_ylabel('Y (m)')
        ax.legend(frameon=False, fontsize=8)

    # hide unused axes
    for j in range(B, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

## Training Loop

In [None]:
for ep in range(start_ep, config["epochs"]):
    t0 = time.time()
    tr = run_epoch(train_loader, True,  ep)
    vl = run_epoch(val_loader,   False, ep)
    
    scheduler.step(vl['total'])

    history['loss'].append(tr['total'])
    history['val_loss'].append(vl['total'])

    if config['use_tensorboard']:
        writer.add_scalars("Loss/train", tr, ep)
        writer.add_scalars("Loss/val",   vl, ep)
        writer.add_scalar("LR", optimiser.param_groups[0]['lr'], ep)

    print(f"Epoch {ep+1:02d}/{config['epochs']} | {time.time()-t0:.1f}s "
          f"| train tot {tr['total']:.4f} (pos {tr['pos']:.4f}, rot {tr['rot']:.4f}) "
          f"| val tot {vl['total']:.4f} (pos {vl['pos']:.4f}, rot {vl['rot']:.4f}) "
          f"| lr {optimiser.param_groups[0]['lr']:.2e}")

    ck = {
        "epoch": ep,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimiser.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "train_history": history,
        "best_val_loss": best_val,
        "config": config
    }
    torch.save(ck, cp)

    if vl['total'] < best_val * (1 - min_delta):
        best_val, pat_cnt = vl['total'], 0
        torch.save(ck, os.path.join(config["checkpoint_dir"], "best_checkpoint.pth"))
    else:
        pat_cnt += 1
        if pat_cnt >= patience:
            print(f"Early stop at epoch {ep+1}")
            break

if config['use_tensorboard']:
    writer.close()

  0%|          | 0/3688 [00:19<?, ?it/s]

  0%|          | 0/922 [00:20<?, ?it/s]

Epoch 86/150 | 1233.2s | train tot 0.0394 (pos 0.0251, rot 0.0286) | val tot 0.0402 (pos 0.0254, rot 0.0295) | lr 5.00e-05


  0%|          | 0/3688 [00:00<?, ?it/s]

In [None]:
map_view_val_batch(model, val_loader)