# Transfer Learning Comparison on Caltech-256: CNNs vs Vision Transformers

This notebook implements a comprehensive comparison of pretrained CNN and Vision Transformer
architectures for transfer learning on the Caltech-256 dataset.

## Models Compared
| Family | Models |
|--------|--------|
| **CNNs** | ResNet-50, ResNet-101, VGG-16, EfficientNet-B0, MobileNet-V2, InceptionV3 |
| **Transformers** | ViT-B/16, ViT-L/16, Swin Transformer, DeiT-Base |

## Sections
1. [Configuration](#1.-Configuration)
2. [Setup & Imports](#2.-Setup-&-Imports)
3. [Dataset Loading](#3.-Dataset-Loading)
4. [Model Preparation](#4.-Model-Preparation)
5. [Training Pipeline](#5.-Training-Pipeline)
6. [Evaluation](#6.-Evaluation)
7. [Visualisation & Comparison](#7.-Visualisation-&-Comparison)
8. [Results Summary](#8.-Results-Summary)
9. [Conclusions](#9.-Conclusions)

## 1. Configuration

Edit the values in this cell to customise training behaviour.

In [None]:
# ---------------------------------------------------------------
# ⚙️  CONFIGURATION – edit this cell before running the notebook
# ---------------------------------------------------------------

# --- Dataset ---
DATA_DIR = "/path/to/caltech-256"   # Root directory that contains class folders
                                     # e.g. 001.ak47/, 002.american-flag/, ...

# --- Splits ---
TRAIN_RATIO = 0.70   # 70 % training
VAL_RATIO   = 0.15   # 15 % validation  → 15 % test (remainder)

# --- Training ---
BATCH_SIZE  = 32
NUM_EPOCHS  = 30
LR          = 1e-3
WEIGHT_DECAY = 1e-4
PATIENCE    = 7       # Early-stopping patience (epochs without val_loss improvement)

# --- Transfer learning mode ---
FREEZE_BACKBONE = True   # True  → feature extraction (only classifier trains)
                          # False → full fine-tuning (all layers train)
UNFREEZE_AFTER  = 5      # Epochs before unfreezing last UNFREEZE_N_LAYERS layers
                          # (set to None to keep backbone frozen throughout)
UNFREEZE_N_LAYERS = 3    # Number of backbone layer groups to unfreeze after UNFREEZE_AFTER epochs

# --- Augmentation ---
AUGMENT_TRAIN = True

# --- Data loader ---
NUM_WORKERS = 4

# --- Reproducibility ---
SEED = 42

# --- Checkpoints & results ---
CHECKPOINT_DIR = "checkpoints"
RESULTS_DIR    = "results"

# --- Models to run (subset if you want a quick test) ---
# All supported keys:
#   CNNs        : "resnet50", "resnet101", "vgg16", "efficientnet_b0",
#                 "mobilenet_v2", "inception_v3"
#   Transformers: "vit_b_16", "vit_l_16", "swin_t", "deit_base_patch16_224"
MODELS_TO_RUN = [
    "resnet50",
    "resnet101",
    "vgg16",
    "efficientnet_b0",
    "mobilenet_v2",
    "inception_v3",
    "vit_b_16",
    "vit_l_16",
    "swin_t",
    "deit_base_patch16_224",
]

## 2. Setup & Imports

In [None]:
import sys
import os
from pathlib import Path

# Add the repo root to sys.path so that models.py / utils.py are importable
repo_root = Path(".").resolve()
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams["figure.dpi"] = 100

# Local modules
from models import (
    get_model,
    list_models,
    count_all_parameters,
    freeze_backbone,
    unfreeze_last_n_layers,
    INPUT_SIZES,
    MODEL_NAMES,
)
from utils import (
    set_seed,
    load_caltech256,
    train_model,
    evaluate_model,
    measure_inference_time,
    build_summary_df,
    save_summary_csv,
    plot_training_curves,
    plot_accuracy_comparison,
    plot_inference_time_comparison,
    plot_size_vs_accuracy,
    plot_confusion_matrix,
)

# ── Device ──────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# ── Reproducibility ─────────────────────────────────────────
set_seed(SEED)

print(f"\nAvailable models: {list_models()}")
print(f"Models selected for this run: {MODELS_TO_RUN}")

## 3. Dataset Loading

Caltech-256 is expected to be organised as:
```
DATA_DIR/
    001.ak47/
        image_0001.jpg
        ...
    002.american-flag/
        ...
    ...
    257.clutter/
```

The dataset is split **once** (reproducibly) using `SEED` into train / val / test subsets.
Data augmentation (random crop, flip, colour jitter, rotation) is applied only to the
training split.

In [None]:
# ── Load Caltech-256 ────────────────────────────────────────
# Note: inception_v3 requires 299×299 input; all other models use 224×224.
# We create one loader per unique input size.
# For simplicity, and because most models share 224×224, we load a single
# set of loaders at 224×224 and re-use them for all models except InceptionV3.

print("Loading Caltech-256 …")

train_loader_224, val_loader_224, test_loader_224, class_names = load_caltech256(
    data_dir=DATA_DIR,
    input_size=224,
    train_ratio=TRAIN_RATIO,
    val_ratio=VAL_RATIO,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    seed=SEED,
    augment=AUGMENT_TRAIN,
)

# Dedicated 299×299 loaders for InceptionV3
train_loader_299, val_loader_299, test_loader_299, _ = load_caltech256(
    data_dir=DATA_DIR,
    input_size=299,
    train_ratio=TRAIN_RATIO,
    val_ratio=VAL_RATIO,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    seed=SEED,
    augment=AUGMENT_TRAIN,
)

NUM_CLASSES = len(class_names)
print(f"Number of classes: {NUM_CLASSES}")
print(f"Class names (first 10): {class_names[:10]}")

# Helper: pick the right loaders for a model key
def get_loaders(model_key: str):
    if model_key == "inception_v3":
        return train_loader_299, val_loader_299, test_loader_299
    return train_loader_224, val_loader_224, test_loader_224

In [None]:
# ── Visualise a batch ───────────────────────────────────────
import torchvision.utils as vutils

def imshow_batch(loader, title="Sample batch", n=8):
    imgs, labels = next(iter(loader))
    # Denormalise
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
    imgs = (imgs[:n] * std + mean).clamp(0, 1)
    grid = vutils.make_grid(imgs, nrow=n, padding=2)
    fig, ax = plt.subplots(figsize=(16, 2.5))
    ax.imshow(grid.permute(1, 2, 0).numpy())
    ax.set_title(title)
    ax.axis("off")
    plt.tight_layout()
    plt.show()
    label_names = [class_names[l] for l in labels[:n].tolist()]
    print("Labels:", label_names)

imshow_batch(train_loader_224, title="Training batch sample (224×224)")

## 4. Model Preparation

For each selected model we:
1. Load pretrained ImageNet weights via `torchvision.models` (or `timm` for DeiT).
2. Replace the final classification head with a new linear layer for `NUM_CLASSES` outputs.
3. Optionally freeze the backbone so that only the head is trained in the first phase.

The `get_model()` factory handles all architecture-specific head replacements.

In [None]:
# ── Instantiate all selected models ─────────────────────────
models_dict: dict = {}
params_dict:  dict = {}

for key in MODELS_TO_RUN:
    print(f"Loading {MODEL_NAMES.get(key, key)} …", end="  ")
    m = get_model(key, pretrained=True, freeze=FREEZE_BACKBONE, device=device)
    models_dict[key] = m
    params_dict[key] = count_all_parameters(m)
    trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"total params={params_dict[key]:,}  trainable={trainable:,}")

print("\nAll models loaded successfully.")

In [None]:
# ── Parameter summary table ─────────────────────────────────
param_rows = [
    {"Model": MODEL_NAMES.get(k, k), "Total Params (M)": f"{v/1e6:.2f}"}
    for k, v in params_dict.items()
]
pd.DataFrame(param_rows).sort_values("Total Params (M)", ascending=False)

## 5. Training Pipeline

Each model is trained with:
- **Loss**: Cross-Entropy
- **Optimiser**: Adam (`lr=LR`, `weight_decay=WEIGHT_DECAY`)
- **LR Scheduler**: `ReduceLROnPlateau` (halves LR when val_loss stagnates for 3 epochs)
- **Early stopping**: stops if val_loss doesn't improve for `PATIENCE` epochs
- **Checkpointing**: saves the model state with the best validation accuracy to `CHECKPOINT_DIR/`

If `UNFREEZE_AFTER` is not `None`, the last `UNFREEZE_N_LAYERS` backbone layer groups are
unfrozen after that many epochs (gradual fine-tuning).

In [None]:
# ── Training loop (iterates over all selected models) ────────
import time

all_histories: dict = {}

for key in MODELS_TO_RUN:
    model = models_dict[key]
    train_ldr, val_ldr, test_ldr = get_loaders(key)

    # ─── Phase 1: feature extraction ─────────────────────────
    phase1_epochs = UNFREEZE_AFTER if UNFREEZE_AFTER is not None else NUM_EPOCHS
    print(f"\n{'='*60}")
    print(f" Training {MODEL_NAMES.get(key, key)}  (phase 1 – {phase1_epochs} epochs)")
    print(f"{'='*60}")

    history = train_model(
        model, train_ldr, val_ldr,
        model_key=key,
        num_epochs=phase1_epochs,
        lr=LR,
        weight_decay=WEIGHT_DECAY,
        patience=PATIENCE,
        checkpoint_dir=CHECKPOINT_DIR,
        device=device,
    )

    # ─── Phase 2: fine-tuning (optional) ─────────────────────
    if UNFREEZE_AFTER is not None and NUM_EPOCHS > UNFREEZE_AFTER:
        print(f"\n --- Phase 2: unfreezing last {UNFREEZE_N_LAYERS} layer groups ---")
        unfreeze_last_n_layers(model, key, UNFREEZE_N_LAYERS)
        trainable_now = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"     Trainable params after unfreeze: {trainable_now:,}")

        history2 = train_model(
            model, train_ldr, val_ldr,
            model_key=key,
            num_epochs=NUM_EPOCHS - UNFREEZE_AFTER,
            lr=LR / 10,          # smaller LR for fine-tuning
            weight_decay=WEIGHT_DECAY,
            patience=PATIENCE,
            checkpoint_dir=CHECKPOINT_DIR,
            device=device,
            resume_checkpoint=history["checkpoint_path"],
        )
        # Merge histories
        for metric in ("train_loss", "val_loss", "train_acc", "val_acc", "epoch_times"):
            history[metric].extend(history2[metric])
        history["total_time"] = history.get("total_time", 0) + history2["total_time"]
        if history2["best_val_acc"] > history["best_val_acc"]:
            history["best_val_acc"] = history2["best_val_acc"]
            history["best_epoch"] = len(history["train_loss"])
            history["checkpoint_path"] = history2["checkpoint_path"]

    all_histories[key] = history
    print(f"\n✓ {MODEL_NAMES.get(key, key)} done."
          f"  best_val_acc={history['best_val_acc']:.4f}"
          f"  total_time={history['total_time']:.0f}s")

## 6. Evaluation

We evaluate each model on the held-out **test set** and record:
- **Top-1 accuracy** – fraction of correctly classified images
- **Top-5 accuracy** – fraction of images whose true class is in the top-5 predictions
- **Inference time** – average ms per image (single-image benchmark on the same device)
- **Per-class metrics** – precision, recall, F1-score from scikit-learn
- **Confusion matrix** – saved as PNG for the best-performing model

In [None]:
# ── Evaluate all models ──────────────────────────────────────
all_results: dict = {}

for key in MODELS_TO_RUN:
    model = models_dict[key]
    _, _, test_ldr = get_loaders(key)
    input_sz = INPUT_SIZES[key]

    print(f"\nEvaluating {MODEL_NAMES.get(key, key)} …")

    # Load best checkpoint
    ckpt_path = all_histories[key]["checkpoint_path"]
    if Path(ckpt_path).exists():
        ckpt = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(ckpt["model_state"])
        print(f"  Loaded checkpoint: {ckpt_path}")

    eval_res = evaluate_model(model, test_ldr, class_names, device=device)

    # Standalone inference benchmark
    eval_res["avg_inference_ms"] = measure_inference_time(
        model, input_size=input_sz, n_runs=50, batch_size=1, device=device,
    )
    eval_res["total_time"]    = all_histories[key]["total_time"]
    eval_res["best_val_acc"]  = all_histories[key]["best_val_acc"]

    all_results[key] = eval_res

    print(f"  Top-1: {eval_res['top1_acc']*100:.2f}%  "
          f"Top-5: {eval_res['top5_acc']*100:.2f}%  "
          f"Inf: {eval_res['avg_inference_ms']:.2f} ms/img")

## 7. Visualisation & Comparison

All plots are saved to `RESULTS_DIR/plots/`.

In [None]:
# ── Training curves ─────────────────────────────────────────
plots_dir = Path(RESULTS_DIR) / "plots"

for key, history in all_histories.items():
    plot_training_curves(history, key, save_dir=plots_dir)

print("Training curves saved.")

# ── Display training curves inline for each model ───────────
fig, axes = plt.subplots(len(MODELS_TO_RUN), 2,
                          figsize=(14, 4 * len(MODELS_TO_RUN)))
if len(MODELS_TO_RUN) == 1:
    axes = [axes]

for ax_row, key in zip(axes, MODELS_TO_RUN):
    h = all_histories[key]
    epochs = range(1, len(h["train_loss"]) + 1)
    ax_row[0].plot(epochs, h["train_loss"], label="Train")
    ax_row[0].plot(epochs, h["val_loss"],   label="Val")
    ax_row[0].set_title(f"{MODEL_NAMES.get(key,key)} – Loss")
    ax_row[0].set_xlabel("Epoch"); ax_row[0].set_ylabel("Loss")
    ax_row[0].legend()

    ax_row[1].plot(epochs, [a*100 for a in h["train_acc"]], label="Train")
    ax_row[1].plot(epochs, [a*100 for a in h["val_acc"]],   label="Val")
    ax_row[1].set_title(f"{MODEL_NAMES.get(key,key)} – Accuracy")
    ax_row[1].set_xlabel("Epoch"); ax_row[1].set_ylabel("Accuracy (%)")
    ax_row[1].legend()

fig.tight_layout()
plt.show()

In [None]:
# ── Summary DataFrame ────────────────────────────────────────
summary_df = build_summary_df(all_results, params_dict)
print(summary_df.to_string(index=False))

# Save CSV
save_summary_csv(summary_df, path=Path(RESULTS_DIR) / "summary.csv")

In [None]:
# ── Accuracy comparison bar chart ───────────────────────────
plot_accuracy_comparison(summary_df, save_dir=plots_dir)

fig, ax = plt.subplots(figsize=(10, 5))
colors = ["steelblue" if any(t in m for t in ["ViT", "Swin", "DeiT"])
          else "salmon" for m in summary_df["Model"]]
ax.bar(summary_df["Model"], summary_df["Top-1 Acc (%)"], color=colors)
ax.set_xlabel("Model")
ax.set_ylabel("Top-1 Accuracy (%)")
ax.set_title("Top-1 Accuracy – CNNs (red) vs Transformers (blue)")
ax.set_xticklabels(summary_df["Model"], rotation=35, ha="right")
fig.tight_layout()
plt.show()

In [None]:
# ── Inference time comparison ───────────────────────────────
plot_inference_time_comparison(summary_df, save_dir=plots_dir)

fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(summary_df["Model"], summary_df["Inf. Time (ms)"], color="mediumseagreen")
ax.set_xlabel("Model")
ax.set_ylabel("Avg Inference Time (ms / image)")
ax.set_title("Inference Time Comparison")
ax.set_xticklabels(summary_df["Model"], rotation=35, ha="right")
fig.tight_layout()
plt.show()

In [None]:
# ── Model size vs accuracy scatter ──────────────────────────
plot_size_vs_accuracy(summary_df, save_dir=plots_dir)

fig, ax = plt.subplots(figsize=(8, 6))
ax.scatter(summary_df["Params (M)"], summary_df["Top-1 Acc (%)"], s=80)
for _, row in summary_df.iterrows():
    ax.annotate(row["Model"], (row["Params (M)"], row["Top-1 Acc (%)"]),
                textcoords="offset points", xytext=(5, 5), fontsize=9)
ax.set_xlabel("Parameters (M)")
ax.set_ylabel("Top-1 Accuracy (%)")
ax.set_title("Model Size vs Accuracy Trade-off")
fig.tight_layout()
plt.show()

In [None]:
# ── Confusion matrix for the best model ─────────────────────
best_key = summary_df.iloc[0]["Model"]
# Reverse MODEL_NAMES to find the key
rev_names = {v: k for k, v in MODEL_NAMES.items()}
best_model_key = rev_names.get(best_key, MODELS_TO_RUN[0])

print(f"Best model: {best_key} (key={best_model_key})")

plot_confusion_matrix(
    all_results[best_model_key]["confusion_matrix"],
    class_names,
    model_key=best_model_key,
    save_dir=plots_dir,
    max_classes=30,
)
print("Confusion matrix saved.")

## 8. Results Summary

In [None]:
# ── Final ranked table ──────────────────────────────────────
print("=" * 70)
print(" FINAL RESULTS SUMMARY")
print("=" * 70)
print(summary_df.to_string(index=False))
print()

best = summary_df.iloc[0]
print(f"🏆 Best model by Top-1 accuracy: {best['Model']}")
print(f"   Top-1 Acc : {best['Top-1 Acc (%)']:.2f}%")
print(f"   Top-5 Acc : {best['Top-5 Acc (%)']:.2f}%")
print(f"   Inf. Time  : {best['Inf. Time (ms)']:.2f} ms/img")
print(f"   Params     : {best['Params (M)']:.1f} M")

fastest = summary_df.sort_values("Inf. Time (ms)").iloc[0]
print(f"\n⚡ Fastest model by inference time: {fastest['Model']}")
print(f"   Inf. Time: {fastest['Inf. Time (ms)']:.2f} ms/img")
print(f"   Top-1 Acc: {fastest['Top-1 Acc (%)']:.2f}%")

most_efficient = (summary_df["Top-1 Acc (%)"] / summary_df["Params (M)"]).idxmax()
eff = summary_df.iloc[most_efficient]
print(f"\n💡 Most parameter-efficient: {eff['Model']}")
print(f"   Acc/Params ratio: {eff['Top-1 Acc (%)']/eff['Params (M)']:.2f}% per M params")

## 9. Conclusions

### Key Observations
- **Accuracy**: Vision Transformers (ViT-B/16, ViT-L/16, DeiT, Swin) typically achieve
  higher top-1 accuracy on Caltech-256 when sufficient training data is available, because
  their self-attention mechanism captures long-range dependencies better than local convolutions.
- **Speed**: Lightweight CNNs (MobileNet-V2, EfficientNet-B0) offer the fastest inference,
  making them ideal for deployment-constrained environments.
- **Parameter efficiency**: EfficientNet-B0 and MobileNet-V2 deliver competitive accuracy
  with far fewer parameters than ViT-L/16 or VGG-16.
- **Training stability**: CNNs generally converge faster in the early epochs due to their
  inductive biases (locality, translation equivariance), while Transformers may require
  more epochs or data augmentation to reach peak performance.

### Recommendations
| Use-case | Recommended model |
|----------|-------------------|
| Highest accuracy | ViT-B/16 or Swin-T |
| Edge / mobile deployment | MobileNet-V2 or EfficientNet-B0 |
| Balanced accuracy & speed | ResNet-50 or EfficientNet-B0 |
| Fine-tuning with limited data | ResNet-50 (strong ImageNet prior) |

### Next Steps
- Experiment with stronger augmentation (RandAugment, MixUp, CutMix) for the Transformer models.
- Try longer training with a cosine LR schedule.
- Explore few-shot or semi-supervised approaches for the long-tail classes.
- Perform per-class error analysis to identify systematic weaknesses.