# Project: Transfer Learning on Flowers

**Module 3.3, Lesson 2 (Capstone)** | CourseAI

This is the Series 3 capstone project. You will combine everything you have learned about CNNs into a single practitioner workflow:

1. **Explore** the dataset (Oxford Flowers, 8 species)
2. **Feature extraction** — freeze a pretrained ResNet-18 backbone, train a new head
3. **Grad-CAM validation** — check if the model focuses on flowers or shortcuts
4. **Fine-tuning** — unfreeze layer4 with a differential learning rate, compare
5. **Final comparison** — accuracy table + Grad-CAM heatmaps side by side

No new concepts. Every technique here was taught in a prior lesson. The challenge is putting them together.

**Estimated time:** 30–45 minutes on a Colab GPU (T4).

---

## Setup

Run this cell to install dependencies and import everything.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.models as models
from torchvision.models import ResNet18_Weights
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
import copy

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')

---

## Phase 1: Explore the Data

We use **Oxford Flowers102**, a dataset of flower photographs from 102 species. To keep training fast and the task manageable, we filter to **8 visually distinct species** with roughly 50–80 images each.

This is a realistic small-dataset scenario: enough data for transfer learning, far too little to train from scratch.

In [None]:
# Download Oxford Flowers102
# This downloads ~350MB the first time
from torchvision.datasets import Flowers102

# Download both train and test splits
raw_train = Flowers102(root='./data', split='train', download=True)
raw_val = Flowers102(root='./data', split='val', download=True)
raw_test = Flowers102(root='./data', split='test', download=True)

print(f'Train: {len(raw_train)} images')
print(f'Val:   {len(raw_val)} images')
print(f'Test:  {len(raw_test)} images')

In [None]:
# The Flowers102 labels are 0-indexed (0 to 101).
# We'll pick 8 visually distinct species and remap labels to 0-7.
#
# Selected classes (these have enough samples and are visually distinct):
# Label names are chosen for display purposes and may not match the original
# dataset's label file exactly. The important thing is visual distinctness
# between classes, not the precise species name.

SELECTED_CLASSES = [1, 10, 17, 28, 51, 63, 70, 82]  # Original Flowers102 label indices
CLASS_NAMES = [
    'Pink Primrose',     # class 1
    'Globe Thistle',     # class 10
    'Purple Coneflower', # class 17
    'Stemless Gentian',  # class 28
    'Wild Pansy',        # class 51
    'Black-eyed Susan',  # class 63
    'Bird of Paradise',  # class 70
    'Clematis',          # class 82
]

# Create mapping from original labels to new 0-7 labels
label_map = {orig: new for new, orig in enumerate(SELECTED_CLASSES)}
NUM_CLASSES = len(SELECTED_CLASSES)

print(f'Selected {NUM_CLASSES} classes:')
for i, name in enumerate(CLASS_NAMES):
    print(f'  {i}: {name} (original label {SELECTED_CLASSES[i]})')

In [None]:
# Filter datasets to only include selected classes

def filter_dataset(dataset, selected_classes, label_map):
    """Return indices of samples belonging to selected classes."""
    indices = []
    mapped_labels = []
    for i in range(len(dataset)):
        _, label = dataset[i]
        if label in selected_classes:
            indices.append(i)
            mapped_labels.append(label_map[label])
    return indices, mapped_labels

print('Filtering datasets (this takes a moment)...')

# Combine train + val for our training set (Flowers102 train split is small)
train_indices_raw, train_labels_raw = filter_dataset(raw_train, SELECTED_CLASSES, label_map)
val_indices_raw, val_labels_raw = filter_dataset(raw_val, SELECTED_CLASSES, label_map)
test_indices, test_labels = filter_dataset(raw_test, SELECTED_CLASSES, label_map)

# Combine train+val for training, use test for evaluation
# (Flowers102 train split has only 10 images per class — too few alone)
train_indices = train_indices_raw + val_indices_raw
train_labels = train_labels_raw + val_labels_raw

print(f'\nFiltered dataset sizes:')
print(f'  Train: {len(train_indices)} images')
print(f'  Test:  {len(test_indices)} images')

