# Transfer Learning
## Module 3.2, Lesson 3 — Feature Extraction and Fine-Tuning a Pretrained ResNet

In this notebook you will:
1. Set up a **small dataset** (3 classes from CIFAR-10, limited to ~500 training images per class)
2. Try **training from scratch** — see it overfit and fail
3. Load a **pretrained ResNet-18** via `torchvision.models`
4. Do **feature extraction** — freeze the backbone, replace the head, train
5. Do **fine-tuning** — unfreeze the last stage with differential learning rates
6. **Compare all three approaches** side by side

The key insight: same architecture, same data — the difference is the starting point.

---

**Prerequisites:** ResNets and Skip Connections lesson, PyTorch training loop

## 0. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models import ResNet18_Weights
import matplotlib.pyplot as plt
import numpy as np
import time
import copy

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Create a Small Dataset

To make transfer learning's benefit obvious, we need a **small dataset** where training from scratch will overfit.

We take CIFAR-10 but keep only 3 classes (cat, dog, horse) and limit to 500 training images per class. This simulates a realistic scenario: you have a niche classification task and limited data.

**Important:** We resize images to 224x224 (the size pretrained ResNet expects) and normalize using ImageNet statistics.

In [None]:
# Classes to keep: cat (3), dog (5), horse (7)
SELECTED_CLASSES = [3, 5, 7]
CLASS_NAMES = ['cat', 'dog', 'horse']
NUM_CLASSES = len(SELECTED_CLASSES)
SAMPLES_PER_CLASS = 500  # small dataset!

# ImageNet normalization (required for pretrained models)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

### Data Augmentation

You already know `Compose`, `ToTensor`, and `Normalize` from the Datasets and DataLoaders lesson. For small datasets, we add random augmentations to prevent overfitting:

- **RandomResizedCrop(224)** — random crop and resize, adds position/scale variation
- **RandomHorizontalFlip()** — mirrors the image, effectively doubling the dataset
- **ColorJitter** — varies brightness and contrast, making the model robust to lighting

Validation transforms are deterministic (always the same crop, no random flips).

In [None]:
# Training transforms: augmentation + ImageNet normalization
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

# Validation transforms: deterministic (no random augmentation)
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

In [None]:
def create_small_dataset(dataset, selected_classes, samples_per_class):
    """Filter a dataset to keep only selected classes with limited samples.
    
    Returns indices for the subset and a mapping from original labels to new labels (0, 1, 2, ...).
    """
    # Build label mapping: original_class -> new_class (0, 1, 2, ...)
    label_map = {orig: new for new, orig in enumerate(selected_classes)}
    
    # Collect indices for each selected class
    class_indices = {c: [] for c in selected_classes}
    targets = np.array(dataset.targets)
    
    for cls in selected_classes:
        all_indices = np.where(targets == cls)[0]
        # Take at most samples_per_class
        chosen = all_indices[:samples_per_class]
        class_indices[cls] = chosen.tolist()
    
    # Flatten all indices
    all_chosen = []
    for cls in selected_classes:
        all_chosen.extend(class_indices[cls])
    
    return all_chosen, label_map


class RemappedSubset(torch.utils.data.Dataset):
    """A subset that remaps labels to 0, 1, 2, ..."""
    def __init__(self, dataset, indices, label_map):
        self.dataset = dataset
        self.indices = indices
        self.label_map = label_map
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        img, label = self.dataset[self.indices[idx]]
        return img, self.label_map[label]

In [None]:
# Download CIFAR-10
full_train = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform
)
full_test = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=val_transform
)

# Create small training set (500 per class = 1500 total)
train_indices, label_map = create_small_dataset(full_train, SELECTED_CLASSES, SAMPLES_PER_CLASS)
train_dataset = RemappedSubset(full_train, train_indices, label_map)

# Create test set (all available test images for selected classes)
test_indices, _ = create_small_dataset(full_test, SELECTED_CLASSES, samples_per_class=10000)
test_dataset = RemappedSubset(full_test, test_indices, label_map)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')
print(f'Classes: {CLASS_NAMES}')
print(f'Label mapping: {label_map}')

### Visualize Some Training Images

