# Vehicle Detection with YOLO-SwinV2 Model

This notebook implements a modified YOLO model with SwinV2-Tiny as the backbone for vehicle detection. The model is trained on the AAU RainSnow dataset (vehicles in rainy/snowy conditions) and evaluated on the same highway video as the pre-trained YOLO-V5m model.

**Key Features:**
- SwinV2-Tiny backbone replacing the original CSPDarknet backbone
- Training on weather-degraded vehicle data for improved detection in adverse conditions
- Metrics collection for comparison with the pre-trained YOLO-V5m baseline


## How To Run

It is recommended to run this notebook in Google Colab. However, it is implemented so that it can also be run in a local environment.

**To run this notebook in Google Colab:**
- Download the whole project folder (enhanced_vehicle_detection) from GitHub.
- Place it in MyDrive in Google Drive.
    - If the project folder is placed in a different path in Google Drive, the paths for the input video and outputs need to be edited accordingly.
- All set! You can now run the cells.

**To run this notebook in a local environment:**
- Fork or clone the GitHub repository.
- Run `pip install -r app/requirements.txt` to install all required libraries.
- Since the code requires video conversion, make sure to install **ffmpeg**:
    - macOS: `brew install ffmpeg`
    - Ubuntu/Linux: `sudo apt install ffmpeg`
    - Windows: Download from [ffmpeg.org](https://ffmpeg.org/download.html)
- All set! You can now run the cells.

## Setup YOLO V5 

The code below installs every required libraries to load and use YOLO-V5 model. This code only need to be run once while using this notebook.

In [None]:
!git clone -q https://github.com/ultralytics/yolov5
%cd yolov5
!pip -q install -r requirements.txt opencv-python-headless==4.10.0.84 timm

## Import Necessary Libraries

In [None]:
import cv2, torch, numpy as np, matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from IPython.display import HTML, display
from base64 import b64encode
import timm
import json
import os
import random
from pathlib import Path
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

## Environment Setup

Set up paths based on whether running in Google Colab or local environment.

In [None]:
# Check if running in Google Colab
IN_COLAB = 'COLAB_GPU' in os.environ or 'google.colab' in str(get_ipython())

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Paths for Colab
    DATA_ROOT = '/content/drive/MyDrive/enhanced_vehicle_detection/data/training_data_vehicles_in_rain'
    VIDEO_PATH = '/content/drive/MyDrive/enhanced_vehicle_detection/data/rainy_highway_video.mp4'
    OUTPUT_DIR = '/content/drive/MyDrive/enhanced_vehicle_detection/outputs/YOLO_SwinV2'
    YOLOV5M_METRICS_PATH = '/content/drive/MyDrive/enhanced_vehicle_detection/outputs/YOLO_V5m/YOLO_V5m_metrics.json'
else:
    # Paths for local environment
    DATA_ROOT = '../data/training_data_vehicles_in_rain'
    VIDEO_PATH = '../data/rainy_highway_video.mp4'
    OUTPUT_DIR = '../outputs/YOLO_SwinV2'
    YOLOV5M_METRICS_PATH = '../outputs/YOLO_V5m/YOLO_V5m_metrics.json'

# Create output directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, 'visualizations'), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, 'checkpoints'), exist_ok=True)

# ============================================================================
# RANDOM SEED FOR REPRODUCIBILITY
# ============================================================================
SEED = 42  # Change this value to get different reproducible results

def set_seed(seed):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU
    
    # For deterministic behavior (may slow down training slightly)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # Set environment variable for additional reproducibility
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(SEED)
print(f"Random seed set to: {SEED}")
# ============================================================================

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Data root: {DATA_ROOT}")
print(f"Output directory: {OUTPUT_DIR}")


## Training Hyperparameters

**Modify these values to adjust training behavior.** For a class project with ~90 min time limit, use 10-20 epochs.

| Parameter | Description | Recommended Range |
|-----------|-------------|-------------------|
| `num_epochs` | Number of training epochs | 10-20 for quick training, 50+ for better results |
| `batch_size` | Samples per batch | 8-32 (lower if GPU memory limited) |
| `learning_rate` | Initial learning rate | 1e-3 to 1e-4 |
| `optimizer` | Optimization algorithm | 'sgd' (faster) or 'adamw' (better convergence) |
| `img_size` | Input image size | **256** (required for SwinV2-Tiny window8_256) |

> **Note:** The SwinV2-Tiny model uses `swinv2_tiny_window8_256` which requires 256x256 input images. This also speeds up training compared to 640x640.

In [None]:
# ============================================================================
# TRAINING HYPERPARAMETERS - MODIFY THESE VALUES AS NEEDED
# ============================================================================

# Training parameters (adjust for training time vs. performance trade-off)
NUM_EPOCHS = 15          # Number of epochs (10-20 for ~90 min training, 50+ for better results)
BATCH_SIZE = 16          # Batch size (reduce to 8 if GPU memory is limited)
LEARNING_RATE = 1e-3     # Initial learning rate (1e-3 for SGD, 1e-4 for AdamW)
WEIGHT_DECAY = 0.01      # Weight decay for regularization

# Optimizer selection: 'sgd' for faster training, 'adamw' for better convergence
OPTIMIZER = 'sgd'        # Options: 'sgd', 'adamw'

# SGD-specific parameters (used only if OPTIMIZER = 'sgd')
SGD_MOMENTUM = 0.937     # Momentum for SGD optimizer

# Image and model parameters
# NOTE: SwinV2-Tiny (window8_256) requires input size of 256
# Using 256x256 also speeds up training significantly
IMG_SIZE = 256           # Input image size (must be 256 for swinv2_tiny_window8_256)

# Detection thresholds
# Note: Use lower confidence for models trained from scratch
CONF_THRESH = 0.25       # Confidence threshold for detection (0.25-0.4 for trained models)
IOU_THRESH = 0.45        # IoU threshold for NMS
IOU_MATCH_THRESH = 0.3   # IoU threshold for tracking

# ============================================================================

# Build CONFIG dictionary from hyperparameters
CONFIG = {
    'num_epochs': NUM_EPOCHS,
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
    'weight_decay': WEIGHT_DECAY,
    'optimizer': OPTIMIZER,
    'sgd_momentum': SGD_MOMENTUM,
    'img_size': IMG_SIZE,
    'conf_thresh': CONF_THRESH,
    'iou_thresh': IOU_THRESH,
    'iou_match_thresh': IOU_MATCH_THRESH,
    'vehicle_classes': {'car', 'truck', 'bus'},
    'num_classes': 3,  # car, truck, bus
    'seed': SEED,  # For reproducibility tracking
}

# Print configuration summary
print("=" * 50)
print("        TRAINING CONFIGURATION")
print("=" * 50)
print(f"  Epochs:         {CONFIG['num_epochs']}")
print(f"  Batch Size:     {CONFIG['batch_size']}")
print(f"  Learning Rate:  {CONFIG['learning_rate']}")
print(f"  Optimizer:      {CONFIG['optimizer'].upper()}")
print(f"  Image Size:     {CONFIG['img_size']}x{CONFIG['img_size']}")
print(f"  Conf Threshold: {CONFIG['conf_thresh']}")
print("=" * 50)


## Data Loading and Preparation

Load and prepare the AAU RainSnow dataset for training. The dataset is in COCO format and contains vehicle annotations for images captured in rain and snow conditions.


In [None]:
# COCO category IDs for vehicles
COCO_VEHICLE_CATS = {3: 'car', 6: 'bus', 8: 'truck'}
# Map COCO IDs to our class indices
COCO_TO_CLASS = {3: 0, 6: 1, 8: 2}  # car=0, bus=1, truck=2

def load_coco_annotations(json_path):
    """Load COCO format annotations and filter for vehicle classes."""
    with open(json_path, 'r') as f:
        coco_data = json.load(f)
    
    # Create image_id to filename mapping
    images = {img['id']: img for img in coco_data['images']}
    
    # Group annotations by image
    annotations_by_image = defaultdict(list)
    for ann in coco_data['annotations']:
        if ann['category_id'] in COCO_VEHICLE_CATS:
            annotations_by_image[ann['image_id']].append(ann)
    
    return images, annotations_by_image

class VehicleDataset(Dataset):
    """Dataset for vehicle detection from AAU RainSnow dataset."""
    
    def __init__(self, data_root, img_size=640, split='train', train_ratio=0.8):
        self.data_root = Path(data_root)
        self.img_size = img_size
        self.split = split
        
        # Load annotations
        json_path = self.data_root / 'aauRainSnow-rgb.json'
        self.images, self.annotations = load_coco_annotations(json_path)
        
        # Get list of image IDs with vehicle annotations
        self.image_ids = [img_id for img_id in self.annotations.keys() 
                         if len(self.annotations[img_id]) > 0]
        
        # Split into train/val
        np.random.seed(42)
        np.random.shuffle(self.image_ids)
        split_idx = int(len(self.image_ids) * train_ratio)
        
        if split == 'train':
            self.image_ids = self.image_ids[:split_idx]
        else:
            self.image_ids = self.image_ids[split_idx:]
        
        print(f"{split} set: {len(self.image_ids)} images with vehicle annotations")
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_info = self.images[img_id]
        
        # Construct image path
        img_path = self.data_root / img_info['file_name']
        
        # Read image
        img = cv2.imread(str(img_path))
        if img is None:
            # Return a blank image if not found
            img = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8)
            return self._preprocess_image(img), torch.zeros((0, 5))
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        orig_h, orig_w = img.shape[:2]
        
        # Resize image
        img = cv2.resize(img, (self.img_size, self.img_size))
        
        # Get annotations
        anns = self.annotations[img_id]
        targets = []
        
        for ann in anns:
            if ann['category_id'] not in COCO_TO_CLASS:
                continue
            
            class_id = COCO_TO_CLASS[ann['category_id']]
            bbox = ann['bbox']  # [x, y, width, height]
            
            # Convert to normalized [x_center, y_center, width, height]
            x_center = (bbox[0] + bbox[2] / 2) / orig_w
            y_center = (bbox[1] + bbox[3] / 2) / orig_h
            width = bbox[2] / orig_w
            height = bbox[3] / orig_h
            
            targets.append([class_id, x_center, y_center, width, height])
        
        img_tensor = self._preprocess_image(img)
        
        if len(targets) > 0:
            targets = torch.tensor(targets, dtype=torch.float32)
        else:
            targets = torch.zeros((0, 5))
        
        return img_tensor, targets
    
    def _preprocess_image(self, img):
        """Normalize and convert image to tensor."""
        img = img.astype(np.float32) / 255.0
        img = torch.from_numpy(img).permute(2, 0, 1)  # HWC -> CHW
        return img

