# 03 - 3D Voxel CNN Training
# AutonomousVehiclePerception/notebooks/03_3d_voxel_training.ipynb

Train 3D CNN on voxelized LiDAR point clouds for BEV object detection.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import time
import numpy as np

from src.model.cnn_3d_voxel import VoxelBackbone3D
from src.data.kitti_dataset import voxelize_points
from src.data.augmentations import get_lidar_augmentations

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')

## Configuration

In [None]:
CONFIG = {
    'num_classes': 5,
    'voxel_size': (0.2, 0.2, 0.2),
    'point_range': (-40, -40, -3, 40, 40, 1),
    'batch_size': 4,
    'num_epochs': 30,
    'learning_rate': 1e-3,
    'weight_decay': 1e-4,
    'val_split': 0.2,
    'kitti_root': '../data/raw/kitti',
    'checkpoint_dir': '../checkpoints',
    'log_dir': '../logs/tensorboard/voxel3d',
}

Path(CONFIG['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
Path(CONFIG['log_dir']).mkdir(parents=True, exist_ok=True)

## Data Loading

Voxelize LiDAR point clouds into dense 3D grids for Conv3d processing.

In [None]:
kitti_root = Path(CONFIG['kitti_root'])
lidar_dir = kitti_root / 'training' / 'velodyne'

if lidar_dir.exists():
    from src.data.kitti_dataset import load_lidar_points
    bin_files = sorted(lidar_dir.glob('*.bin'))[:200]
    print(f'Loading {len(bin_files)} LiDAR scans...')

    voxel_grids = []
    for bf in bin_files:
        pts = load_lidar_points(str(bf))
        vg = voxelize_points(pts, CONFIG['voxel_size'], CONFIG['point_range'])
        voxel_grids.append(vg)

    X = torch.from_numpy(np.stack(voxel_grids)).unsqueeze(1)  # (N, 1, D, H, W)
    # Placeholder labels for BEV grid
    Y = torch.randint(0, CONFIG['num_classes'], (len(voxel_grids), X.shape[3]//4, X.shape[4]//4))
else:
    print('KITTI LiDAR not found. Using synthetic voxel data.')
    N = 100
    D, H, W = 20, 128, 128
    X = torch.randn(N, 1, D, H, W)
    Y = torch.randint(0, CONFIG['num_classes'], (N, H // 4, W // 4))

print(f'Voxel data shape: {X.shape}')
print(f'Labels shape: {Y.shape}')

dataset = TensorDataset(X, Y)
val_size = int(len(dataset) * CONFIG['val_split'])
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, pin_memory=True)
print(f'Train: {train_size}, Val: {val_size}')

## LiDAR Augmentations

In [None]:
lidar_augs = get_lidar_augmentations()
print('Available LiDAR augmentations:')
for name in lidar_augs:
    print(f'  - {name}')

## Model Setup

In [None]:
model = VoxelBackbone3D(in_channels=1, num_classes=CONFIG['num_classes']).to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f'VoxelBackbone3D parameters: {num_params:,}')

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['num_epochs'])
writer = SummaryWriter(log_dir=CONFIG['log_dir'])

## Training Loop

In [None]:
best_val_loss = float('inf')

for epoch in range(CONFIG['num_epochs']):
    start = time.time()

    # Train
    model.train()
    train_loss_sum = 0.0
    train_correct = 0
    train_total = 0

    for voxels, targets in train_loader:
        voxels, targets = voxels.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(voxels)

        if outputs.shape[2:] != targets.shape[1:]:
            targets = torch.nn.functional.interpolate(
                targets.unsqueeze(1).float(), size=outputs.shape[2:], mode='nearest'
            ).squeeze(1).long()

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss_sum += loss.item()
        preds = outputs.argmax(dim=1)
        train_correct += (preds == targets).sum().item()
        train_total += targets.numel()

    # Validate
    model.eval()
    val_loss_sum = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for voxels, targets in val_loader:
            voxels, targets = voxels.to(device), targets.to(device)
            outputs = model(voxels)

            if outputs.shape[2:] != targets.shape[1:]:
                targets = torch.nn.functional.interpolate(
                    targets.unsqueeze(1).float(), size=outputs.shape[2:], mode='nearest'
                ).squeeze(1).long()

            loss = criterion(outputs, targets)
            val_loss_sum += loss.item()
            preds = outputs.argmax(dim=1)
            val_correct += (preds == targets).sum().item()
            val_total += targets.numel()

    scheduler.step()
    elapsed = time.time() - start

    train_loss = train_loss_sum / max(len(train_loader), 1)
    val_loss = val_loss_sum / max(len(val_loader), 1)
    train_acc = train_correct / max(train_total, 1)
    val_acc = val_correct / max(val_total, 1)

    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/val', val_loss, epoch)
    writer.add_scalar('Accuracy/train', train_acc, epoch)
    writer.add_scalar('Accuracy/val', val_acc, epoch)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_loss': val_loss,
        }, Path(CONFIG['checkpoint_dir']) / 'voxel3d_best.pth')

    if epoch % 5 == 0 or epoch == CONFIG['num_epochs'] - 1:
        print(f'Epoch {epoch:3d}/{CONFIG["num_epochs"]} | '
              f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | '
              f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} | '
              f'Time: {elapsed:.1f}s')

writer.close()
print(f'\nTraining complete. Best val loss: {best_val_loss:.4f}')