In [None]:
def show_images(dataset, n=8):
    """Display a grid of images from the dataset."""
    fig, axes = plt.subplots(1, n, figsize=(2 * n, 2.5))
    for i in range(n):
        img, label = dataset[i]
        # Undo normalization for display
        img = img.clone()
        for c in range(3):
            img[c] = img[c] * IMAGENET_STD[c] + IMAGENET_MEAN[c]
        img = img.clamp(0, 1)
        axes[i].imshow(img.permute(1, 2, 0).numpy())
        axes[i].set_title(CLASS_NAMES[label], fontsize=10)
        axes[i].axis('off')
    plt.tight_layout()
    plt.show()

show_images(train_dataset)

## 2. Training Loop (Provided)

This training function works for all three approaches. It tracks training loss, training accuracy, and test accuracy per epoch.

**Note on `nn.CrossEntropyLoss`:** This combines log-softmax and negative log-likelihood into one operation. It takes raw logits (no softmax needed) and is the standard loss for multi-class classification. You have used `nn.MSELoss` before — `CrossEntropyLoss` is the classification equivalent.

In [None]:
def train_model(model, train_loader, test_loader, optimizer, epochs=15, label='Model'):
    """Train a model and return training history."""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    
    history = {'train_loss': [], 'train_acc': [], 'test_acc': []}
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        
        train_loss = running_loss / total
        train_acc = 100.0 * correct / total
        
        # Evaluation phase
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        test_acc = 100.0 * correct / total
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc)
        
        print(f'[{label}] Epoch {epoch+1:2d}/{epochs} | '
              f'Train Loss: {train_loss:.4f} | '
              f'Train Acc: {train_acc:.1f}% | '
              f'Test Acc: {test_acc:.1f}%')
    
    return history

## 3. Approach 1: Training from Scratch (Baseline)

First, train a ResNet-18 from random initialization on our small dataset. This is the baseline — we expect it to overfit badly because 1500 images are nowhere near enough to learn 11M parameters.

In [None]:
# ResNet-18 from scratch (random initialization)
scratch_model = models.resnet18(weights=None)  # No pretrained weights
scratch_model.fc = nn.Linear(512, NUM_CLASSES)  # Replace 1000-class head with 3-class head

scratch_optimizer = optim.Adam(scratch_model.parameters(), lr=1e-3)

print(f'Parameters: {sum(p.numel() for p in scratch_model.parameters()):,}')
print(f'Training images: {len(train_dataset)}')
print(f'Ratio: {sum(p.numel() for p in scratch_model.parameters()) / len(train_dataset):.0f} params per image')
print(f'\nThis is a recipe for overfitting!\n')

scratch_history = train_model(
    scratch_model, train_loader, test_loader,
    optimizer=scratch_optimizer, epochs=15, label='Scratch'
)

## 4. YOUR TURN: Approach 2 — Feature Extraction

Now use a **pretrained** ResNet-18. The three steps from the lesson:

1. **Load pretrained:** `models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)`
2. **Freeze all parameters:** `for param in model.parameters(): param.requires_grad = False`
3. **Replace the head:** `model.fc = nn.Linear(512, num_classes)` — new layer is trainable by default

Then create an optimizer that only updates `model.fc.parameters()`.

Fill in the `TODO` sections below.

In [None]:
# TODO 1: Load a pretrained ResNet-18
# Hint: models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

fe_model = ...  # YOUR CODE HERE

# TODO 2: Freeze ALL parameters in the model
# Hint: loop over model.parameters() and set requires_grad = False

...  # YOUR CODE HERE

# TODO 3: Replace the classification head
# The original head is model.fc = Linear(512, 1000)
# Replace it with Linear(512, NUM_CLASSES)
# The new layer's parameters will have requires_grad=True by default

...  # YOUR CODE HERE

# Verify: count trainable vs frozen parameters
trainable = sum(p.numel() for p in fe_model.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in fe_model.parameters() if not p.requires_grad)
print(f'Trainable parameters: {trainable:,}')
print(f'Frozen parameters: {frozen:,}')
print(f'Total: {trainable + frozen:,}')
print(f'\nOnly {trainable / (trainable + frozen) * 100:.2f}% of parameters are trainable!')

In [None]:
# TODO 4: Create an optimizer that only updates the head's parameters
# Hint: optim.Adam(fe_model.fc.parameters(), lr=1e-3)
# We only pass model.fc.parameters() because everything else is frozen