def collate_fn(batch):
    """Custom collate function to handle variable number of targets."""
    imgs, targets = zip(*batch)
    imgs = torch.stack(imgs, 0)
    
    # Add batch index to targets
    batch_targets = []
    for i, t in enumerate(targets):
        if len(t) > 0:
            batch_idx = torch.full((len(t), 1), i)
            batch_targets.append(torch.cat([batch_idx, t], dim=1))
    
    if len(batch_targets) > 0:
        batch_targets = torch.cat(batch_targets, 0)
    else:
        batch_targets = torch.zeros((0, 6))
    
    return imgs, batch_targets

# Create datasets and dataloaders
print("Loading dataset...")
train_dataset = VehicleDataset(DATA_ROOT, img_size=CONFIG['img_size'], split='train')
val_dataset = VehicleDataset(DATA_ROOT, img_size=CONFIG['img_size'], split='val')

# Create a generator with the seed for reproducible shuffling
g = torch.Generator()
g.manual_seed(SEED)

def seed_worker(worker_id):
    """Ensure each DataLoader worker has a deterministic seed."""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], 
                          shuffle=True, collate_fn=collate_fn, num_workers=2,
                          worker_init_fn=seed_worker, generator=g)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], 
                        shuffle=False, collate_fn=collate_fn, num_workers=2,
                        worker_init_fn=seed_worker)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")


