# VoVNet + LSS v2 + Transformer Training on Google Colab

**Multi-modal BEV Perception for Autonomous Driving**

This notebook trains VoVNetV2-39 + LSS v2 + Lightweight Transformer on nu-A2D dataset.

**Expected Performance:**
- BEV mIoU: 52-54% (baseline: 47%)
- Action F1: 75-78% (baseline: 72%)
- Description F1: 71-74% (baseline: 68%)

**Hardware Requirements:**
- GPU: T4 (15GB) or better
- RAM: 25GB+
- Disk: 50GB+

## 1. Setup Conda Environment with Konda

In [None]:
# Install konda for conda environment management in Colab
!pip install -q konda
import konda
konda.install()


In [None]:
# Verify conda installation
!conda --version


## 2. Clone Repository and Setup Project

In [None]:
# Mount Google Drive (if using Drive for data storage)
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# Clone your repository (replace with your repo URL)
!git clone https://github.com/YOUR_USERNAME/Multimodal-XAD.git
%cd Multimodal-XAD


## 3. Create Conda Environment

In [None]:
# Create conda environment from environment.yaml
!conda env create -f environment.yaml


In [None]:
# Activate environment (konda method)
import konda
konda.activate('Multimodal_XAD')


In [None]:
# Verify installation
!python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: {torch.cuda.is_available()}')"
!python -c "import timm; print(f'timm: {timm.__version__}')"


## 4. Prepare Dataset

**Dataset Structure Required:**
```
data/trainval/nu-A2D-*/nu-A2D/trainval/
├── v1.0-trainval/       # Metadata
├── samples/             # Camera & LIDAR
│   ├── CAM_FRONT/
│   ├── CAM_BACK/
│   ├── LIDAR_TOP/
│   └── ...
├── maps/                # Map masks
├── action_all/
├── desc_all/
└── local_binmap/
```

In [None]:
# Option 1: Download from Google Drive
# Replace with your actual data path
!cp -r /content/drive/MyDrive/nu-A2D-dataset ./data/trainval/


In [None]:
# Option 2: Upload dataset as ZIP and extract
# Uncomment if uploading manually
# from google.colab import files
# uploaded = files.upload()  # Upload your nu-A2D-dataset.zip
# !unzip -q nu-A2D-dataset.zip -d ./data/trainval/


In [None]:
# Verify data structure
import os

data_path = './data/trainval/nu-A2D-20260129T100537Z-3-001/nu-A2D/trainval'
required_folders = ['v1.0-trainval', 'samples', 'maps', 'action_all', 'desc_all']

print("Checking dataset structure...")
for folder in required_folders:
    path = os.path.join(data_path, folder)
    exists = os.path.exists(path)
    print(f"{'✓' if exists else '✗'} {folder}: {exists}")

# Count samples
cam_path = os.path.join(data_path, 'samples/CAM_FRONT')
if os.path.exists(cam_path):
    num_samples = len([f for f in os.listdir(cam_path) if f.endswith('.jpg')])
    print(f"\nTotal CAM_FRONT images: {num_samples}")


## 5. Test Model Architecture

In [None]:
# Quick test to verify model can be created
!python src/model_vovnet_transformer.py


## 6. Configure Training Parameters

In [None]:
# Training configuration
config = {
    'vovnet_type': 'vovnet39',  # or 'vovnet57' for larger model
    'pretrained': True,
    'batch_size': 4,  # Reduce to 3 if OOM on T4
    'epochs': 80,
    'lr': 2e-4,
    'num_workers': 2,  # Reduce for Colab
    'save_dir': './checkpoints_vovnet_colab'
}

print("Training Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")


In [None]:
# Check GPU memory
!nvidia-smi


## 7. Start Training

In [None]:
# Run training script
# This will train for 80 epochs with automatic checkpointing
!python train_vovnet_transformer.py


## 8. Monitor Training (Run in Separate Cell)

In [None]:
# Monitor GPU usage during training
!watch -n 1 nvidia-smi


In [None]:
# View training logs (run after training starts)
import matplotlib.pyplot as plt
import re

def parse_log(log_file='training.log'):
    """Parse training log and plot metrics"""
    if not os.path.exists(log_file):
        print("Log file not found. Training may not have started yet.")
        return
    
    epochs, train_loss, val_loss, bev_miou = [], [], [], []
    
    with open(log_file, 'r') as f:
        for line in f:
            if 'Epoch' in line and 'Train Loss' in line:
                match = re.search(r'Epoch (\d+).*Train Loss: ([\d.]+).*Val Loss: ([\d.]+).*BEV mIoU: ([\d.]+)', line)
                if match:
                    epochs.append(int(match.group(1)))
                    train_loss.append(float(match.group(2)))
                    val_loss.append(float(match.group(3)))
                    bev_miou.append(float(match.group(4)))
    
    if not epochs:
        print("No training data found in log.")
        return
    
    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))
    
    ax1.plot(epochs, train_loss, label='Train Loss', marker='o')
    ax1.plot(epochs, val_loss, label='Val Loss', marker='s')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training & Validation Loss')
    ax1.legend()
    ax1.grid(True)
    
    ax2.plot(epochs, bev_miou, label='BEV mIoU', marker='o', color='green')
    ax2.axhline(y=47, color='r', linestyle='--', label='Baseline (47%)')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('mIoU (%)')
    ax2.set_title('BEV Segmentation Performance')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Best BEV mIoU: {max(bev_miou):.2f}% at Epoch {epochs[bev_miou.index(max(bev_miou))]}")

# Call this periodically to monitor progress
parse_log()


## 9. Resume Training (If Interrupted)

In [None]:
# Find latest checkpoint
import glob

