# Transfer Learning Image Classification (PyTorch)
Comprehensive, reproducible workflow: data prep, augmentation, baseline CNN, ResNet50 transfer learning (two-phase), metrics, confusion matrix, Grad-CAM, and saving artifacts.

# Objectives
- Build high-performance classifier on custom dataset via transfer learning (ResNet50).
- Two-phase training: feature extraction (frozen backbone) then fine-tuning (top layers unfrozen, low LR).
- Baseline CNN for comparison.
- Evaluation: accuracy, precision, recall, F1, confusion matrix.
- Interpretability: Grad-CAM heatmaps.
- Reproducibility: deterministic seeds, clear configs, checkpoints.

> Set your dataset root before running: `DATA_ROOT = "./data/your_dataset"` with `train/`, `val/`, `test/` folders.

In [None]:
# Imports & Environment
import os
import json
import random
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

from utils.data import build_transforms, create_dataloaders
from utils.models import build_baseline, build_resnet50, unfreeze_top_layers, param_groups
from utils.gradcam import GradCAM, overlay_heatmap

# Reproducibility
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# Configuration
@dataclass
class Config:
    data_root: str = "./data/your_dataset"  # update to your dataset path
    image_size: int = 224
    batch_size: int = 32
    num_workers: int = 2
    num_epochs_head: int = 8
    num_epochs_ft: int = 12
    lr_head: float = 3e-4
    lr_backbone: float = 1e-5
    weight_decay: float = 1e-4
    early_stop_patience: int = 5
    grad_accum_steps: int = 1
    trainable_layers_ft: int = 10
    baseline_epochs: int = 10
    output_dir: str = "./models"
    viz_dir: str = "./visualizations"


cfg = Config()
Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
Path(cfg.viz_dir).mkdir(parents=True, exist_ok=True)

cfg

In [None]:
# Data Loading & Visualization Helpers
loaders, sizes, class_names = create_dataloaders(
    data_root=cfg.data_root,
    batch_size=cfg.batch_size,
    num_workers=cfg.num_workers,
    image_size=cfg.image_size,
)
num_classes = len(class_names)
print("Class names:", class_names)
print("Sizes:", sizes)

# Quick sanity check: visualize a training batch
def show_batch(loader: DataLoader, title: str = "Train Batch"):
    images, labels = next(iter(loader))
    grid = make_grid(images[:16], nrow=4, normalize=True, padding=2)
    plt.figure(figsize=(8, 8))
    plt.title(title)
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.axis("off")
    plt.show()

# Uncomment to visualize
# show_batch(loaders["train"], "Train examples")

In [None]:
# Training Utilities

def accuracy_from_logits(logits, targets):
    preds = torch.argmax(logits, dim=1)
    return (preds == targets).float().mean().item()


def train_one_epoch(model, loader, criterion, optimizer, scheduler=None, grad_accum=1):
    model.train()
    total_loss, total_acc, n_batches = 0.0, 0.0, 0
    optimizer.zero_grad()
    for step, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y) / grad_accum
        loss.backward()
        if (step + 1) % grad_accum == 0:
            optimizer.step()
            optimizer.zero_grad()
        acc = accuracy_from_logits(logits, y)
        total_loss += loss.item() * grad_accum
        total_acc += acc
        n_batches += 1
    if scheduler:
        scheduler.step(total_loss / n_batches)
    return total_loss / n_batches, total_acc / n_batches


def eval_model(model, loader, criterion):
    model.eval()
    total_loss, total_acc, n_batches = 0.0, 0.0, 0
    all_preds, all_targets = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            acc = accuracy_from_logits(logits, y)
            total_loss += loss.item()
            total_acc += acc
            n_batches += 1
            all_preds.append(torch.argmax(logits, dim=1).cpu())
            all_targets.append(y.cpu())
    preds = torch.cat(all_preds)
    targets = torch.cat(all_targets)
    return total_loss / n_batches, total_acc / n_batches, preds, targets