## Dataset Visualization

Visualize sample images with their ground truth annotations to verify the data loading pipeline is working correctly. This helps identify any issues with:
- Image loading and preprocessing
- Bounding box coordinate conversion
- Class label mapping


In [None]:
# Visualize sample images with annotations to verify data loading
print("Visualizing sample images with ground truth annotations...")

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

CLASS_COLORS = {0: (255, 0, 0), 1: (0, 255, 0), 2: (0, 0, 255)}  # car=red, bus=green, truck=blue
CLASS_NAMES_VIZ = {0: 'car', 1: 'bus', 2: 'truck'}

# Get a few samples from the training dataset
sample_indices = np.random.choice(len(train_dataset), min(6, len(train_dataset)), replace=False)

for i, idx in enumerate(sample_indices):
    img_tensor, targets = train_dataset[idx]
    
    # Convert tensor to numpy image
    img = img_tensor.permute(1, 2, 0).numpy()  # CHW -> HWC
    img = (img * 255).astype(np.uint8).copy()
    
    # Draw bounding boxes
    img_size = CONFIG['img_size']
    for target in targets:
        cls_id, x_center, y_center, width, height = target.numpy()
        cls_id = int(cls_id)
        
        # Convert normalized coordinates to pixel coordinates
        x1 = int((x_center - width / 2) * img_size)
        y1 = int((y_center - height / 2) * img_size)
        x2 = int((x_center + width / 2) * img_size)
        y2 = int((y_center + height / 2) * img_size)
        
        # Clip to image boundaries
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(img_size, x2), min(img_size, y2)
        
        color = CLASS_COLORS.get(cls_id, (255, 255, 255))
        # Convert RGB to BGR for cv2, then back to RGB for display
        cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
        cv2.putText(img, CLASS_NAMES_VIZ.get(cls_id, 'unknown'), (x1, y1 - 5),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    
    axes[i].imshow(img)
    axes[i].set_title(f"Sample {idx}: {len(targets)} annotations")
    axes[i].axis('off')

plt.suptitle('Training Dataset Samples with Ground Truth Annotations', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'visualizations', 'dataset_samples.png'), dpi=150, bbox_inches='tight')
plt.show()

# Print dataset statistics
print(f"\n{'='*50}")
print("           DATASET STATISTICS")
print(f"{'='*50}")
print(f"Training images: {len(train_dataset)}")
print(f"Validation images: {len(val_dataset)}")

# Count total annotations
total_train_annotations = 0
class_counts = {'car': 0, 'bus': 0, 'truck': 0}
for i in range(min(len(train_dataset), 100)):  # Sample first 100
    _, targets = train_dataset[i]
    total_train_annotations += len(targets)
    for t in targets:
        cls_id = int(t[0].item())
        cls_name = CLASS_NAMES_VIZ.get(cls_id, 'unknown')
        if cls_name in class_counts:
            class_counts[cls_name] += 1

print(f"\nAnnotations in first 100 training images:")
print(f"  - Cars: {class_counts['car']}")
print(f"  - Buses: {class_counts['bus']}")
print(f"  - Trucks: {class_counts['truck']}")
print(f"  - Total: {total_train_annotations}")
print(f"{'='*50}")


## SwinV2 Backbone Implementation

The SwinV2-Tiny (Swin Transformer V2) is used as the backbone to replace the original CSPDarknet backbone in YOLO. SwinV2 uses shifted window attention mechanism which is more effective for visual recognition tasks.

**Key advantages of SwinV2:**
- Hierarchical feature representation with multi-scale outputs
- Efficient window-based self-attention
- Better handling of various object scales
- Pre-trained on ImageNet for transfer learning

In [None]:
class SwinV2Backbone(nn.Module):
    """SwinV2-Tiny backbone for feature extraction."""
    
    def __init__(self, pretrained=True):
        super().__init__()
        
        # Load pre-trained SwinV2-Tiny
        self.swin = timm.create_model('swinv2_tiny_window8_256', 
                                       pretrained=pretrained,
                                       features_only=True,
                                       out_indices=(1, 2, 3))  # Multi-scale outputs
        
        # Get actual output channels from the model
        # SwinV2-Tiny output channels at different stages
        self.out_channels = self.swin.feature_info.channels()
        print(f"SwinV2 output channels: {self.out_channels}")
        
    def forward(self, x):
        # Get multi-scale features
        features = self.swin(x)
        
        # SwinV2 outputs in (B, H, W, C) format, convert to (B, C, H, W) for conv layers
        features = [f.permute(0, 3, 1, 2).contiguous() for f in features]
        
        return features  # List of [P3, P4, P5] features in NCHW format


class ConvBlock(nn.Module):
    """Convolutional block with BatchNorm and SiLU activation."""
    
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.SiLU(inplace=True)
    
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


