# Mamba UAV Detector - Training Notebook

This notebook contains all training and validation logic. Helper functions live in `mamba/` and `shared/`.

In [1]:
import os
import sys
import torch
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor

REPO_ROOT = os.path.abspath('..')
if REPO_ROOT not in sys.path:
    sys.path.append(REPO_ROOT)

from mamba.config import Config
from mamba.dataset import create_dataloaders
from mamba.trainer import MambaDetectorModule
from mamba.model import MambaUAVDetector
import shared.visualization as viz

pl.seed_everything(42)

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
elif torch.backends.mps.is_available():
    device_name = 'MPS'
else:
    device_name = 'CPU'

print('✅ Setup complete')
print(f'PyTorch version: {torch.__version__}')
print(f'Lightning version: {pl.__version__}')
print(f'Device: {device_name}')

  data = fetch_version_info()
Seed set to 42


✅ Setup complete
PyTorch version: 2.10.0
Lightning version: 2.6.0
Device: MPS


## Configuration

In [2]:
config = Config()

config.data.data_root = '../data/MMFW-UAV/sample'
config.data.sequence_length = 10
config.data.batch_size = 4

config.model.backbone = 'mobilevit_s'
config.model.d_model = 256
config.model.mamba_layers = 4

config.training.max_epochs = 50
config.training.lr = 1e-3

config.experiment_name = 'mamba-mobilevit-s10'
config.use_wandb = True

if torch.cuda.is_available():
    config.accelerator = 'gpu'
elif torch.backends.mps.is_available():
    config.accelerator = 'mps'
else:
    config.accelerator = 'cpu'

print('Configuration:')
print(f'  Data root: {config.data.data_root}')
print(f'  Sequence length: {config.data.sequence_length}')
print(f'  Batch size: {config.data.batch_size}')
print(f'  Backbone: {config.model.backbone}')
print(f'  d_model: {config.model.d_model}')
print(f'  Mamba layers: {config.model.mamba_layers}')
print(f'  Max epochs: {config.training.max_epochs}')
print(f'  Accelerator: {config.accelerator}')

Configuration:
  Data root: ../data/MMFW-UAV/sample
  Sequence length: 10
  Batch size: 4
  Backbone: mobilevit_s
  d_model: 256
  Mamba layers: 4
  Max epochs: 50
  Accelerator: mps


## Data Loading

In [3]:
train_loader, val_loader, test_loader = create_dataloaders(
    data_root=config.data.data_root,
    batch_size=config.data.batch_size,
    num_workers=config.data.num_workers,
    sequence_length=config.data.sequence_length,
    sensor_type=config.data.sensor_type,
    view=config.data.view,
    img_size=config.data.img_size,
    stride=config.data.stride,
)

print('✅ Data loaded:')
print(f'  Train batches: {len(train_loader)}')
print(f'  Val batches: {len(val_loader)}')
print(f'  Test batches: {len(test_loader)}')

images, targets = next(iter(train_loader))
print('\nBatch shape:')
print(f'  Images: {images.shape}')
print(f'  Targets: {targets.shape}')

FileNotFoundError: Split file not found: ../data/MMFW-UAV/splits/train.json. Run scripts/prepare_data.py to generate splits.

## Visualize Sample Sequence

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

sample_seq = images[0]
sample_targets = targets[0]

fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()

for i in range(min(10, len(sample_seq))):
    img = sample_seq[i].permute(1, 2, 0).numpy()
    img = img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
    img = np.clip(img, 0, 1)

    axes[i].imshow(img)
    axes[i].set_title(f'Frame {i}')
    axes[i].axis('off')

    target = sample_targets[i]
    x, y, w, h, conf = target
    if conf > 0.5:
        rect_x = (x - w / 2) * config.data.img_size
        rect_y = (y - h / 2) * config.data.img_size
        rect_w = w * config.data.img_size
        rect_h = h * config.data.img_size
        rect = Rectangle((rect_x, rect_y), rect_w, rect_h, fill=False, edgecolor='red', linewidth=2)
        axes[i].add_patch(rect)

plt.tight_layout()
plt.show()

## Initialize Model

In [None]:
model = MambaDetectorModule(config)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print('✅ Model created:')
print(f'  Total parameters: {total_params:,}')
print(f'  Trainable parameters: {trainable_params:,}')

with torch.no_grad():
    test_output = model(images[:2])
    print('\nTest forward pass:')
    print(f'  Input shape: {images[:2].shape}')
    print(f'  Output shape: {test_output.shape}')

