# Model Comparison: SimpleCNN vs RADIO

Compare two encoders on real Pascal VOC 2012 data:
- **SimpleCNN**: Lightweight baseline (~2M params)
- **RADIO**: NVIDIA foundation model (~300M params)

**Dataset**: Pascal VOC 2012 (download from Kaggle - see notebook 03)

## 1. Setup

In [None]:
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import time
import pandas as pd

from vlm_eval import EncoderRegistry, HeadRegistry, DatasetRegistry

print("✓ Setup complete")

## 2. Load Dataset

In [None]:
dataset = DatasetRegistry.get(
    "pascal_voc",
    root="./data/pascal_voc",
    split="val",
    subset_size=30,
    image_size=224  # Common size for fair comparison
)

dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
print(f"Dataset: {len(dataset)} samples")

## 3. Create Both Encoders

In [None]:
# SimpleCNN
encoder_cnn = EncoderRegistry.get("simple_cnn", variant="base", pretrained=False)
head_cnn = HeadRegistry.get("linear_probe", encoder=encoder_cnn, num_classes=21, freeze_encoder=False)

# RADIO
print("Loading RADIO...")
encoder_radio = EncoderRegistry.get("radio", variant="base", pretrained=True)
head_radio = HeadRegistry.get("linear_probe", encoder=encoder_radio, num_classes=21, freeze_encoder=True)

print("✓ Both models created")

## 4. Architecture Comparison

In [None]:
comparison = pd.DataFrame({
    "Metric": ["Parameters", "Params (M)", "Output Channels", "Patch Size", "Pretrained"],
    "SimpleCNN": [
        f"{encoder_cnn.get_num_parameters():,}",
        f"{encoder_cnn.get_num_parameters()/1e6:.2f}M",
        encoder_cnn.output_channels,
        encoder_cnn.patch_size,
        "No"
    ],
    "RADIO": [
        f"{encoder_radio.get_num_parameters():,}",
        f"{encoder_radio.get_num_parameters()/1e6:.2f}M",
        encoder_radio.output_channels,
        encoder_radio.patch_size,
        "Yes"
    ]
})

print("\n=== Architecture Comparison ===")
print(comparison.to_string(index=False))

# Visualize
fig, ax = plt.subplots(figsize=(10, 5))
params = [encoder_cnn.get_num_parameters()/1e6, encoder_radio.get_num_parameters()/1e6]
bars = ax.bar(['SimpleCNN', 'RADIO'], params, color=['#3498db', '#e74c3c'], alpha=0.7)
ax.set_ylabel('Parameters (Millions)')
ax.set_title('Model Size Comparison')
for bar, p in zip(bars, params):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height(),
            f'{p:.1f}M', ha='center', va='bottom', fontweight='bold')
plt.show()

## 5. Inference Speed Benchmark

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
head_cnn = head_cnn.to(device).eval()
head_radio = head_radio.to(device).eval()

test_batch = next(iter(dataloader))["image"].to(device)

# Warmup
with torch.no_grad():
    _ = head_cnn(encoder_cnn(test_batch))
    _ = head_radio(encoder_radio(test_batch))

# Benchmark
def benchmark(encoder, head, n=20):
    times = []
    with torch.no_grad():
        for _ in range(n):
            start = time.time()
            _ = head(encoder(test_batch))
            if device.type == 'cuda':
                torch.cuda.synchronize()
            times.append(time.time() - start)
    return np.mean(times) * 1000  # ms

time_cnn = benchmark(encoder_cnn, head_cnn)
time_radio = benchmark(encoder_radio, head_radio)

print(f"SimpleCNN: {time_cnn:.1f} ms")
print(f"RADIO: {time_radio:.1f} ms")
print(f"Speedup: {time_radio/time_cnn:.1f}x slower (RADIO)")

## 6. Side-by-Side Predictions

In [None]:
batch = next(iter(dataloader))
images = batch["image"].to(device)
masks = batch["mask"].to(device)

with torch.no_grad():
    preds_cnn = head_cnn(encoder_cnn(images)).argmax(dim=1)
    preds_radio = head_radio(encoder_radio(images)).argmax(dim=1)

# Visualize
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i in range(2):
    axes[i, 0].imshow(images[i].cpu().permute(1, 2, 0).numpy())
    axes[i, 0].set_title("Input")
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(masks[i].cpu().numpy(), cmap='tab20', vmin=0, vmax=20)
    axes[i, 1].set_title("Ground Truth")
    axes[i, 1].axis('off')
    
    axes[i, 2].imshow(preds_cnn[i].cpu().numpy(), cmap='tab20', vmin=0, vmax=20)
    axes[i, 2].set_title("SimpleCNN")
    axes[i, 2].axis('off')
    
    axes[i, 3].imshow(preds_radio[i].cpu().numpy(), cmap='tab20', vmin=0, vmax=20)
    axes[i, 3].set_title("RADIO")
    axes[i, 3].axis('off')

plt.suptitle("Model Comparison on Real Images", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 7. Feature Comparison

In [None]:
sample = images[0:1]
with torch.no_grad():
    feat_cnn = encoder_cnn(sample)
    feat_radio = encoder_radio(sample)

fig, axes = plt.subplots(3, 4, figsize=(16, 12))
axes[0, 0].imshow(sample[0].cpu().permute(1, 2, 0).numpy())
axes[0, 0].set_title("Input")
for j in range(1, 4):
    axes[0, j].axis('off')

for j in range(4):
    axes[1, j].imshow(feat_cnn[0, j].cpu().numpy(), cmap='viridis')
    axes[1, j].set_title(f"SimpleCNN Ch{j}")
    axes[1, j].axis('off')
    
    axes[2, j].imshow(feat_radio[0, j].cpu().numpy(), cmap='viridis')
    axes[2, j].set_title(f"RADIO Ch{j}")
    axes[2, j].axis('off')

plt.suptitle("Feature Comparison", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"SimpleCNN: {feat_cnn.shape}")
print(f"RADIO: {feat_radio.shape}")
print(f"RADIO has {feat_radio.shape[1]/feat_cnn.shape[1]:.1f}x more channels")

## Summary

### SimpleCNN
✅ Fast inference  
✅ Lightweight  
❌ Limited capacity  
❌ No pretrained weights  

### RADIO
✅ State-of-the-art features  
✅ Pretrained on massive data  
✅ Rich 1024-dim features  
❌ Slower inference  
❌ More memory  

### When to Use Each
- **SimpleCNN**: Fast prototyping, limited resources, simple tasks
- **RADIO**: Best accuracy, transfer learning, challenging tasks

The framework makes it easy to swap encoders and compare!