class DetectionHead(nn.Module):
    """YOLO detection head for a single scale."""
    
    def __init__(self, in_channels, num_classes, num_anchors=3):
        super().__init__()
        self.num_classes = num_classes
        self.num_anchors = num_anchors
        
        # Output: (num_anchors * (5 + num_classes)) per grid cell
        # 5 = x, y, w, h, objectness
        out_channels = num_anchors * (5 + num_classes)
        
        self.conv = nn.Sequential(
            ConvBlock(in_channels, in_channels, 3, padding=1),
            ConvBlock(in_channels, in_channels, 3, padding=1),
            nn.Conv2d(in_channels, out_channels, 1)
        )
    
    def forward(self, x):
        return self.conv(x)


class YOLOSwinV2(nn.Module):
    """YOLO model with SwinV2-Tiny backbone."""
    
    def __init__(self, num_classes=3, pretrained_backbone=True):
        super().__init__()
        self.num_classes = num_classes
        self.num_anchors = 3
        
        # SwinV2 Backbone
        self.backbone = SwinV2Backbone(pretrained=pretrained_backbone)
        
        # Get actual channel sizes from backbone
        backbone_channels = self.backbone.out_channels
        print(f"Backbone output channels: {backbone_channels}")
        
        # Neck: Feature Pyramid Network (FPN) style
        # Adapt SwinV2 channels to YOLO neck channels
        self.adapt_p3 = ConvBlock(backbone_channels[0], 256, 1)   # P3
        self.adapt_p4 = ConvBlock(backbone_channels[1], 512, 1)   # P4
        self.adapt_p5 = ConvBlock(backbone_channels[2], 1024, 1)  # P5
        
        # Upsample for FPN
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        
        # Lateral connections
        self.lateral_p4 = ConvBlock(512 + 1024, 512, 1)
        self.lateral_p3 = ConvBlock(256 + 512, 256, 1)
        
        # Detection heads for each scale
        self.head_p3 = DetectionHead(256, num_classes, self.num_anchors)
        self.head_p4 = DetectionHead(512, num_classes, self.num_anchors)
        self.head_p5 = DetectionHead(1024, num_classes, self.num_anchors)
        
        # Anchors for each scale (scaled for 256x256 input)
        # Original anchors were for 640x640, scaled by 256/640 = 0.4
        self.anchors = torch.tensor([
            [[4, 5], [6, 12], [13, 9]],      # P3 (small objects)
            [[12, 24], [25, 18], [24, 48]],  # P4 (medium objects)
            [[46, 36], [62, 79], [149, 130]] # P5 (large objects)
        ], dtype=torch.float32)
        
    def forward(self, x):
        # Backbone features
        p3, p4, p5 = self.backbone(x)
        
        # Adapt channels
        p3 = self.adapt_p3(p3)
        p4 = self.adapt_p4(p4)
        p5 = self.adapt_p5(p5)
        
        # FPN top-down pathway
        p4 = self.lateral_p4(torch.cat([p4, self.upsample(p5)], dim=1))
        p3 = self.lateral_p3(torch.cat([p3, self.upsample(p4)], dim=1))
        
        # Detection outputs
        out_p3 = self.head_p3(p3)
        out_p4 = self.head_p4(p4)
        out_p5 = self.head_p5(p5)
        
        return [out_p3, out_p4, out_p5]
    
    def decode_predictions(self, outputs, conf_thresh=0.5, img_size=640):
        """Decode raw outputs to bounding boxes."""
        batch_size = outputs[0].shape[0]
        all_boxes = []
        
        for batch_idx in range(batch_size):
            boxes = []
            
            for scale_idx, output in enumerate(outputs):
                _, _, h, w = output.shape
                stride = img_size // h
                
                pred = output[batch_idx].view(self.num_anchors, 5 + self.num_classes, h, w)
                pred = pred.permute(0, 2, 3, 1).contiguous()
                
                # Get objectness and class scores
                obj = torch.sigmoid(pred[..., 4])
                cls_scores = torch.sigmoid(pred[..., 5:])
                
                # Find high confidence predictions
                mask = obj > conf_thresh
                
                if mask.sum() == 0:
                    continue
                
                # Get coordinates
                for anchor_idx in range(self.num_anchors):
                    for yi in range(h):
                        for xi in range(w):
                            if mask[anchor_idx, yi, xi]:
                                tx, ty = pred[anchor_idx, yi, xi, :2]
                                tw, th = pred[anchor_idx, yi, xi, 2:4]
                                
                                # Decode bbox
                                x = (torch.sigmoid(tx) + xi) * stride
                                y = (torch.sigmoid(ty) + yi) * stride
                                w_box = torch.exp(tw) * self.anchors[scale_idx, anchor_idx, 0]
                                h_box = torch.exp(th) * self.anchors[scale_idx, anchor_idx, 1]
                                
                                conf = obj[anchor_idx, yi, xi]
                                cls_score, cls_id = cls_scores[anchor_idx, yi, xi].max(0)
                                
                                x1 = x - w_box / 2
                                y1 = y - h_box / 2
                                x2 = x + w_box / 2
                                y2 = y + h_box / 2
                                
                                boxes.append([x1.item(), y1.item(), x2.item(), y2.item(), 
                                            (conf * cls_score).item(), cls_id.item()])
                
            all_boxes.append(boxes)
        
        return all_boxes

# Initialize model
print("Initializing YOLO-SwinV2 model...")
model = YOLOSwinV2(num_classes=CONFIG['num_classes'], pretrained_backbone=True)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")


## Training Loop

Train the YOLO-SwinV2 model on the AAU RainSnow dataset. The training loop tracks:
- Training loss per epoch
- Validation loss per epoch
- Number of detections
- Learning rate schedule

The trained model is saved to the checkpoints directory.