# Class distribution
print(f'\nTrain class distribution:')
train_counts = Counter(train_labels)
for cls_idx in range(NUM_CLASSES):
    count = train_counts.get(cls_idx, 0)
    print(f'  {CLASS_NAMES[cls_idx]}: {count} images')

print(f'\nTest class distribution:')
test_counts = Counter(test_labels)
for cls_idx in range(NUM_CLASSES):
    count = test_counts.get(cls_idx, 0)
    print(f'  {CLASS_NAMES[cls_idx]}: {count} images')

In [None]:
# Define transforms
# Training: augmentation to help with small dataset
# Test: just resize and normalize (no augmentation)

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

print('Transforms defined.')
print('Training augmentation: RandomResizedCrop, RandomHorizontalFlip, ColorJitter')
print('Test: CenterCrop only (no augmentation)')

In [None]:
# Custom dataset wrapper that applies our label mapping and transforms

class FilteredFlowersDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, indices, labels, transform):
        self.base_dataset = base_dataset
        self.indices = indices
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        img, _ = self.base_dataset[self.indices[idx]]
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

# We need to handle the fact that train_indices spans two datasets
# (raw_train and raw_val). Build a combined approach.

class CombinedFilteredDataset(torch.utils.data.Dataset):
    """Combines samples from multiple base datasets with pre-computed indices and labels."""
    def __init__(self, datasets_with_indices, labels, transform):
        # datasets_with_indices: list of (dataset, [indices])
        self.items = []  # (dataset, original_index)
        for dataset, indices in datasets_with_indices:
            for idx in indices:
                self.items.append((dataset, idx))
        self.labels = labels
        self.transform = transform
        assert len(self.items) == len(self.labels)

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        dataset, orig_idx = self.items[idx]
        img, _ = dataset[orig_idx]
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label


train_dataset = CombinedFilteredDataset(
    [(raw_train, train_indices_raw), (raw_val, val_indices_raw)],
    train_labels,
    train_transform,
)