## Setup Training

In [None]:
os.makedirs('../outputs/checkpoints', exist_ok=True)

callbacks = [
    ModelCheckpoint(
        dirpath='../outputs/checkpoints',
        filename=f'{config.experiment_name}-{{epoch:02d}}-{{val_loss:.3f}}',
        monitor=config.training.monitor,
        mode=config.training.mode,
        save_top_k=config.training.save_top_k,
        save_last=True,
    ),
    EarlyStopping(
        monitor=config.training.monitor,
        patience=10,
        mode=config.training.mode,
    ),
    LearningRateMonitor(logging_interval='epoch'),
]

if config.use_wandb:
    logger = WandbLogger(
        project=config.project_name,
        name=config.experiment_name,
        log_model=True,
    )
else:
    logger = True

trainer = pl.Trainer(
    max_epochs=config.training.max_epochs,
    accelerator=config.accelerator,
    devices=config.devices,
    callbacks=callbacks,
    logger=logger,
    log_every_n_steps=config.training.log_every_n_steps,
    val_check_interval=config.training.val_check_interval,
    gradient_clip_val=config.training.gradient_clip_val,
    deterministic=True,
)

print('✅ Trainer configured')

## Train Model

In [None]:
trainer.fit(model, train_loader, val_loader)

print('✅ Training complete!')
print(f'Best model: {trainer.checkpoint_callback.best_model_path}')

## Evaluation

In [None]:
best_model_path = trainer.checkpoint_callback.best_model_path
model = MambaDetectorModule.load_from_checkpoint(best_model_path, config=config)
model.eval()

trainer.test(model, test_loader)

## Inference on Sample

In [None]:
images, targets = next(iter(test_loader))
sample_img = images[0]
sample_target = targets[0]

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
model = model.to(device)
sample_img = sample_img.to(device)

with torch.no_grad():
    pred_dict = model.model.predict(sample_img.unsqueeze(0))

print('Prediction:')
print(f"  Bbox: {pred_dict['bbox'][0]}")
print(f"  Confidence: {pred_dict['confidence'][0]:.3f}")

print('\nGround Truth:')
print(f"  Bbox: {sample_target[-1, :4]}")
print(f"  Confidence: {sample_target[-1, 4]:.3f}")

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
img_gt = sample_img[-1].detach().cpu().permute(1, 2, 0).numpy()
img_gt = img_gt * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
img_gt = np.clip(img_gt, 0, 1)
plt.imshow(img_gt)
plt.title('Ground Truth')

x, y, w, h, conf = sample_target[-1].cpu()
if conf > 0.5:
    rect_x = (x - w / 2) * config.data.img_size
    rect_y = (y - h / 2) * config.data.img_size
    rect_w = w * config.data.img_size
    rect_h = h * config.data.img_size
    rect = Rectangle((rect_x, rect_y), rect_w, rect_h, fill=False, edgecolor='green', linewidth=2)
    plt.gca().add_patch(rect)

plt.subplot(1, 2, 2)
plt.imshow(img_gt)
plt.title('Prediction')

pred_bbox = pred_dict['bbox'][0].detach().cpu()
pred_conf = pred_dict['confidence'][0].detach().cpu()
x, y, w, h = pred_bbox
if pred_conf > 0.5:
    rect_x = (x - w / 2) * config.data.img_size
    rect_y = (y - h / 2) * config.data.img_size
    rect_w = w * config.data.img_size
    rect_h = h * config.data.img_size
    rect = Rectangle((rect_x, rect_y), rect_w, rect_h, fill=False, edgecolor='red', linewidth=2)
    plt.gca().add_patch(rect)

plt.tight_layout()
plt.show()

## Export Model

In [None]:
os.makedirs('../outputs', exist_ok=True)
scripted_model = torch.jit.script(model.model)
torch.jit.save(scripted_model, f'../outputs/{config.experiment_name}.pt')
print(f"✅ Model exported to: ../outputs/{config.experiment_name}.pt")

dummy_input = torch.randn(1, config.data.sequence_length, 3, config.data.img_size, config.data.img_size)
torch.onnx.export(
    model.model,
    dummy_input,
    f'../outputs/{config.experiment_name}.onnx',
    export_params=True,
    opset_version=11,
    input_names=['input'],
    output_names=['output'],
)
print(f"✅ Model exported to: ../outputs/{config.experiment_name}.onnx")