In [None]:
class YOLOLoss(nn.Module):
    """YOLO loss function with proper target assignment."""
    
    def __init__(self, num_classes=3, num_anchors=3, img_size=256):
        super().__init__()
        self.num_classes = num_classes
        self.num_anchors = num_anchors
        self.img_size = img_size
        
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
        self.mse = nn.MSELoss(reduction='none')
        self.ce = nn.CrossEntropyLoss(reduction='none')
        
        # Anchors scaled for 256x256 input (same as model)
        self.anchors = torch.tensor([
            [[4, 5], [6, 12], [13, 9]],      # P3 (stride 8)
            [[12, 24], [25, 18], [24, 48]],  # P4 (stride 16)
            [[46, 36], [62, 79], [149, 130]] # P5 (stride 32)
        ], dtype=torch.float32)
        
        self.strides = [8, 16, 32]  # Feature map strides
        
        # Loss weights
        self.lambda_obj = 1.0
        self.lambda_noobj = 0.5
        self.lambda_box = 5.0
        self.lambda_cls = 1.0
        
    def forward(self, predictions, targets):
        """
        predictions: list of tensors [P3, P4, P5]
        targets: tensor of shape [N, 6] where each row is [batch_idx, class_id, x, y, w, h]
        """
        device = predictions[0].device
        batch_size = predictions[0].shape[0]
        
        total_obj_loss = torch.tensor(0.0, device=device)
        total_box_loss = torch.tensor(0.0, device=device)
        total_cls_loss = torch.tensor(0.0, device=device)
        
        num_pos = 0
        
        for scale_idx, pred in enumerate(predictions):
            _, _, h, w = pred.shape
            stride = self.strides[scale_idx]
            anchors = self.anchors[scale_idx].to(device)
            
            # Reshape prediction
            pred = pred.view(batch_size, self.num_anchors, 5 + self.num_classes, h, w)
            pred = pred.permute(0, 1, 3, 4, 2).contiguous()
            
            # Create target tensors
            obj_mask = torch.zeros(batch_size, self.num_anchors, h, w, device=device)
            noobj_mask = torch.ones(batch_size, self.num_anchors, h, w, device=device)
            tx = torch.zeros(batch_size, self.num_anchors, h, w, device=device)
            ty = torch.zeros(batch_size, self.num_anchors, h, w, device=device)
            tw = torch.zeros(batch_size, self.num_anchors, h, w, device=device)
            th = torch.zeros(batch_size, self.num_anchors, h, w, device=device)
            tcls = torch.zeros(batch_size, self.num_anchors, h, w, device=device, dtype=torch.long)
            
            if len(targets) > 0:
                for target in targets:
                    b = int(target[0].item())
                    cls_id = int(target[1].item())
                    gx, gy, gw, gh = target[2:6]
                    
                    # Convert to feature map coordinates
                    gx_fm = gx.item() * w
                    gy_fm = gy.item() * h
                    gi = min(max(int(gx_fm), 0), w - 1)
                    gj = min(max(int(gy_fm), 0), h - 1)
                    
                    # Convert to pixel size
                    gw_px = gw.item() * self.img_size
                    gh_px = gh.item() * self.img_size
                    
                    # Find best anchor
                    anchor_ious = []
                    for a_idx in range(self.num_anchors):
                        aw, ah = anchors[a_idx]
                        inter_w = min(gw_px, aw.item())
                        inter_h = min(gh_px, ah.item())
                        inter = inter_w * inter_h
                        union = gw_px * gh_px + aw.item() * ah.item() - inter
                        anchor_ious.append(inter / (union + 1e-6))
                    
                    best_anchor = int(np.argmax(anchor_ious))
                    
                    # Set masks and targets
                    obj_mask[b, best_anchor, gj, gi] = 1
                    noobj_mask[b, best_anchor, gj, gi] = 0
                    tx[b, best_anchor, gj, gi] = gx_fm - gi
                    ty[b, best_anchor, gj, gi] = gy_fm - gj
                    aw, ah = anchors[best_anchor]
                    tw[b, best_anchor, gj, gi] = torch.log(torch.tensor(gw_px / aw.item() + 1e-6))
                    th[b, best_anchor, gj, gi] = torch.log(torch.tensor(gh_px / ah.item() + 1e-6))
                    tcls[b, best_anchor, gj, gi] = cls_id
                    num_pos += 1
            
            # Compute losses
            pred_obj = pred[..., 4]
            pred_xy = torch.sigmoid(pred[..., :2])
            pred_wh = pred[..., 2:4]
            pred_cls = pred[..., 5:]
            
            # Objectness loss
            obj_loss = self.bce(pred_obj, obj_mask)
            total_obj_loss += self.lambda_obj * (obj_loss * obj_mask).sum()
            total_obj_loss += self.lambda_noobj * (obj_loss * noobj_mask).sum()
            
            # Box and class loss for positive samples
            if obj_mask.sum() > 0:
                box_loss_x = (self.mse(pred_xy[..., 0], tx) * obj_mask).sum()
                box_loss_y = (self.mse(pred_xy[..., 1], ty) * obj_mask).sum()
                box_loss_w = (self.mse(pred_wh[..., 0], tw.to(device)) * obj_mask).sum()
                box_loss_h = (self.mse(pred_wh[..., 1], th.to(device)) * obj_mask).sum()
                total_box_loss += self.lambda_box * (box_loss_x + box_loss_y + box_loss_w + box_loss_h)
                
                pred_cls_flat = pred_cls[obj_mask == 1]
                tcls_flat = tcls[obj_mask == 1]
                if len(pred_cls_flat) > 0:
                    total_cls_loss += self.lambda_cls * self.ce(pred_cls_flat, tcls_flat).sum()
        
        num_pos = max(num_pos, 1)
        total_loss = (total_obj_loss + total_box_loss + total_cls_loss) / num_pos
        
        return total_loss, {
            'obj_loss': total_obj_loss.item() / num_pos,
            'box_loss': total_box_loss.item() / num_pos,
            'cls_loss': total_cls_loss.item() / num_pos
        }