test_dataset = FilteredFlowersDataset(
    raw_test, test_indices, test_labels, test_transform,
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

print(f'Train dataset: {len(train_dataset)} images')
print(f'Test dataset:  {len(test_dataset)} images')
print(f'Train batches: {len(train_loader)}')
print(f'Test batches:  {len(test_loader)}')

In [None]:
# Visualize sample images from each class

def unnormalize(tensor):
    """Reverse ImageNet normalization for display."""
    mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
    std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
    return (tensor.cpu() * std + mean).clamp(0, 1)

# Show 2 images per class
fig, axes = plt.subplots(2, NUM_CLASSES, figsize=(20, 6))

# Collect samples per class
class_samples = {i: [] for i in range(NUM_CLASSES)}
for idx in range(len(test_dataset)):
    img, label = test_dataset[idx]
    if len(class_samples[label]) < 2:
        class_samples[label].append(img)
    if all(len(v) >= 2 for v in class_samples.values()):
        break

for cls_idx in range(NUM_CLASSES):
    for row in range(2):
        if row < len(class_samples[cls_idx]):
            img = unnormalize(class_samples[cls_idx][row])
            axes[row, cls_idx].imshow(img.permute(1, 2, 0).numpy())
        axes[row, cls_idx].axis('off')
        if row == 0:
            axes[row, cls_idx].set_title(CLASS_NAMES[cls_idx], fontsize=9)

fig.suptitle('Sample Images from Each Class', fontsize=14)
plt.tight_layout()
plt.show()

### What to Notice

- Are the classes visually distinct? Can you tell them apart yourself?
- Look at the backgrounds — are they varied or consistent within a class? Consistent backgrounds could become shortcuts.
- This is a **small** dataset. Transfer learning is the only viable strategy here.

---

## Phase 2: Feature Extraction

The practitioner workflow says: **start with the simplest strategy**. Feature extraction is simpler than fine-tuning:

1. Load pretrained ResNet-18
2. Freeze all backbone parameters
3. Replace the classification head for 8 classes
4. Train only the head

You did this on CIFAR-10 in the Transfer Learning lesson. Now do it on flowers.

In [None]:
def create_feature_extraction_model(num_classes):
    """Load pretrained ResNet-18 and set up for feature extraction.

    Steps:
        1. Load pretrained ResNet-18
        2. Freeze ALL backbone parameters (requires_grad = False)
        3. Replace model.fc with a new Linear layer for num_classes

    Returns:
        model with frozen backbone and trainable classification head
    """
    model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

    # TODO: Freeze all backbone parameters
    # Hint: iterate over model.parameters() and set requires_grad = False
    # Expected: every parameter in the model should have requires_grad == False after this
    # (1-2 lines)


    # TODO: Replace the classification head for num_classes
    # Hint: the original head is model.fc — check model.fc.in_features for the input size
    # Expected: model.fc should be a new nn.Linear(in_features, num_classes)
    # (1-2 lines)


    return model

fe_model = create_feature_extraction_model(NUM_CLASSES).to(device)

# Verify: only the fc layer should be trainable
trainable = sum(p.numel() for p in fe_model.parameters() if p.requires_grad)
total = sum(p.numel() for p in fe_model.parameters())
fc_in = fe_model.fc.in_features
print(f'Trainable parameters: {trainable:,} / {total:,} ({trainable/total:.1%})')
print(f'Only training the classification head ({fc_in} -> {NUM_CLASSES})')

# Sanity check — if this fails, revisit your TODO implementation
assert trainable < total, "All parameters are trainable — did you forget to freeze the backbone?"
assert trainable > 0, "No trainable parameters — did you freeze the new fc layer too?"

In [None]:
# Training utilities (provided)

def train_one_epoch(model, loader, optimizer, criterion):
    """Train for one epoch. Returns (loss, accuracy)."""
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    return total_loss / total, correct / total


def evaluate(model, loader, criterion):
    """Evaluate on a dataset. Returns (loss, accuracy)."""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

    return total_loss / total, correct / total


def train_model(model, train_loader, test_loader, optimizer, num_epochs=15):
    """Full training loop with logging. Returns history dict."""
    criterion = nn.CrossEntropyLoss()
    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}

    for epoch in range(num_epochs):
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
        test_loss, test_acc = evaluate(model, test_loader, criterion)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)

        print(f'Epoch {epoch+1:2d}/{num_epochs}  '
              f'Train Loss: {train_loss:.4f}  Train Acc: {train_acc:.1%}  '
              f'Test Loss: {test_loss:.4f}  Test Acc: {test_acc:.1%}')

    return history

print('Training utilities loaded.')

In [None]:
# TODO: Train the feature extraction model
# Only the fc layer is trainable, so we pass only fc parameters to the optimizer.

fe_optimizer = optim.Adam(fe_model.fc.parameters(), lr=1e-3)

print('Training feature extraction model...')
print('=' * 70)
fe_history = train_model(fe_model, train_loader, test_loader, fe_optimizer, num_epochs=15)
print('=' * 70)
print(f'\nFinal test accuracy: {fe_history["test_acc"][-1]:.1%}')

In [None]:
# Plot training curves

