# Week 11 â€” Representation Learning (CNNs)

This notebook explores convolutional networks and learned representations. You'll:
- Build CNNs and understand receptive fields
- Visualize filters and activations
- Apply transfer learning with pretrained models
- Conduct ablation studies

In [None]:
# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

torch.manual_seed(42)
print(f"PyTorch version: {torch.__version__}")

## 1. Build a Small CNN

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleCNN(num_classes=10)
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters())}")

## 2. Load MNIST and Train

In [None]:
# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Training loop
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

n_epochs = 5
for epoch in range(n_epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 200 == 0:
            print(f"Epoch {epoch+1}/{n_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

# Evaluate
model.eval()
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        output = model(data)
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()

accuracy = 100. * correct / len(test_dataset)
print(f"\nTest Accuracy: {accuracy:.2f}%")

## 3. Visualize Filters and Activations

In [None]:
# Visualize first conv layer filters
filters = model.conv1.weight.data.cpu().numpy()

fig, axes = plt.subplots(4, 8, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    if i < filters.shape[0]:
        ax.imshow(filters[i, 0], cmap='gray')
        ax.axis('off')
plt.suptitle('Conv1 Filters (32 filters, 3x3 each)')
plt.tight_layout()
plt.show()

# Capture activations using hooks
activations = {}
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

model.conv1.register_forward_hook(get_activation('conv1'))
model.conv2.register_forward_hook(get_activation('conv2'))

# Get a sample image
sample_img, _ = test_dataset[0]
sample_img = sample_img.unsqueeze(0)
_ = model(sample_img)

# Visualize activations
fig, axes = plt.subplots(4, 8, figsize=(12, 6))
conv1_act = activations['conv1'][0].cpu().numpy()
for i, ax in enumerate(axes.flat):
    if i < conv1_act.shape[0]:
        ax.imshow(conv1_act[i], cmap='viridis')
        ax.axis('off')
plt.suptitle('Conv1 Activations (32 feature maps)')
plt.tight_layout()
plt.show()

## 4. Transfer Learning with Pretrained Model

In [None]:
# Load pretrained ResNet18 (for demonstration)
pretrained_model = models.resnet18(pretrained=True)

# Freeze all layers
for param in pretrained_model.parameters():
    param.requires_grad = False

# Replace final layer
num_ftrs = pretrained_model.fc.in_features
pretrained_model.fc = nn.Linear(num_ftrs, 10)

print("Pretrained ResNet18 adapted for MNIST (10 classes)")
print(f"Trainable parameters: {sum(p.numel() for p in pretrained_model.parameters() if p.requires_grad)}")
print(f"Total parameters: {sum(p.numel() for p in pretrained_model.parameters())}")

## Exercises

1. **Deeper CNNs**: Build a 5-layer CNN and compare to the 2-layer version
2. **Receptive Fields**: Calculate and visualize receptive field sizes
3. **Feature Extraction**: Extract features from a pretrained model and train a classifier
4. **Ablation Study**: Vary depth/width and document effects on accuracy

## Deliverables

- [ ] CNN implementation and training on MNIST/CIFAR-10
- [ ] Filter and activation visualizations
- [ ] Transfer learning experiment
- [ ] Ablation study report