# Model Comparison: SimpleCNN vs RADIO

This notebook compares two different encoders for semantic segmentation:
1. **SimpleCNN**: Lightweight CNN encoder (framework's baseline)
2. **RADIO**: NVIDIA's foundation model (state-of-the-art)

We'll compare:
- Architecture and parameter counts
- Inference speed
- Prediction quality
- Feature representations

## 1. Setup and Imports

In [None]:
import sys
from pathlib import Path

# Add project to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import time

# Import framework components
from vlm_eval import EncoderRegistry, HeadRegistry, DatasetRegistry

print("✓ Imports successful!")

## 2. Load Dataset

Using Pascal VOC with a small subset for comparison.

In [None]:
# Create dataset
dataset = DatasetRegistry.get(
    "pascal_voc",
    root="./data/pascal_voc",
    split="val",
    download=True,
    subset_size=30,
    image_size=224  # Use 224 for fair comparison (both models support this)
)

dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0)

print(f"Dataset: {len(dataset)} samples")
print(f"Batches: {len(dataloader)}")

## 3. Create Both Encoders

Load SimpleCNN and RADIO encoders.

In [None]:
# Create SimpleCNN encoder
print("Creating SimpleCNN encoder...")
encoder_cnn = EncoderRegistry.get("simple_cnn", variant="base", pretrained=False)

# Create RADIO encoder
print("\nCreating RADIO encoder (may take a moment)...")
encoder_radio = EncoderRegistry.get("radio", variant="base", pretrained=True)

print("\n✓ Both encoders created successfully!")

## 4. Architecture Comparison

Compare model architectures and parameter counts.

In [None]:
import pandas as pd

# Gather architecture info
comparison_data = {
    "Metric": [
        "Model Name",
        "Total Parameters",
        "Parameters (M)",
        "Output Channels",
        "Patch Size",
        "Pretrained"
    ],
    "SimpleCNN": [
        "SimpleCNN (Base)",
        f"{encoder_cnn.get_num_parameters():,}",
        f"{encoder_cnn.get_num_parameters() / 1e6:.2f}M",
        encoder_cnn.output_channels,
        encoder_cnn.patch_size,
        "No"
    ],
    "RADIO": [
        "NVIDIA RADIO",
        f"{encoder_radio.get_num_parameters():,}",
        f"{encoder_radio.get_num_parameters() / 1e6:.2f}M",
        encoder_radio.output_channels,
        encoder_radio.patch_size,
        "Yes (ImageNet+)"
    ]
}

df = pd.DataFrame(comparison_data)
print("\n=== Architecture Comparison ===")
print(df.to_string(index=False))

# Visualize parameter counts
fig, ax = plt.subplots(figsize=(10, 5))
models = ['SimpleCNN', 'RADIO']
params = [
    encoder_cnn.get_num_parameters() / 1e6,
    encoder_radio.get_num_parameters() / 1e6
]
colors = ['#3498db', '#e74c3c']

bars = ax.bar(models, params, color=colors, alpha=0.7)
ax.set_ylabel('Parameters (Millions)', fontsize=12)
ax.set_title('Model Size Comparison', fontsize=14, fontweight='bold')
ax.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bar, param in zip(bars, params):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{param:.1f}M',
            ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.show()

## 5. Create Segmentation Heads

Add linear probe heads to both encoders.

In [None]:
# Create heads
head_cnn = HeadRegistry.get(
    "linear_probe",
    encoder=encoder_cnn,
    num_classes=21,
    freeze_encoder=False
)

head_radio = HeadRegistry.get(
    "linear_probe",
    encoder=encoder_radio,
    num_classes=21,
    freeze_encoder=True  # Freeze RADIO for linear probing
)

print("SimpleCNN Head:")
print(f"  Total params: {head_cnn.get_num_parameters():,}")
print(f"  Trainable params: {head_cnn.get_num_parameters(trainable_only=True):,}")

print("\nRADIO Head:")
print(f"  Total params: {head_radio.get_num_parameters():,}")
print(f"  Trainable params: {head_radio.get_num_parameters(trainable_only=True):,}")