def plot_curves(history, title: str, path: str):
    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history["train_loss"], label="train")
    plt.plot(history["val_loss"], label="val")
    plt.title("Loss")
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(history["train_acc"], label="train")
    plt.plot(history["val_acc"], label="val")
    plt.title("Accuracy")
    plt.legend()
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(path, dpi=200)
    plt.show()


In [None]:
# Baseline CNN (from scratch)
baseline = build_baseline(num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(baseline.parameters(), lr=1e-3, weight_decay=cfg.weight_decay)
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2)

history_base = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
best_val = 0.0
best_path = os.path.join(cfg.output_dir, "baseline_best.pth")

for epoch in range(cfg.baseline_epochs):
    train_loss, train_acc = train_one_epoch(baseline, loaders["train"], criterion, optimizer)
    val_loss, val_acc, _, _ = eval_model(baseline, loaders["val"], criterion)
    scheduler.step(val_loss)
    history_base["train_loss"].append(train_loss)
    history_base["val_loss"].append(val_loss)
    history_base["train_acc"].append(train_acc)
    history_base["val_acc"].append(val_acc)
    print(f"[Baseline] Epoch {epoch+1}/{cfg.baseline_epochs} - train_acc: {train_acc:.3f}, val_acc: {val_acc:.3f}")
    if val_acc > best_val:
        best_val = val_acc
        torch.save({"model_state": baseline.state_dict(), "class_names": class_names}, best_path)

plot_curves(history_base, "Baseline CNN", os.path.join(cfg.viz_dir, "baseline_curves.png"))