def train_epoch(model, dataloader, optimizer, criterion, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    num_batches = 0
    
    pbar = tqdm(dataloader, desc='Training')
    for imgs, targets in pbar:
        imgs = imgs.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(imgs)
        
        loss, loss_items = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / max(num_batches, 1)


def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for imgs, targets in dataloader:
            imgs = imgs.to(device)
            targets = targets.to(device)
            
            outputs = model(imgs)
            loss, _ = criterion(outputs, targets)
            
            total_loss += loss.item()
            num_batches += 1
    
    return total_loss / max(num_batches, 1)


# Training setup
criterion = YOLOLoss(num_classes=CONFIG['num_classes'], img_size=CONFIG['img_size'])

# Select optimizer based on CONFIG
if CONFIG['optimizer'].lower() == 'sgd':
    # SGD with momentum - faster training, good for quick experiments
    optimizer = optim.SGD(
        model.parameters(), 
        lr=CONFIG['learning_rate'],
        momentum=CONFIG['sgd_momentum'],
        weight_decay=CONFIG['weight_decay'],
        nesterov=True
    )
    print(f"Using SGD optimizer with momentum={CONFIG['sgd_momentum']}")
else:
    # AdamW - better convergence, recommended for final training
    optimizer = optim.AdamW(
        model.parameters(), 
        lr=CONFIG['learning_rate'], 
        weight_decay=CONFIG['weight_decay']
    )
    print("Using AdamW optimizer")

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['num_epochs'])

# Training metrics storage
training_metrics = {
    'train_losses': [],
    'val_losses': [],
    'learning_rates': [],
    'epochs': []
}

print(f"\nStarting training for {CONFIG['num_epochs']} epochs...")
print("="*60)

best_val_loss = float('inf')