fe_optimizer = ...  # YOUR CODE HERE

# Train
fe_history = train_model(
    fe_model, train_loader, test_loader,
    optimizer=fe_optimizer, epochs=15, label='Feature Extraction'
)

### What Just Happened?

Compare the feature extraction results to training from scratch:
- **Training accuracy:** Feature extraction should have lower training accuracy (it cannot overfit as easily because only the tiny head is being trained)
- **Test accuracy:** Feature extraction should have much HIGHER test accuracy — the pretrained features generalize
- **Training speed:** Feature extraction is faster because only the head's gradients are computed

This is the same architecture, the same data, the same training loop. The only difference is that the backbone started from ImageNet weights instead of random initialization.

## 5. YOUR TURN: Approach 3 — Fine-Tuning

Feature extraction keeps the backbone completely frozen. **Fine-tuning** goes one step further: unfreeze some pretrained layers and train them with a lower learning rate.

The key skill: **differential learning rates** via parameter groups.

Steps:
1. Load a fresh pretrained model (do not reuse the feature extraction model — its head is already trained)
2. Freeze all parameters
3. Replace the head
4. Unfreeze `model.layer4` (the last residual stage)
5. Create an optimizer with two parameter groups:
   - `model.fc.parameters()` with lr=1e-3 (new head, learn fast)
   - `model.layer4.parameters()` with lr=1e-5 (pretrained, learn slowly)

In [None]:
# TODO 5: Load a fresh pretrained ResNet-18

ft_model = ...  # YOUR CODE HERE

# TODO 6: Freeze all parameters

...  # YOUR CODE HERE

# TODO 7: Replace the head

...  # YOUR CODE HERE

# TODO 8: Unfreeze layer4 (the last residual stage)
# Hint: for param in ft_model.layer4.parameters(): param.requires_grad = True

...  # YOUR CODE HERE

# Verify
trainable = sum(p.numel() for p in ft_model.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in ft_model.parameters() if not p.requires_grad)
print(f'Trainable parameters: {trainable:,} (head + layer4)')
print(f'Frozen parameters: {frozen:,}')

In [None]:
# TODO 9: Create optimizer with differential learning rates
# Two parameter groups:
#   1. model.fc.parameters() with lr=1e-3  (new head: learn fast)
#   2. model.layer4.parameters() with lr=1e-5  (pretrained: learn slow)
#
# Hint:
# optimizer = optim.Adam([
#     {'params': ft_model.fc.parameters(), 'lr': 1e-3},
#     {'params': ft_model.layer4.parameters(), 'lr': 1e-5},
# ])

ft_optimizer = ...  # YOUR CODE HERE

# Train
ft_history = train_model(
    ft_model, train_loader, test_loader,
    optimizer=ft_optimizer, epochs=15, label='Fine-Tuning'
)

## 6. Compare All Three Approaches

Now let's see all three approaches side by side. The comparison should make it viscerally clear why transfer learning is the default approach for practical deep learning.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 4.5))
epochs_range = range(1, 16)

colors = {
    'scratch': '#ef4444',      # red
    'feature_ext': '#3b82f6',  # blue
    'fine_tune': '#22c55e',    # green
}

# Training Loss
axes[0].plot(epochs_range, scratch_history['train_loss'], 'o-', label='From Scratch', color=colors['scratch'], markersize=4)
axes[0].plot(epochs_range, fe_history['train_loss'], 's-', label='Feature Extraction', color=colors['feature_ext'], markersize=4)
axes[0].plot(epochs_range, ft_history['train_loss'], '^-', label='Fine-Tuning', color=colors['fine_tune'], markersize=4)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Loss')
axes[0].set_title('Training Loss')
axes[0].legend(fontsize=8)
axes[0].grid(True, alpha=0.3)

# Training Accuracy
axes[1].plot(epochs_range, scratch_history['train_acc'], 'o-', label='From Scratch', color=colors['scratch'], markersize=4)
axes[1].plot(epochs_range, fe_history['train_acc'], 's-', label='Feature Extraction', color=colors['feature_ext'], markersize=4)
axes[1].plot(epochs_range, ft_history['train_acc'], '^-', label='Fine-Tuning', color=colors['fine_tune'], markersize=4)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training Accuracy')
axes[1].legend(fontsize=8)
axes[1].grid(True, alpha=0.3)