def plot_training_curves(history, title='Training Curves'):
    """Plot loss and accuracy curves."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    epochs = range(1, len(history['train_loss']) + 1)

    ax1.plot(epochs, history['train_loss'], 'b-', label='Train')
    ax1.plot(epochs, history['test_loss'], 'r-', label='Test')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    ax2.plot(epochs, history['train_acc'], 'b-', label='Train')
    ax2.plot(epochs, history['test_acc'], 'r-', label='Test')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 1.05)

    fig.suptitle(title, fontsize=14)
    plt.tight_layout()
    plt.show()

plot_training_curves(fe_history, 'Feature Extraction Training')

### Checkpoint: Feature Extraction Results

Before moving on, note your results:
- What accuracy did you achieve?
- Is there a large gap between train and test accuracy (overfitting)?
- Training should be fast because only the fc layer is being updated.

**But accuracy alone is not enough.** Time for the most important step.

---

## Phase 3: Grad-CAM Validation

High accuracy is step two of the practitioner workflow. The real work starts here.

**Question:** Is the model right for the right reasons? Does it focus on the flowers, or on something else (background, pot, image borders)?

You implemented Grad-CAM from scratch in Visualizing Features. The utility below is provided so you can focus on **interpretation**, not reimplementation.

In [None]:
# Grad-CAM utility (provided)

def grad_cam(model, img_tensor, target_class=None):
    """Compute Grad-CAM for a given image and target class.

    Args:
        model: the model (in eval mode)
        img_tensor: preprocessed image [1, 3, 224, 224] on device
        target_class: int class index. If None, uses predicted class.

    Returns:
        cam: numpy array [224, 224], values in [0, 1]
        predicted_class: the class index used
    """
    model.eval()
    stored = {}

    def forward_hook(module, input, output):
        stored['activations'] = output

    def backward_hook(module, grad_input, grad_output):
        stored['gradients'] = grad_output[0]

    fhook = model.layer4.register_forward_hook(forward_hook)
    bhook = model.layer4.register_full_backward_hook(backward_hook)

    output = model(img_tensor)

    if target_class is None:
        target_class = output.argmax(dim=1).item()

    model.zero_grad()
    output[0, target_class].backward()

    gradients = stored['gradients']   # [1, 512, 7, 7]
    activations = stored['activations']  # [1, 512, 7, 7]

    weights = gradients.mean(dim=[2, 3])  # [1, 512]
    cam = (weights.unsqueeze(-1).unsqueeze(-1) * activations).sum(dim=1, keepdim=True)
    cam = F.relu(cam)
    cam = F.interpolate(cam, size=(224, 224), mode='bilinear', align_corners=False)
    cam = cam.squeeze().detach().cpu().numpy()

    if cam.max() > 0:
        cam = cam / cam.max()

    fhook.remove()
    bhook.remove()

    return cam, target_class


def show_grad_cam_overlay(img_tensor, cam, class_name, ax=None):
    """Overlay Grad-CAM heatmap on the image."""
    img_np = unnormalize(img_tensor.squeeze(0)).permute(1, 2, 0).numpy()
    heatmap = plt.cm.jet(cam)[:, :, :3]
    overlay = 0.5 * img_np + 0.5 * heatmap
    overlay = np.clip(overlay, 0, 1)

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(5, 5))

    ax.imshow(overlay)
    ax.set_title(f'Grad-CAM: {class_name}', fontsize=10)
    ax.axis('off')

print('Grad-CAM utilities loaded.')

In [None]:
# TODO: Run Grad-CAM on correctly classified test images
#
# For each class, find 2 correctly classified images and show:
# - Original image
# - Grad-CAM overlay
#
# Ask yourself: does the model focus on the flower, or on something else?

# Collect correctly classified samples (2 per class)
fe_model.eval()
correct_samples = {i: [] for i in range(NUM_CLASSES)}

for idx in range(len(test_dataset)):
    img, label = test_dataset[idx]
    img_batch = img.unsqueeze(0).to(device)

    with torch.no_grad():
        pred = fe_model(img_batch).argmax(dim=1).item()

    if pred == label and len(correct_samples[label]) < 2:
        correct_samples[label].append((img_batch, label))

    if all(len(v) >= 2 for v in correct_samples.values()):
        break

# Display Grad-CAM for each class
fig, axes = plt.subplots(NUM_CLASSES, 4, figsize=(16, 4 * NUM_CLASSES))

for cls_idx in range(NUM_CLASSES):
    for sample_idx in range(min(2, len(correct_samples[cls_idx]))):
        img_batch, label = correct_samples[cls_idx][sample_idx]

        # Original
        col_offset = sample_idx * 2
        img_display = unnormalize(img_batch.squeeze(0)).permute(1, 2, 0).numpy()
        axes[cls_idx, col_offset].imshow(img_display)
        axes[cls_idx, col_offset].set_title(f'{CLASS_NAMES[cls_idx]}', fontsize=9)
        axes[cls_idx, col_offset].axis('off')

        # Grad-CAM
        cam, _ = grad_cam(fe_model, img_batch, target_class=label)
        show_grad_cam_overlay(img_batch, cam, CLASS_NAMES[cls_idx], ax=axes[cls_idx, col_offset + 1])

fig.suptitle('Feature Extraction Model: Grad-CAM Validation', fontsize=16)
plt.tight_layout()
plt.show()

### Interpret Your Grad-CAM Results

Look at each heatmap carefully:

- **Good signs:** Heatmap highlights the flower petals, center, or distinctive shape
- **Warning signs:** Heatmap highlights background, image borders, or non-flower regions
- **Ambiguous:** Heatmap includes flower + some surrounding context (leaves, stem) — this can be legitimate

Remember the husky/wolf example from Visualizing Features. Your flower model might have its own version of this.

**Key question for each heatmap:** *"If I showed this to someone who does not know ML, would they agree the model is focusing on the right thing?"*

---

### Optional: The Other Two Tools

Grad-CAM answers "what mattered for this prediction?" But the **three questions, three tools** framework from Visualizing Features gives you two more lenses:

1. **Filter visualization** (conv1 weights) — What patterns does the first layer detect?
2. **Activation maps** (layer4 output) — What does the network see at the deepest level?

Since you froze the backbone for feature extraction, conv1 filters should be **identical** to the pretrained ImageNet filters. This is a concrete confirmation of what "frozen" means — the early features were not touched.

Run the cell below to check. This is optional but reinforces the full visualization toolkit.

In [None]:
# Optional: Visualize conv1 filters and layer4 activation maps
# from the feature extraction model

# --- 1. Conv1 filter visualization ---
# Since the backbone was frozen, these should be identical to ImageNet pretrained filters.
filters = fe_model.conv1.weight.data.cpu().clone()

# Normalize each filter to [0, 1] for display
filters_min = filters.flatten(1).min(dim=1).values[:, None, None, None]
filters_max = filters.flatten(1).max(dim=1).values[:, None, None, None]
filters_norm = (filters - filters_min) / (filters_max - filters_min + 1e-8)

fig, axes = plt.subplots(4, 16, figsize=(16, 4))
for i in range(min(64, filters_norm.shape[0])):
    row, col = i // 16, i % 16
    axes[row, col].imshow(filters_norm[i].permute(1, 2, 0).numpy())
    axes[row, col].axis('off')
fig.suptitle('Conv1 Filters (frozen — should match ImageNet pretrained)', fontsize=13)
plt.tight_layout()
plt.show()

# --- 2. Layer4 activation maps ---
# Capture layer4 activations on a sample flower image using a forward hook.
# Use a correctly classified sample from Phase 3
sample_img = list(correct_samples.values())[0][0][0]

stored_acts = {}
hook = fe_model.layer4.register_forward_hook(lambda m, inp, out: stored_acts.update({'act': out}))
fe_model.eval()
with torch.no_grad():
    fe_model(sample_img)
hook.remove()

acts = stored_acts['act'].squeeze(0).cpu()  # [512, 7, 7]

# Show the first 16 activation maps
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(16):
    row, col = i // 8, i % 8
    axes[row, col].imshow(acts[i].numpy(), cmap='viridis')
    axes[row, col].axis('off')
    axes[row, col].set_title(f'ch {i}', fontsize=8)
fig.suptitle('Layer4 Activation Maps (feature extraction model)', fontsize=13)
plt.tight_layout()
plt.show()

print('Conv1 filters are frozen ImageNet features — edge and color detectors.')
print('Layer4 activations show what high-level patterns the model detected in this flower.')

## Phase 4: Fine-Tuning

Now that you have feature extraction results and Grad-CAM heatmaps, try fine-tuning to see if adapting the backbone helps.

Feature extraction freezes the entire backbone. Fine-tuning goes one step further: **unfreeze the last residual stage (layer4)** so the backbone can adapt its high-level features to the flower domain.

The key technique: **differential learning rates**. The unfrozen backbone layers get a much lower learning rate than the classification head. This prevents destroying the pretrained features.

You saw this pattern in Transfer Learning. Now apply it to your own data and compare to the feature extraction baseline — both accuracy **and** Grad-CAM focus.

In [None]:
# TODO: Create a fine-tuning model
#
# Start from a fresh pretrained ResNet-18 (not the already-trained fe_model).
# Three steps: freeze everything, then selectively unfreeze layer4, then replace the head.

def create_finetuning_model(num_classes):
    """Load pretrained ResNet-18 set up for fine-tuning with layer4 unfrozen.

    Steps:
        1. Load pretrained ResNet-18
        2. Freeze ALL parameters
        3. Unfreeze layer4 parameters (set requires_grad = True)
        4. Replace model.fc for num_classes (new layers are trainable by default)

    Returns:
        model with layer4 + fc trainable, everything else frozen
    """
    model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

    # TODO: Freeze all parameters (same as feature extraction)
    # (1-2 lines)


    # TODO: Unfreeze layer4 parameters
    # Hint: iterate over model.layer4.parameters() and set requires_grad = True
    # (1-2 lines)


    # TODO: Replace the classification head for num_classes
    # (same as feature extraction — 1-2 lines)


    return model

ft_model = create_finetuning_model(NUM_CLASSES).to(device)

# Verify: layer4 + fc should be trainable
trainable = sum(p.numel() for p in ft_model.parameters() if p.requires_grad)
total = sum(p.numel() for p in ft_model.parameters())
layer4_params = sum(p.numel() for p in ft_model.layer4.parameters())
fc_params = sum(p.numel() for p in ft_model.fc.parameters())
print(f'Trainable parameters: {trainable:,} / {total:,} ({trainable/total:.1%})')
print(f'Trainable layers: layer4 ({layer4_params:,}) + fc ({fc_params:,})')

# Sanity check
assert trainable > fc_params, "Only fc is trainable — did you forget to unfreeze layer4?"

In [None]:
# TODO: Set up optimizer with differential learning rates
#
# Create an Adam optimizer with TWO parameter groups:
#   1. layer4 parameters → low learning rate (1e-4) to adapt pretrained features gently
#   2. fc parameters → higher learning rate (1e-3) to learn the new classification head
#
# Hint: pass a list of dicts to optim.Adam:
#   optim.Adam([
#       {'params': ..., 'lr': ...},
#       {'params': ..., 'lr': ...},
#   ])
#
# You did this in Transfer Learning — same pattern, your own code.

# TODO: Create the optimizer with differential learning rates
# (1-4 lines)
ft_optimizer = None  # Replace this


# Verify (this will fail if ft_optimizer is still None)
assert ft_optimizer is not None, "Create the optimizer above"
print('Optimizer configured with differential learning rates:')
for i, group in enumerate(ft_optimizer.param_groups):
    print(f'  Group {i}: lr={group["lr"]}, {sum(p.numel() for p in group["params"]):,} parameters')

In [None]:
# Train the fine-tuning model
print('Training fine-tuning model...')
print('=' * 70)
ft_history = train_model(ft_model, train_loader, test_loader, ft_optimizer, num_epochs=15)
print('=' * 70)
print(f'\nFinal test accuracy: {ft_history["test_acc"][-1]:.1%}')

In [None]:
plot_training_curves(ft_history, 'Fine-Tuning Training')

### Checkpoint: Fine-Tuning vs Feature Extraction

Compare the two approaches:
- Did fine-tuning improve test accuracy over feature extraction?
- Was there more overfitting (larger train/test accuracy gap)?
- Was the improvement (if any) worth the added complexity?

On a small dataset like this, fine-tuning may or may not help. That is a realistic outcome, not a failure.

---

## Phase 5: Final Comparison

Build the complete picture: accuracy numbers **and** Grad-CAM heatmaps for both approaches, side by side.

In [None]:
# Accuracy comparison table

fe_final_acc = fe_history['test_acc'][-1]
ft_final_acc = ft_history['test_acc'][-1]

print('=' * 50)
print('ACCURACY COMPARISON')
print('=' * 50)
print(f'{"Approach":<25} {"Test Accuracy":<15}')
print('-' * 40)
print(f'{"Feature Extraction":<25} {fe_final_acc:<15.1%}')
print(f'{"Fine-Tuning (layer4)":<25} {ft_final_acc:<15.1%}')
print('-' * 40)
diff = ft_final_acc - fe_final_acc
direction = 'improvement' if diff > 0 else 'decrease'
print(f'Difference: {abs(diff):.1%} {direction}')
print('=' * 50)

In [None]:
# Side-by-side training curves

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

epochs_fe = range(1, len(fe_history['test_acc']) + 1)
epochs_ft = range(1, len(ft_history['test_acc']) + 1)

ax1.plot(epochs_fe, fe_history['test_acc'], 'b-', label='Feature Extraction', linewidth=2)
ax1.plot(epochs_ft, ft_history['test_acc'], 'r-', label='Fine-Tuning', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Test Accuracy')
ax1.set_title('Test Accuracy Comparison')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_ylim(0, 1.05)

ax2.plot(epochs_fe, fe_history['test_loss'], 'b-', label='Feature Extraction', linewidth=2)
ax2.plot(epochs_ft, ft_history['test_loss'], 'r-', label='Fine-Tuning', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Test Loss')
ax2.set_title('Test Loss Comparison')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# TODO: Side-by-side Grad-CAM comparison
#
# For each class, show the same test image with:
# - Original image
# - Grad-CAM from feature extraction model
# - Grad-CAM from fine-tuning model
#
# Look for differences in spatial focus between the two approaches.

# Collect one test image per class
comparison_images = {}
for idx in range(len(test_dataset)):
    img, label = test_dataset[idx]
    if label not in comparison_images:
        comparison_images[label] = img.unsqueeze(0).to(device)
    if len(comparison_images) == NUM_CLASSES:
        break

fig, axes = plt.subplots(NUM_CLASSES, 3, figsize=(14, 4 * NUM_CLASSES))

for cls_idx in range(NUM_CLASSES):
    img_batch = comparison_images[cls_idx]

    # Original
    img_display = unnormalize(img_batch.squeeze(0)).permute(1, 2, 0).numpy()
    axes[cls_idx, 0].imshow(img_display)
    axes[cls_idx, 0].set_title(f'{CLASS_NAMES[cls_idx]}', fontsize=10)
    axes[cls_idx, 0].axis('off')

    # Grad-CAM: feature extraction model
    cam_fe, _ = grad_cam(fe_model, img_batch, target_class=cls_idx)
    show_grad_cam_overlay(img_batch, cam_fe, 'Feature Extraction', ax=axes[cls_idx, 1])

    # Grad-CAM: fine-tuning model
    cam_ft, _ = grad_cam(ft_model, img_batch, target_class=cls_idx)
    show_grad_cam_overlay(img_batch, cam_ft, 'Fine-Tuning', ax=axes[cls_idx, 2])

fig.suptitle('Grad-CAM Comparison: Feature Extraction vs Fine-Tuning', fontsize=16)
plt.tight_layout()
plt.show()

### Interpret the Comparison

Compare the Grad-CAM heatmaps for the two approaches:

- **Did fine-tuning change what the model focuses on?** Often, fine-tuning produces tighter, more object-focused heatmaps because layer4 adapts to the specific task.
- **Are there classes where one approach focuses on the flower and the other does not?** This is the clearest evidence of the value of fine-tuning (or the danger of it).
- **Are there any classes where both approaches focus on something other than the flower?** That would suggest a dataset bias — something about the images that correlates with the label beyond the flower itself.

---

## Summary

You just completed the full practitioner workflow:

| Step | What You Did |
|------|--------------|
| **1. Explore** | Understood the dataset: 8 flower species, ~50–80 images each |
| **2. Feature Extraction** | Froze ResNet-18 backbone, trained classification head |
| **3. Grad-CAM Validation** | Checked if the model focuses on flowers, not shortcuts |
| **4. Fine-Tuning** | Unfroze layer4 with differential LR, compared to baseline |
| **5. Comparison** | Accuracy table + Grad-CAM heatmaps side by side |

The most important step was **not** the one that maximized accuracy. It was the Grad-CAM validation — the step where you checked whether the model learned the right features.

**Correct prediction does not mean correct reasoning.** You now have the tools to check.

---

### Series 3 Complete

You started Series 3 asking "what is a convolution?" You ended it by fine-tuning a pretrained CNN on a custom dataset and using Grad-CAM to verify the model's reasoning.

The practical superpower you built: not just "can I get high accuracy?" but **"can I understand what my model learned and trust its reasoning?"**

Next up: **Series 4 — LLMs and Transformers**. A different architecture, a different data modality, but the same practitioner mindset: understand the model, do not just use it.