for epoch in range(CONFIG['num_epochs']):
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    
    # Validate
    val_loss = validate_epoch(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    
    # Store metrics
    training_metrics['train_losses'].append(train_loss)
    training_metrics['val_losses'].append(val_loss)
    training_metrics['learning_rates'].append(current_lr)
    training_metrics['epochs'].append(epoch + 1)
    
    # Print progress
    print(f"Epoch [{epoch+1}/{CONFIG['num_epochs']}] "
          f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint_path = os.path.join(OUTPUT_DIR, 'checkpoints', 'best_model.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, checkpoint_path)
        print(f"  -> Saved best model (val_loss: {val_loss:.4f})")

# Save final model
final_checkpoint_path = os.path.join(OUTPUT_DIR, 'checkpoints', 'final_model.pth')
torch.save({
    'epoch': CONFIG['num_epochs'],
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'training_metrics': training_metrics,
}, final_checkpoint_path)

print("="*60)
print(f"Training complete! Best validation loss: {best_val_loss:.4f}")
print(f"Models saved to: {os.path.join(OUTPUT_DIR, 'checkpoints')}")


## Training Metrics Visualization

Visualize the training progress including:
- Training and validation loss curves
- Learning rate schedule
- Comparison with YOLO-V5m baseline (if available)

In [None]:
viz_dir = os.path.join(OUTPUT_DIR, 'visualizations')

# Plot Training Metrics
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle('YOLO-SwinV2 Training Metrics', fontsize=14, fontweight='bold')

# 1. Training and Validation Loss
ax1 = axes[0]
ax1.plot(training_metrics['epochs'], training_metrics['train_losses'], 
         'b-', linewidth=2, label='Train Loss')
ax1.plot(training_metrics['epochs'], training_metrics['val_losses'], 
         'r-', linewidth=2, label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training & Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Learning Rate Schedule
ax2 = axes[1]
ax2.plot(training_metrics['epochs'], training_metrics['learning_rates'], 
         'g-', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Learning Rate')
ax2.set_title('Learning Rate Schedule')
ax2.grid(True, alpha=0.3)

# 3. Loss Convergence (Log Scale)
ax3 = axes[2]
ax3.semilogy(training_metrics['epochs'], training_metrics['train_losses'], 
             'b-', linewidth=2, label='Train Loss')
ax3.semilogy(training_metrics['epochs'], training_metrics['val_losses'], 
             'r-', linewidth=2, label='Val Loss')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Loss (log scale)')
ax3.set_title('Loss Convergence (Log Scale)')
ax3.legend()
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(viz_dir, 'training_curves.png'), dpi=150, bbox_inches='tight')
plt.show()

# Print Training Summary
print("\n" + "="*50)
print("         YOLO-SwinV2 Training Summary")
print("="*50)
print(f"\nTotal Epochs: {len(training_metrics['epochs'])}")
print(f"Final Train Loss: {training_metrics['train_losses'][-1]:.4f}")
print(f"Final Val Loss: {training_metrics['val_losses'][-1]:.4f}")
print(f"Best Val Loss: {min(training_metrics['val_losses']):.4f}")
print(f"Best Epoch: {training_metrics['val_losses'].index(min(training_metrics['val_losses'])) + 1}")
print("="*50)

# Save training metrics to JSON
training_export = {
    'model_name': 'YOLO-SwinV2',
    'config': {k: str(v) if isinstance(v, set) else v for k, v in CONFIG.items()},
    'training_metrics': training_metrics,
    'best_val_loss': min(training_metrics['val_losses']),
    'final_train_loss': training_metrics['train_losses'][-1],
    'final_val_loss': training_metrics['val_losses'][-1],
}

with open(os.path.join(OUTPUT_DIR, 'YOLO_SwinV2_training_metrics.json'), 'w') as f:
    json.dump(training_export, f, indent=2)

print(f"\nTraining metrics saved to: {os.path.join(OUTPUT_DIR, 'YOLO_SwinV2_training_metrics.json')}")


## Video Detection and Comparison

Apply the trained YOLO-SwinV2 model to the test video and compare detection metrics with the pre-trained YOLO-V5m baseline.


In [None]:
# Load best model for inference
best_model_path = os.path.join(OUTPUT_DIR, 'checkpoints', 'best_model.pth')
if os.path.exists(best_model_path):
    checkpoint = torch.load(best_model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch'] + 1}")

model.eval()

# Detection metrics storage
inference_metrics = {
    'frame_indices': [],
    'detections_per_frame': [],
    'confidence_scores': [],
    'class_counts': {'car': [], 'truck': [], 'bus': []},
}

CLASS_NAMES = {0: 'car', 1: 'bus', 2: 'truck'}

# Process video
print(f"\nProcessing video: {VIDEO_PATH}")

if os.path.exists(VIDEO_PATH):
    cap = cv2.VideoCapture(VIDEO_PATH)
    
    ret, frame = cap.read()
    if not ret:
        raise RuntimeError("Couldn't read the video")
    
    h, w = frame.shape[:2]
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 15
    
    out = cv2.VideoWriter(
        os.path.join(OUTPUT_DIR, "YOLO_SwinV2_highway_with_detection.avi"),
        cv2.VideoWriter_fourcc(*"XVID"),
        fps,
        (w, h)
    )
    
    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
    frame_index = 0
    
    pbar = tqdm(total=int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), desc='Processing video')
    
    with torch.no_grad():
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            frame_index += 1
            
            # Preprocess frame
            img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            img_resized = cv2.resize(img, (CONFIG['img_size'], CONFIG['img_size']))
            img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1) / 255.0
            img_tensor = img_tensor.unsqueeze(0).to(device)
            
            # Get predictions
            outputs = model(img_tensor)
            boxes = model.decode_predictions(outputs, conf_thresh=CONFIG['conf_thresh'], 
                                             img_size=CONFIG['img_size'])[0]
            
            # Collect metrics
            inference_metrics['frame_indices'].append(frame_index)
            inference_metrics['detections_per_frame'].append(len(boxes))
            
            class_count = {'car': 0, 'truck': 0, 'bus': 0}
            
            for box in boxes:
                x1, y1, x2, y2, conf, cls_id = box
                
                # Scale to original frame size
                scale_x = w / CONFIG['img_size']
                scale_y = h / CONFIG['img_size']
                x1, x2 = int(x1 * scale_x), int(x2 * scale_x)
                y1, y2 = int(y1 * scale_y), int(y2 * scale_y)
                
                # Clip to frame boundaries
                x1, y1 = max(0, x1), max(0, y1)
                x2, y2 = min(w, x2), min(h, y2)
                
                cls_name = CLASS_NAMES.get(int(cls_id), 'car')
                class_count[cls_name] += 1
                inference_metrics['confidence_scores'].append(conf)
                
                # Draw bounding box
                cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                cv2.putText(frame, f"{cls_name} ({conf:.2f})", (x1, y1 - 7),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
            
            for cls in class_count:
                inference_metrics['class_counts'][cls].append(class_count[cls])
            
            out.write(frame)
            pbar.update(1)
    
    pbar.close()
    cap.release()
    out.release()
    
    # Convert to MP4
    output_video_path = os.path.join(OUTPUT_DIR, 'highway_with_detection.mp4')
    avi_path = os.path.join(OUTPUT_DIR, "YOLO_SwinV2_highway_with_detection.avi")
    
    os.system(f'ffmpeg -y -i "{avi_path}" -vcodec libx264 -crf 23 -pix_fmt yuv420p "{output_video_path}" >/dev/null 2>&1')
    
    print(f"\nVideo saved to: {output_video_path}")
    print(f"Total frames processed: {frame_index}")
    print(f"Total detections: {sum(inference_metrics['detections_per_frame'])}")
    
    # Display the video in the notebook
    print("\nVehicle Detection Video for YOLO-SwinV2:")
    video_mp4 = open(output_video_path, 'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(video_mp4).decode()
    display(HTML(f'<video width=640 controls><source src="{data_url}" type="video/mp4"></video>'))
else:
    print(f"Video not found at: {VIDEO_PATH}")


## Model Comparison: YOLO-SwinV2 vs YOLO-V5m

Compare the inference metrics of the trained YOLO-SwinV2 model with the pre-trained YOLO-V5m baseline.


In [None]:
# Load YOLO-V5m metrics for comparison
yolov5m_metrics = None
if os.path.exists(YOLOV5M_METRICS_PATH):
    with open(YOLOV5M_METRICS_PATH, 'r') as f:
        yolov5m_metrics = json.load(f)
    print("Loaded YOLO-V5m metrics for comparison")
else:
    print(f"YOLO-V5m metrics not found at: {YOLOV5M_METRICS_PATH}")
    print("Run YOLO_V5m.ipynb first to generate baseline metrics")

# Calculate YOLO-SwinV2 summary metrics
swinv2_summary = {
    'total_frames': len(inference_metrics['frame_indices']),
    'total_detections': sum(inference_metrics['detections_per_frame']),
    'avg_detections_per_frame': np.mean(inference_metrics['detections_per_frame']) if inference_metrics['detections_per_frame'] else 0,
    'mean_confidence': np.mean(inference_metrics['confidence_scores']) if inference_metrics['confidence_scores'] else 0,
    'std_confidence': np.std(inference_metrics['confidence_scores']) if inference_metrics['confidence_scores'] else 0,
    'total_cars': sum(inference_metrics['class_counts']['car']),
    'total_trucks': sum(inference_metrics['class_counts']['truck']),
    'total_buses': sum(inference_metrics['class_counts']['bus']),
}

# Create comparison visualization
if yolov5m_metrics:
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('Model Comparison: YOLO-SwinV2 vs YOLO-V5m', fontsize=14, fontweight='bold')
    
    # 1. Total Detections Comparison
    ax1 = axes[0, 0]
    models = ['YOLO-V5m', 'YOLO-SwinV2']
    detections = [yolov5m_metrics['summary']['total_detections'], swinv2_summary['total_detections']]
    bars = ax1.bar(models, detections, color=['#2E86AB', '#A23B72'])
    ax1.set_ylabel('Total Detections')
    ax1.set_title('Total Detections Comparison')
    for bar, val in zip(bars, detections):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 10, 
                f'{val}', ha='center', va='bottom', fontweight='bold')
    
    # 2. Confidence Score Comparison
    ax2 = axes[0, 1]
    v5m_conf = [yolov5m_metrics['summary']['mean_confidence'], 
                yolov5m_metrics['summary']['std_confidence']]
    swinv2_conf = [swinv2_summary['mean_confidence'], swinv2_summary['std_confidence']]
    x = np.arange(2)
    width = 0.35
    ax2.bar(x - width/2, v5m_conf, width, label='YOLO-V5m', color='#2E86AB')
    ax2.bar(x + width/2, swinv2_conf, width, label='YOLO-SwinV2', color='#A23B72')
    ax2.set_xticks(x)
    ax2.set_xticklabels(['Mean Conf', 'Std Conf'])
    ax2.set_ylabel('Confidence Score')
    ax2.set_title('Confidence Score Comparison')
    ax2.legend()
    
    # 3. Class Distribution Comparison
    ax3 = axes[1, 0]
    x = np.arange(3)
    v5m_classes = [yolov5m_metrics['summary']['total_cars'],
                   yolov5m_metrics['summary']['total_trucks'],
                   yolov5m_metrics['summary']['total_buses']]
    swinv2_classes = [swinv2_summary['total_cars'],
                      swinv2_summary['total_trucks'],
                      swinv2_summary['total_buses']]
    ax3.bar(x - width/2, v5m_classes, width, label='YOLO-V5m', color='#2E86AB')
    ax3.bar(x + width/2, swinv2_classes, width, label='YOLO-SwinV2', color='#A23B72')
    ax3.set_xticks(x)
    ax3.set_xticklabels(['Cars', 'Trucks', 'Buses'])
    ax3.set_ylabel('Total Count')
    ax3.set_title('Detection by Class')
    ax3.legend()
    
    # 4. Detections per Frame Comparison
    ax4 = axes[1, 1]
    if len(inference_metrics['frame_indices']) > 0:
        # Use min length for comparison
        min_len = min(len(yolov5m_metrics['per_frame_data']['frame_indices']), 
                      len(inference_metrics['frame_indices']))
        ax4.plot(yolov5m_metrics['per_frame_data']['frame_indices'][:min_len],
                yolov5m_metrics['per_frame_data']['detections_per_frame'][:min_len],
                'b-', alpha=0.7, label='YOLO-V5m')
        ax4.plot(inference_metrics['frame_indices'][:min_len],
                inference_metrics['detections_per_frame'][:min_len],
                'r-', alpha=0.7, label='YOLO-SwinV2')
    ax4.set_xlabel('Frame')
    ax4.set_ylabel('Detections')
    ax4.set_title('Detections per Frame')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(viz_dir, 'model_comparison.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    # Print Comparison Summary
    print("\n" + "="*70)
    print("                    MODEL COMPARISON SUMMARY")
    print("="*70)
    print(f"\n{'Metric':<30} {'YOLO-V5m':<20} {'YOLO-SwinV2':<20}")
    print("-"*70)
    print(f"{'Total Frames':<30} {yolov5m_metrics['summary']['total_frames']:<20} {swinv2_summary['total_frames']:<20}")
    print(f"{'Total Detections':<30} {yolov5m_metrics['summary']['total_detections']:<20} {swinv2_summary['total_detections']:<20}")
    print(f"{'Avg Detections/Frame':<30} {yolov5m_metrics['summary']['avg_detections_per_frame']:<20.2f} {swinv2_summary['avg_detections_per_frame']:<20.2f}")
    print(f"{'Mean Confidence':<30} {yolov5m_metrics['summary']['mean_confidence']:<20.4f} {swinv2_summary['mean_confidence']:<20.4f}")
    print(f"{'Total Cars':<30} {yolov5m_metrics['summary']['total_cars']:<20} {swinv2_summary['total_cars']:<20}")
    print(f"{'Total Trucks':<30} {yolov5m_metrics['summary']['total_trucks']:<20} {swinv2_summary['total_trucks']:<20}")
    print(f"{'Total Buses':<30} {yolov5m_metrics['summary']['total_buses']:<20} {swinv2_summary['total_buses']:<20}")
    print("="*70)

# Export YOLO-SwinV2 inference metrics
swinv2_export = {
    'model_name': 'YOLO-SwinV2 (Trained)',
    'config': {k: str(v) if isinstance(v, set) else v for k, v in CONFIG.items()},
    'summary': swinv2_summary,
    'per_frame_data': {
        'frame_indices': inference_metrics['frame_indices'],
        'detections_per_frame': inference_metrics['detections_per_frame'],
        'class_counts': inference_metrics['class_counts']
    }
}

with open(os.path.join(OUTPUT_DIR, 'YOLO_SwinV2_metrics.json'), 'w') as f:
    json.dump(swinv2_export, f, indent=2)

print(f"\nYOLO-SwinV2 metrics exported to: {os.path.join(OUTPUT_DIR, 'YOLO_SwinV2_metrics.json')}")