# Test Accuracy (the one that matters!)
axes[2].plot(epochs_range, scratch_history['test_acc'], 'o-', label='From Scratch', color=colors['scratch'], markersize=4)
axes[2].plot(epochs_range, fe_history['test_acc'], 's-', label='Feature Extraction', color=colors['feature_ext'], markersize=4)
axes[2].plot(epochs_range, ft_history['test_acc'], '^-', label='Fine-Tuning', color=colors['fine_tune'], markersize=4)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Accuracy (%)')
axes[2].set_title('Test Accuracy (what matters!)')
axes[2].legend(fontsize=8)
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Summary table
print('\n' + '=' * 70)
print('COMPARISON SUMMARY')
print('=' * 70)
print(f'{"":25s} {"From Scratch":>15s} {"Feature Ext":>15s} {"Fine-Tune":>15s}')
print('-' * 70)
print(f'{"Final train acc":25s} {scratch_history["train_acc"][-1]:>14.1f}% {fe_history["train_acc"][-1]:>14.1f}% {ft_history["train_acc"][-1]:>14.1f}%')
print(f'{"Final test acc":25s} {scratch_history["test_acc"][-1]:>14.1f}% {fe_history["test_acc"][-1]:>14.1f}% {ft_history["test_acc"][-1]:>14.1f}%')
print(f'{"Best test acc":25s} {max(scratch_history["test_acc"]):>14.1f}% {max(fe_history["test_acc"]):>14.1f}% {max(ft_history["test_acc"]):>14.1f}%')

gap = scratch_history['train_acc'][-1] - scratch_history['test_acc'][-1]
print(f'{"Train-test gap (scratch)":25s} {gap:>14.1f}%')
print('-' * 70)
print()
print('Key observations:')
print('  1. From scratch: high train acc, low test acc = overfitting')
print('  2. Feature extraction: strong test acc with minimal training')
print('  3. Fine-tuning: may improve slightly over feature extraction')
print()
print('Same architecture. Same data. The difference is the STARTING POINT.')

## 7. Inspect What Changed

Let's verify our understanding by checking which parameters were updated in each approach.

In [None]:
# Print the model structure to see what we worked with
pretrained = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

print('ResNet-18 top-level modules:')
print('-' * 40)
for name, module in pretrained.named_children():
    num_params = sum(p.numel() for p in module.parameters())
    print(f'  {name:12s} — {num_params:>10,} params')

total = sum(p.numel() for p in pretrained.parameters())
fc_params = sum(p.numel() for p in pretrained.fc.parameters())
layer4_params = sum(p.numel() for p in pretrained.layer4.parameters())

print(f'\nTotal: {total:,}')
print(f'\nFeature extraction trained: fc = {fc_params:,} params ({fc_params/total*100:.2f}%)')
print(f'Fine-tuning trained: fc + layer4 = {fc_params + layer4_params:,} params ({(fc_params + layer4_params)/total*100:.1f}%)')

## 8. Reflection

Before moving on, consider:

1. **The overfitting gap:** Training from scratch shows a huge gap between training accuracy (~99%) and test accuracy (~40-60%). This is exactly what you predicted from the Overfitting and Regularization lesson. 11M parameters + 1500 images = memorization, not learning.

2. **Feature extraction is surprisingly effective:** With only ~1,500 trainable parameters (the fc layer), feature extraction achieves strong test accuracy. The pretrained backbone converts raw pixels into meaningful 512-dimensional feature vectors, and a simple linear classifier on top is enough.

3. **Fine-tuning is a refinement, not a revolution:** Fine-tuning may improve over feature extraction by a few percentage points. The real value is when your domain is significantly different from ImageNet, and the later layers need adaptation.

4. **The practical takeaway:** Always start with feature extraction. It is fast, hard to mess up, and gives a strong baseline. Only add fine-tuning if you need more accuracy and have enough data.

5. **Transfer learning changes the economics of deep learning:** You no longer need millions of images and weeks of GPU time. A pretrained model + a few hundred labeled images + an afternoon = a working classifier.

---

**Module 3.2 complete!** You now understand CNN architectures from LeNet through ResNet and how to put them to practical use with transfer learning.