checkpoints = glob.glob('./checkpoints_vovnet_colab/checkpoint_epoch_*.pth')
if checkpoints:
    latest = max(checkpoints, key=os.path.getctime)
    print(f"Latest checkpoint: {latest}")
    print(f"To resume, modify train_vovnet_transformer.py to load this checkpoint")
else:
    print("No checkpoints found.")


## 10. Save Results to Google Drive

In [None]:
# Copy checkpoints to Google Drive for persistence
!mkdir -p /content/drive/MyDrive/Multimodal-XAD-Results
!cp -r ./checkpoints_vovnet_colab /content/drive/MyDrive/Multimodal-XAD-Results/
!cp training.log /content/drive/MyDrive/Multimodal-XAD-Results/ 2>/dev/null || echo "No log file yet"

print("✓ Checkpoints saved to Google Drive!")


## 11. Evaluate Best Model

In [None]:
# Load best checkpoint and evaluate
import torch
from src.model_vovnet_transformer import compile_model_vovnet_transformer

# Find best checkpoint (highest validation performance)
best_ckpt = './checkpoints_vovnet_colab/best_model.pth'

if os.path.exists(best_ckpt):
    print(f"Loading best checkpoint: {best_ckpt}")
    checkpoint = torch.load(best_ckpt)
    
    print(f"\nBest Model Performance:")
    print(f"  Epoch: {checkpoint['epoch']}")
    print(f"  BEV mIoU: {checkpoint['bev_iou']:.2f}%")
    print(f"  Action F1: {checkpoint.get('action_f1', 0):.2f}%")
    print(f"  Description F1: {checkpoint.get('desc_f1', 0):.2f}%")
else:
    print("Best checkpoint not found. Training may still be in progress.")


## 12. Quick Inference Test

In [None]:
# Test inference on a sample
import torch
import matplotlib.pyplot as plt
import numpy as np
from src.model_vovnet_transformer import compile_model_vovnet_transformer

# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

grid_conf = {
    'xbound': [-50.0, 50.0, 0.5],
    'ybound': [-50.0, 50.0, 0.5],
    'zbound': [-10.0, 10.0, 20.0],
    'dbound': [4.0, 45.0, 1.0],
}

data_aug_conf = {
    'resize_lim': (0.193, 0.225),
    'final_dim': (128, 352),
    'rot_lim': (-5.4, 5.4),
    'H': 900, 'W': 1600,
    'rand_flip': False,
    'bot_pct_lim': (0.0, 0.22),
    'cams': ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT',
             'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT'],
    'Ncams': 6,
}

model = compile_model_vovnet_transformer(
    bsize=1, grid_conf=grid_conf, data_aug_conf=data_aug_conf,
    outC=4, vovnet_type='vovnet39', pretrained=False
).to(device)

# Load weights
if os.path.exists(best_ckpt):
    model.load_state_dict(torch.load(best_ckpt)['model_state_dict'])
    model.eval()
    print("✓ Model loaded successfully!")
else:
    print("⚠ No checkpoint found, using random weights")

# Dummy inference
with torch.no_grad():
    imgs = torch.randn(6, 3, 128, 352).to(device)
    rots = torch.randn(1, 6, 3, 3).to(device)
    trans = torch.randn(1, 6, 3).to(device)
    intrins = torch.randn(1, 6, 3, 3).to(device)
    post_rots = torch.randn(1, 6, 3, 3).to(device)
    post_trans = torch.randn(1, 6, 3).to(device)
    
    bev_seg, action, desc = model(imgs, rots, trans, intrins, post_rots, post_trans)
    
    print(f"\nOutput shapes:")
    print(f"  BEV Segmentation: {bev_seg.shape}")
    print(f"  Action: {action.shape}")
    print(f"  Description: {desc.shape}")
    
    # Visualize BEV prediction
    bev_pred = bev_seg.argmax(1).cpu().numpy()[0]
    
    plt.figure(figsize=(8, 8))
    plt.imshow(bev_pred, cmap='tab10')
    plt.title('BEV Segmentation Prediction')
    plt.colorbar(label='Class')
    plt.xlabel('X (meters)')
    plt.ylabel('Y (meters)')
    plt.show()


## 13. Clean Up (Optional)

In [None]:
# Free up disk space by removing intermediate checkpoints
# (Keep only best and latest)
import glob

checkpoints = glob.glob('./checkpoints_vovnet_colab/checkpoint_epoch_*.pth')
if len(checkpoints) > 5:
    checkpoints.sort(key=os.path.getctime)
    to_remove = checkpoints[:-5]  # Keep last 5
    
    for ckpt in to_remove:
        os.remove(ckpt)
        print(f"Removed: {ckpt}")
    
    print(f"\n✓ Cleaned up {len(to_remove)} old checkpoints")
else:
    print("No cleanup needed.")


---

## Tips for Google Colab:

1. **Prevent Disconnection**: Keep browser tab active or use Colab Pro
2. **Save Frequently**: Checkpoints auto-save every 5 epochs to Drive
3. **Monitor GPU**: Use `!nvidia-smi` to check memory usage
4. **Reduce Batch Size**: If OOM error, reduce from 4→3 or 3→2
5. **Use T4 or Better**: Request GPU in Runtime → Change runtime type

## Expected Timeline (T4 GPU):
- ~15 min/epoch
- 80 epochs = ~20 hours
- Best results typically at epoch 60-70

## Troubleshooting:
- **OOM Error**: Reduce batch_size or use gradient checkpointing
- **Slow Training**: Reduce num_workers to 0-2
- **Data Not Found**: Check dataset structure in cell 4
- **Import Errors**: Re-run conda environment setup