# 02 - CNN Training
# AutonomousVehiclePerception/notebooks/02_cnn_training.ipynb

Train 2D CNN perception model on KITTI dataset with GPU acceleration.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import time
import numpy as np

from src.model.cnn_2d import PerceptionCNN2D
from src.model.fpn_resnet import FPNDetector
from src.data.kitti_dataset import KITTIDataset
from src.data.augmentations import get_train_transforms, get_val_transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')

## Configuration

In [None]:
# Training config
CONFIG = {
    'model': 'cnn_2d',          # 'cnn_2d' or 'fpn_resnet'
    'num_classes': 9,            # KITTI classes
    'image_size': (480, 640),
    'batch_size': 8,
    'num_epochs': 50,
    'learning_rate': 1e-3,
    'weight_decay': 1e-4,
    'val_split': 0.2,
    'kitti_root': '../data/raw/kitti',
    'checkpoint_dir': '../checkpoints',
    'log_dir': '../logs/tensorboard',
}

Path(CONFIG['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
Path(CONFIG['log_dir']).mkdir(parents=True, exist_ok=True)

## Data Loading

In [None]:
train_transforms = get_train_transforms(CONFIG['image_size'])
val_transforms = get_val_transforms(CONFIG['image_size'])

kitti_root = Path(CONFIG['kitti_root'])
if kitti_root.exists():
    full_dataset = KITTIDataset(root=kitti_root, split='training', transform=train_transforms)
    val_size = int(len(full_dataset) * CONFIG['val_split'])
    train_size = len(full_dataset) - val_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    print(f'Train: {train_size}, Val: {val_size}')
else:
    # Use synthetic data for demonstration
    print('KITTI not found. Using synthetic data for demo.')
    from torch.utils.data import TensorDataset
    X = torch.randn(200, 3, *CONFIG['image_size'])
    Y = torch.randint(0, CONFIG['num_classes'], (200, CONFIG['image_size'][0]//16, CONFIG['image_size'][1]//16))
    full_dataset = TensorDataset(X, Y)
    val_size = int(len(full_dataset) * CONFIG['val_split'])
    train_size = len(full_dataset) - val_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    print(f'Synthetic Train: {train_size}, Val: {val_size}')

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=2, pin_memory=True)

## Model Setup

In [None]:
if CONFIG['model'] == 'cnn_2d':
    model = PerceptionCNN2D(num_classes=CONFIG['num_classes'])
else:
    model = FPNDetector(num_classes=CONFIG['num_classes'], pretrained=True)

model = model.to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f'Model: {CONFIG["model"]}')
print(f'Parameters: {num_params:,}')

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['num_epochs'])

writer = SummaryWriter(log_dir=CONFIG['log_dir'])

## Training Loop

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, batch in enumerate(loader):
        if isinstance(batch, (list, tuple)) and len(batch) == 2:
            images, targets = batch
            if isinstance(targets, dict):
                continue  # Skip KITTI dict targets for now
        else:
            continue

        images = images.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        # Resize targets to match output spatial dims if needed
        if outputs.shape[2:] != targets.shape[1:]:
            targets = torch.nn.functional.interpolate(
                targets.unsqueeze(1).float(), size=outputs.shape[2:], mode='nearest'
            ).squeeze(1).long()

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct += (preds == targets).sum().item()
        total += targets.numel()

    avg_loss = running_loss / max(len(loader), 1)
    accuracy = correct / max(total, 1)
    return avg_loss, accuracy


def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in loader:
            if isinstance(batch, (list, tuple)) and len(batch) == 2:
                images, targets = batch
                if isinstance(targets, dict):
                    continue
            else:
                continue

            images = images.to(device)
            targets = targets.to(device)
            outputs = model(images)

            if outputs.shape[2:] != targets.shape[1:]:
                targets = torch.nn.functional.interpolate(
                    targets.unsqueeze(1).float(), size=outputs.shape[2:], mode='nearest'
                ).squeeze(1).long()

            loss = criterion(outputs, targets)
            running_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == targets).sum().item()
            total += targets.numel()

    avg_loss = running_loss / max(len(loader), 1)
    accuracy = correct / max(total, 1)
    return avg_loss, accuracy

In [None]:
best_val_loss = float('inf')

for epoch in range(CONFIG['num_epochs']):
    start = time.time()

    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    scheduler.step()

    elapsed = time.time() - start

    # TensorBoard logging
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/val', val_loss, epoch)
    writer.add_scalar('Accuracy/train', train_acc, epoch)
    writer.add_scalar('Accuracy/val', val_acc, epoch)
    writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)

    # Save best checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        ckpt_path = Path(CONFIG['checkpoint_dir']) / f'{CONFIG["model"]}_best.pth'
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc,
        }, ckpt_path)

    if epoch % 5 == 0 or epoch == CONFIG['num_epochs'] - 1:
        print(f'Epoch {epoch:3d}/{CONFIG["num_epochs"]} | '
              f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | '
              f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} | '
              f'LR: {optimizer.param_groups[0]["lr"]:.6f} | '
              f'Time: {elapsed:.1f}s')

writer.close()
print(f'\nTraining complete. Best val loss: {best_val_loss:.4f}')