In [None]:
# Transfer Learning: ResNet50 Setup
resnet = build_resnet50(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(resnet.fc.parameters(), lr=cfg.lr_head, weight_decay=cfg.weight_decay)

history_head = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
best_val = 0.0
head_ckpt = os.path.join(cfg.output_dir, "resnet50_head.pth")

print("Phase 1: Feature extraction (backbone frozen)")
for epoch in range(cfg.num_epochs_head):
    train_loss, train_acc = train_one_epoch(resnet, loaders["train"], criterion, optimizer, grad_accum=cfg.grad_accum_steps)
    val_loss, val_acc, _, _ = eval_model(resnet, loaders["val"], criterion)
    history_head["train_loss"].append(train_loss)
    history_head["val_loss"].append(val_loss)
    history_head["train_acc"].append(train_acc)
    history_head["val_acc"].append(val_acc)
    print(f"[Head] Epoch {epoch+1}/{cfg.num_epochs_head} - train_acc: {train_acc:.3f}, val_acc: {val_acc:.3f}")
    if val_acc > best_val:
        best_val = val_acc
        torch.save({"model_state": resnet.state_dict(), "class_names": class_names}, head_ckpt)

plot_curves(history_head, "ResNet50 - Head Training", os.path.join(cfg.viz_dir, "resnet_head_curves.png"))

In [None]:
# Transfer Learning: Fine-Tuning Top Layers
# Unfreeze top layers
unfreeze_top_layers(resnet, trainable_layers=cfg.trainable_layers_ft)

# Differential learning rates: lower for backbone, higher for head
optimizer_ft = AdamW(param_groups(resnet, base_lr=cfg.lr_backbone, head_lr=cfg.lr_head), weight_decay=cfg.weight_decay)
scheduler_ft = ReduceLROnPlateau(optimizer_ft, mode="min", factor=0.5, patience=2)

history_ft = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
best_val_ft = 0.0
ft_ckpt = os.path.join(cfg.output_dir, "resnet50_finetuned.pth")
patience = cfg.early_stop_patience
wait = 0

print("Phase 2: Fine-tuning (top layers unfrozen)")
for epoch in range(cfg.num_epochs_ft):
    train_loss, train_acc = train_one_epoch(resnet, loaders["train"], criterion, optimizer_ft, grad_accum=cfg.grad_accum_steps)
    val_loss, val_acc, _, _ = eval_model(resnet, loaders["val"], criterion)
    scheduler_ft.step(val_loss)

    history_ft["train_loss"].append(train_loss)
    history_ft["val_loss"].append(val_loss)
    history_ft["train_acc"].append(train_acc)
    history_ft["val_acc"].append(val_acc)

    print(f"[FT] Epoch {epoch+1}/{cfg.num_epochs_ft} - train_acc: {train_acc:.3f}, val_acc: {val_acc:.3f}")

    if val_acc > best_val_ft:
        best_val_ft = val_acc
        torch.save({"model_state": resnet.state_dict(), "class_names": class_names}, ft_ckpt)
        wait = 0
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping triggered")
            break

plot_curves(history_ft, "ResNet50 - Fine-tuning", os.path.join(cfg.viz_dir, "resnet_ft_curves.png"))

In [None]:
# Evaluation on Test Set (Baseline vs Transfer)

def evaluate_and_report(model, loader, name: str):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    loss, acc, preds, targets = eval_model(model, loader, criterion)
    precision, recall, f1, _ = precision_recall_fscore_support(targets, preds, average="weighted")
    print(f"{name} - loss: {loss:.4f}, acc: {acc:.4f}, precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}")
    print(classification_report(targets, preds, target_names=class_names))
    cm = confusion_matrix(targets, preds)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.title(f"Confusion Matrix - {name}")
    plt.ylabel("True")
    plt.xlabel("Pred")
    plt.tight_layout()
    plt.savefig(os.path.join(cfg.viz_dir, f"cm_{name}.png"), dpi=200)
    plt.show()

# Load best checkpoints and evaluate
baseline_ckpt = torch.load(os.path.join(cfg.output_dir, "baseline_best.pth"), map_location=device)
baseline.load_state_dict(baseline_ckpt["model_state"])

evaluate_and_report(baseline.to(device), loaders["test"], name="baseline")

resnet_ckpt = torch.load(os.path.join(cfg.output_dir, "resnet50_finetuned.pth"), map_location=device)
resnet.load_state_dict(resnet_ckpt["model_state"])

evaluate_and_report(resnet.to(device), loaders["test"], name="resnet50_ft")

In [None]:
# Grad-CAM Visualization
# Use the last conv block of ResNet50 (layer4[-1]) as target
cam_extractor = GradCAM(resnet, target_layer=resnet.layer4[-1])

# Select a few test images
def gradcam_on_samples(loader: DataLoader, class_names, k: int = 3):
    resnet.eval()
    images, labels = next(iter(loader))
    images, labels = images.to(device), labels.to(device)
    for i in range(min(k, images.size(0))):
        img = images[i].unsqueeze(0)
        label = labels[i].item()
        with torch.no_grad():
            logits = resnet(img)
            pred = logits.argmax(dim=1).item()
        heatmap = cam_extractor(img, class_idx=pred)
        base_img, overlay = overlay_heatmap(img.squeeze(0), heatmap)

        plt.figure(figsize=(6, 3))
        plt.subplot(1, 2, 1)
        plt.title(f"True: {class_names[label]}\nPred: {class_names[pred]}")
        plt.imshow(base_img)
        plt.axis("off")
        plt.subplot(1, 2, 2)
        plt.title("Grad-CAM")
        plt.imshow(overlay)
        plt.axis("off")
        plt.tight_layout()
        out_path = os.path.join(cfg.viz_dir, f"gradcam_{i}.png")
        plt.savefig(out_path, dpi=200)
        plt.show()

# Uncomment to run Grad-CAM
# gradcam_on_samples(loaders["test"], class_names, k=3)

In [None]:
# Save Final Artifacts
# Save fine-tuned model state and config metadata
final_artifact = {
    "model_state": resnet.state_dict(),
    "class_names": class_names,
    "config": asdict(cfg),
}
final_path = os.path.join(cfg.output_dir, "resnet50_final.pth")
torch.save(final_artifact, final_path)
print("Saved final model to", final_path)

# Save histories for reproducibility
with open(os.path.join(cfg.output_dir, "history_baseline.json"), "w") as f:
    json.dump(history_base, f, indent=2)
with open(os.path.join(cfg.output_dir, "history_head.json"), "w") as f:
    json.dump(history_head, f, indent=2)
with open(os.path.join(cfg.output_dir, "history_ft.json"), "w") as f:
    json.dump(history_ft, f, indent=2)
print("Saved training histories.")