## 6. Inference Speed Comparison

Measure inference time for both models.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Move models to device
head_cnn = head_cnn.to(device)
head_radio = head_radio.to(device)
head_cnn.eval()
head_radio.eval()

# Get a test batch
test_batch = next(iter(dataloader))
test_images = test_batch["image"].to(device)

# Warmup
with torch.no_grad():
    _ = head_cnn(encoder_cnn(test_images))
    _ = head_radio(encoder_radio(test_images))

# Benchmark SimpleCNN
num_runs = 20
times_cnn = []
with torch.no_grad():
    for _ in range(num_runs):
        start = time.time()
        features = encoder_cnn(test_images)
        _ = head_cnn(features)
        if device.type == 'cuda':
            torch.cuda.synchronize()
        times_cnn.append(time.time() - start)

# Benchmark RADIO
times_radio = []
with torch.no_grad():
    for _ in range(num_runs):
        start = time.time()
        features = encoder_radio(test_images)
        _ = head_radio(features)
        if device.type == 'cuda':
            torch.cuda.synchronize()
        times_radio.append(time.time() - start)

# Results
avg_time_cnn = np.mean(times_cnn) * 1000  # Convert to ms
avg_time_radio = np.mean(times_radio) * 1000

print(f"\n=== Inference Speed (batch_size={test_images.shape[0]}) ===")
print(f"SimpleCNN: {avg_time_cnn:.2f} ms/batch")
print(f"RADIO: {avg_time_radio:.2f} ms/batch")
print(f"Speedup: {avg_time_radio / avg_time_cnn:.2f}x slower (RADIO)")

# Visualize
fig, ax = plt.subplots(figsize=(10, 5))
models = ['SimpleCNN', 'RADIO']
times = [avg_time_cnn, avg_time_radio]
colors = ['#2ecc71', '#e67e22']

bars = ax.bar(models, times, color=colors, alpha=0.7)
ax.set_ylabel('Inference Time (ms)', fontsize=12)
ax.set_title('Inference Speed Comparison', fontsize=14, fontweight='bold')
ax.grid(axis='y', alpha=0.3)

for bar, t in zip(bars, times):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{t:.1f}ms',
            ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.show()

## 7. Side-by-Side Predictions

Compare predictions from both models on the same images.

In [None]:
# Get predictions from both models
batch = next(iter(dataloader))
images = batch["image"].to(device)
masks = batch["mask"].to(device)

with torch.no_grad():
    # SimpleCNN predictions
    features_cnn = encoder_cnn(images)
    logits_cnn = head_cnn(features_cnn)
    preds_cnn = logits_cnn.argmax(dim=1)
    
    # RADIO predictions
    features_radio = encoder_radio(images)
    logits_radio = head_radio(features_radio)
    preds_radio = logits_radio.argmax(dim=1)

# Visualize first 2 samples
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(2):
    # Input image
    axes[i, 0].imshow(images[i].cpu().permute(1, 2, 0).numpy())
    axes[i, 0].set_title("Input Image", fontsize=10)
    axes[i, 0].axis('off')
    
    # Ground truth
    axes[i, 1].imshow(masks[i].cpu().numpy(), cmap='tab20', vmin=0, vmax=20)
    axes[i, 1].set_title("Ground Truth", fontsize=10)
    axes[i, 1].axis('off')
    
    # SimpleCNN prediction
    axes[i, 2].imshow(preds_cnn[i].cpu().numpy(), cmap='tab20', vmin=0, vmax=20)
    axes[i, 2].set_title("SimpleCNN Pred", fontsize=10)
    axes[i, 2].axis('off')
    
    # RADIO prediction
    axes[i, 3].imshow(preds_radio[i].cpu().numpy(), cmap='tab20', vmin=0, vmax=20)
    axes[i, 3].set_title("RADIO Pred", fontsize=10)
    axes[i, 3].axis('off')

plt.suptitle("Model Comparison: Predictions (Untrained Heads)", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Note: Both heads are untrained, so predictions are random.")
print("After training, RADIO typically shows superior performance.")

## 8. Feature Comparison

Compare the spatial features extracted by both encoders.

In [None]:
# Get features from both encoders
sample_idx = 0
sample_image = images[sample_idx:sample_idx+1]

with torch.no_grad():
    features_cnn = encoder_cnn(sample_image)
    features_radio = encoder_radio(sample_image)

# Visualize feature maps
fig, axes = plt.subplots(3, 4, figsize=(16, 12))

# Original image
axes[0, 0].imshow(sample_image[0].cpu().permute(1, 2, 0).numpy())
axes[0, 0].set_title("Input Image", fontsize=11, fontweight='bold')
axes[0, 0].axis('off')
for j in range(1, 4):
    axes[0, j].axis('off')

# SimpleCNN features
feat_cnn = features_cnn[0].cpu().numpy()
for j in range(4):
    axes[1, j].imshow(feat_cnn[j], cmap='viridis')
    axes[1, j].set_title(f"SimpleCNN Ch{j}", fontsize=10)
    axes[1, j].axis('off')

# RADIO features
feat_radio = features_radio[0].cpu().numpy()
for j in range(4):
    axes[2, j].imshow(feat_radio[j], cmap='viridis')
    axes[2, j].set_title(f"RADIO Ch{j}", fontsize=10)
    axes[2, j].axis('off')

plt.suptitle("Feature Map Comparison (First 4 Channels)", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"SimpleCNN feature shape: {features_cnn.shape}")
print(f"RADIO feature shape: {features_radio.shape}")
print(f"\nRADIO has {features_radio.shape[1] / features_cnn.shape[1]:.1f}x more channels")

## 9. Summary Comparison Table

Comprehensive comparison of all metrics.

In [None]:
summary_data = {
    "Metric": [
        "Parameters (M)",
        "Output Channels",
        "Spatial Resolution",
        "Inference Time (ms)",
        "Pretrained",
        "Best Use Case"
    ],
    "SimpleCNN": [
        f"{encoder_cnn.get_num_parameters() / 1e6:.2f}M",
        encoder_cnn.output_channels,
        f"{features_cnn.shape[2]}x{features_cnn.shape[3]}",
        f"{avg_time_cnn:.1f}",
        "No",
        "Fast inference, limited data"
    ],
    "RADIO": [
        f"{encoder_radio.get_num_parameters() / 1e6:.2f}M",
        encoder_radio.output_channels,
        f"{features_radio.shape[2]}x{features_radio.shape[3]}",
        f"{avg_time_radio:.1f}",
        "Yes (Foundation)",
        "Best accuracy, rich features"
    ]
}

df_summary = pd.DataFrame(summary_data)
print("\n" + "="*60)
print("COMPREHENSIVE MODEL COMPARISON")
print("="*60)
print(df_summary.to_string(index=False))
print("="*60)

## Summary

### Key Findings

**SimpleCNN:**
- ✅ **Fast**: Much faster inference time
- ✅ **Lightweight**: Fewer parameters, less memory
- ✅ **Simple**: Easy to train from scratch
- ❌ **Limited**: Lower capacity, no pretrained weights

**RADIO:**
- ✅ **Powerful**: State-of-the-art foundation model
- ✅ **Rich Features**: High-dimensional spatial features
- ✅ **Pretrained**: Excellent transfer learning
- ❌ **Slower**: Higher computational cost
- ❌ **Large**: More parameters and memory

### When to Use Each Model

**Use SimpleCNN when:**
- You need fast inference
- You have limited computational resources
- You're prototyping or testing the framework
- You have a simple segmentation task

**Use RADIO when:**
- You need best possible accuracy
- You have limited training data (leverage pretrained features)
- You're working on challenging segmentation tasks
- Computational cost is not a primary concern

### Next Steps

1. **Train both models**: Compare performance after training
2. **Try other encoders**: DINOv2, CLIP, SAM, etc.
3. **Experiment with datasets**: Test on different segmentation benchmarks
4. **Optimize inference**: Quantization, pruning, distillation

The VLM evaluation framework makes it easy to swap encoders and compare them fairly!