# Exercise 1: Image Classification and Modeling Choices

In this exercise, we'll explore how different **modeling choices** affect performance on image classification tasks. Machine learning is a *modeling science*, and this means that we have to make informed decisions about what is important and what isn't in our data. The choices we're engaging with in this exercise are:

- **Data representation**: How do we structure our input?
- **Inductive biases**: What assumptions do we build into our model?
- **Data augmentation**: Can we simulate data that *exemplify* important features of the problem?

We'll work with two datasets:
- **FashionMNIST**: 28×28 grayscale images of clothing items (10 classes)
- **CIFAR-10**: 32×32 color images of objects (10 classes)

## Learning Objectives

By the end of this exercise, you should understand:
1. Why a linear model treats images as unstructured vectors
2. How CNNs encode spatial structure as an inductive bias
3. When and why data augmentation helps
4. How modeling choices interact with dataset characteristics
5. How to detect overfitting using train/validation curves

In [None]:
# Imports and configuration
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Device and reproducibility
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(67)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(67)

# Human-readable class names for FashionMNIST
class_names = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']
print('Using device:', device)

## Part 1: Loading and Exploring FashionMNIST

Let's start by loading our first dataset to understand what we're working with.

**Important**: We'll split our data into three sets:
- **Training set**: Used to update model weights
- **Validation set**: Used to monitor overfitting and tune hyperparameters
- **Test set**: Used only for final evaluation (never seen during training!)

In [None]:
# Data loading and preprocessing (FashionMNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.0,), (0.5,))  # Normalize to [-1, 1] range
])

data_dir = './data'

# Load full training set
full_train_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform)

# Split into train (48k) and validation (12k) - 80/20 split
train_size = 48000
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size], 
                                          generator=torch.Generator().manual_seed(67))

# Load test set (10k samples)
test_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform)

# Batch size: Number of samples processed before updating weights
# Larger = faster training but more memory; smaller = more stable but slower
# 64 is a common choice balancing speed and stability
batch_size = 64

# num_workers: Parallel data loading processes (adjust based on your CPU cores)
num_workers = 4

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print(f'Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}, Test samples: {len(test_dataset)}')

In [None]:
# Let's visualize a few examples to understand our data
images, labels = next(iter(train_loader))
print('Batch shapes:', images.shape, labels.shape)  # (batch_size, 1, 28, 28)

plt.figure(figsize=(12, 3))
for i in range(8):
    plt.subplot(1, 8, i+1)
    plt.imshow(images[i].squeeze().numpy(), cmap='gray')
    plt.title(class_names[labels[i]], fontsize=9)
    plt.axis('off')
plt.suptitle('Sample FashionMNIST Images', fontsize=12, y=1.02)
plt.tight_layout()
plt.show()

## Part 2: Baseline Linear Model

**The simplest approach**: Treat each image as a flat vector of 784 pixels (28×28) and learn a linear mapping to 10 classes.

**Key question**: What does this model assume about images? Does it know that nearby pixels are related? Does it understand that a shirt in the top-left corner is the same as a shirt in the center?

**Spoiler**: No! A linear model treats each pixel position as an independent feature.

In [None]:
class LinearClassifier(nn.Module):
    def __init__(self, input_dim=784, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)  # flatten to (batch_size, 784)
        return self.fc(x)

linear_model = LinearClassifier().to(device)
print(f"Linear model parameters: {sum(p.numel() for p in linear_model.parameters()):,}")

In [None]:
# Training and evaluation helpers
def evaluate(model, loader):
    """Compute accuracy on a dataset."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labs in loader:
            imgs, labs = imgs.to(device), labs.to(device)
            outputs = model(imgs)
            _, preds = outputs.max(1)
            correct += (preds == labs).sum().item()
            total += labs.size(0)
    return 100.0 * correct / total

def compute_loss(model, loader, criterion):
    """Compute average loss on a dataset."""
    model.eval()
    total_loss = 0.0
    total_samples = 0
    with torch.no_grad():
        for imgs, labs in loader:
            imgs, labs = imgs.to(device), labs.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labs)
            total_loss += loss.item() * imgs.size(0)
            total_samples += imgs.size(0)
    return total_loss / total_samples

def get_predictions(model, loader):
    """Get all predictions and true labels for confusion matrix."""
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for imgs, labs in loader:
            imgs, labs = imgs.to(device), labs.to(device)
            outputs = model(imgs)
            _, preds = outputs.max(1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labs.cpu().numpy())
    return np.array(all_labels), np.array(all_preds)

def plot_confusion_matrix(model, loader, class_names, title="Confusion Matrix"):
    """Plot a confusion matrix for the model's predictions."""
    y_true, y_pred = get_predictions(model, loader)
    cm = confusion_matrix(y_true, y_pred)
    
    # Normalize to percentages
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Percentage (%)'})
    plt.title(title, fontsize=14, pad=20)
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
    
    return cm

def plot_training_curves(history, title="Training History"):
    """Plot training and validation loss curves."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))
    
    # Loss curves
    ax1.plot(history['train_loss'], label='Train Loss', linewidth=2)
    ax1.plot(history['val_loss'], label='Val Loss', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=11)
    ax1.set_ylabel('Loss', fontsize=11)
    ax1.set_title('Loss Curves', fontsize=12)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Accuracy curves
    ax2.plot(history['val_acc'], label='Val Accuracy', linewidth=2, color='green')
    ax2.set_xlabel('Epoch', fontsize=11)
    ax2.set_ylabel('Accuracy (%)', fontsize=11)
    ax2.set_title('Validation Accuracy', fontsize=12)
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.suptitle(title, fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()

def train_model(model, train_loader, val_loader, epochs=5, lr=1e-3):
    """
    Train a model and report progress.
    
    Args:
        epochs: Number of complete passes through training data
                More epochs = more learning, but risk of overfitting
        lr: Learning rate - controls step size during optimization
            Too high = unstable training; too low = slow convergence
            1e-3 (0.001) is a common starting point for Adam optimizer
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Track metrics for plotting
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_acc': []
    }
    
    for epoch in range(1, epochs+1):
        # Training phase
        model.train()
        running_loss = 0.0
        for imgs, labs in train_loader:
            imgs, labs = imgs.to(device), labs.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labs)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * imgs.size(0)
        
        # Compute metrics
        train_loss = running_loss / len(train_loader.dataset)
        val_loss = compute_loss(model, val_loader, criterion)
        val_acc = evaluate(model, val_loader)
        
        # Store history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f'Epoch {epoch:02d} — Train Loss: {train_loss:.4f} — Val Loss: {val_loss:.4f} — Val Acc: {val_acc:.2f}%')
    
    return model, history

# Train the linear baseline
print("Training linear model...")
linear_model, linear_history = train_model(linear_model, train_loader, val_loader, epochs=10)
linear_acc = evaluate(linear_model, test_loader)

### Detecting Overfitting: Training Curves

Before looking at the confusion matrix, let's examine the **training curves**. These are crucial for understanding how well your model generalizes.

**What to look for:**
- **Train loss decreasing**: Model is learning from training data ✓
- **Val loss decreasing**: Model generalizes to new data ✓
- **Val loss increasing while train loss decreases**: **OVERFITTING!** ⚠️
  - The model is memorizing training data instead of learning patterns
  - Solution: Stop training earlier, add regularization, or get more data
- **Gap between train and val loss**: Some gap is normal, but large gaps indicate overfitting

Let's see how our linear model performs:

In [None]:
# Plot training curves for linear model
plot_training_curves(linear_history, title="Linear Model Training History (FashionMNIST)")

### Understanding Model Predictions: The Confusion Matrix

Now let's introduce another diagnostic tool: the **confusion matrix**.

**What is it?**  
A confusion matrix shows where your model gets confused. Each row represents the true class, and each column represents what the model predicted.

**How to read it:**
- **Diagonal elements** (top-left to bottom-right): Correct predictions. Higher is better!
- **Off-diagonal elements**: Mistakes. The value at row *i*, column *j* tells you how often class *i* was misclassified as class *j*.
- **Patterns to look for:**
  - Are certain classes systematically confused? (e.g., "Shirt" vs "T-shirt/top")
  - Does the model have a bias toward predicting certain classes?
  - Are mistakes symmetric? (Does the model confuse A→B as often as B→A?)

Let's see where our linear model struggles:

In [None]:
# Confusion matrix for linear model
plot_confusion_matrix(linear_model, test_loader, class_names, 
                     title="Linear Model Confusion Matrix (FashionMNIST)")

### Error Analysis: Hardest-to-Classify Images

Let's look at the images our model struggles with most. This can reveal:
- Ambiguous cases even humans would find difficult
- Systematic biases in the model
- Data quality issues

We'll find images where the model was most confident but wrong.

In [None]:
def show_hardest_examples(model, loader, class_names, num_examples=16):
    """Show the hardest-to-classify examples (high confidence, wrong prediction)."""
    model.eval()
    
    # Collect all predictions with confidence scores
    all_errors = []
    
    with torch.no_grad():
        for imgs, labs in loader:
            imgs, labs = imgs.to(device), labs.to(device)
            outputs = model(imgs)
            probs = torch.softmax(outputs, dim=1)
            confidences, preds = probs.max(1)
            
            # Find incorrect predictions
            incorrect = preds != labs
            
            # Store errors with their confidence
            for i in range(len(imgs)):
                if incorrect[i]:
                    all_errors.append({
                        'image': imgs[i].cpu(),
                        'true_label': labs[i].item(),
                        'pred_label': preds[i].item(),
                        'confidence': confidences[i].item()
                    })
    
    # Sort by confidence (descending) - most confident mistakes first
    all_errors.sort(key=lambda x: x['confidence'], reverse=True)
    
    # Plot top errors
    num_to_show = min(num_examples, len(all_errors))
    fig = plt.figure(figsize=(12, 8))
    
    for i in range(num_to_show):
        error = all_errors[i]
        plt.subplot(4, 4, i+1)
        
        # Handle grayscale vs RGB
        img = error['image'].squeeze()
        if img.ndim == 3:  # RGB
            img = img.permute(1, 2, 0).numpy()
            img = np.clip(img, 0, 1)
            plt.imshow(img)
        else:  # Grayscale
            plt.imshow(img.numpy(), cmap='gray')
        
        true_name = class_names[error['true_label']]
        pred_name = class_names[error['pred_label']]
        conf = error['confidence'] * 100
        
        plt.title(f'True: {true_name}\nPred: {pred_name} ({conf:.0f}%)', 
                 fontsize=8, color='red')
        plt.axis('off')
    
    plt.suptitle('Hardest-to-Classify Examples (High Confidence Errors)', fontsize=12)
    plt.tight_layout()
    plt.show()

# Show hardest examples for linear model
show_hardest_examples(linear_model, test_loader, class_names)

### Reflection: Linear Model Performance

The linear model achieves around **~83-84% accuracy**. Not bad for such a simple approach!

But think about what it's learning: it's finding patterns like "if pixel 234 is bright and pixel 567 is dark, it's probably a sneaker." It has no concept of shapes, edges, or spatial relationships.

**Looking at the training curves**: Notice that train and validation losses track closely together - this model is **not overfitting** significantly. Why? Because it's too simple to memorize the training data!

**Looking at the confusion matrix**, you'll likely notice:
- **Shirt vs. T-shirt/top**: These are often confused because they're visually similar
- **Trousers and ankle boots**: These are rarely confused with other classes - why?

**Looking at the hardest examples**: You'll see that even the model's most confident mistakes are often genuinely ambiguous cases.

**Question to consider**: Why does this work at all? What makes FashionMNIST amenable to this approach?

## Part 3: Convolutional Neural Network (CNN)

Now let's refine our **inductive bias**. We want to exploit that images have spatial structure.

**Key ideas**:
- **Convolutions**: Look at local patches of pixels, not individual pixels
- **Translation invariance**: A sneaker is a sneaker whether it's in the top-left or bottom-right
- **Hierarchical features**: Early layers detect edges, later layers detect shapes

We're choosing a *model architecture* based on our preconceived notions of the data and how they relate to the task. This is one of the core principles of machine learning.

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super().__init__()
        # Convolutional feature extractor
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),  # 32 filters, 3×3 kernels
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # downsample by 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((7, 7))  # ensure fixed size for classifier
        )
        # Classification head
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

cnn_model = SimpleCNN(in_channels=1).to(device)
print(f"CNN parameters: {sum(p.numel() for p in cnn_model.parameters()):,}")

In [None]:
# Train the CNN
print("Training CNN...")
cnn_model, cnn_history = train_model(cnn_model, train_loader, val_loader, epochs=10, lr=1e-3)
cnn_acc = evaluate(cnn_model, test_loader)

In [None]:
# Plot training curves for CNN
plot_training_curves(cnn_history, title="CNN Training History (FashionMNIST)")

In [None]:
# Confusion matrix for CNN on FashionMNIST
plot_confusion_matrix(cnn_model, test_loader, class_names,
                     title="CNN Confusion Matrix (FashionMNIST)")

In [None]:
# Show hardest examples for CNN
show_hardest_examples(cnn_model, test_loader, class_names)

In [None]:
# Visualize some predictions
cnn_model.eval()
imgs, labs = next(iter(test_loader))
imgs, labs = imgs.to(device), labs.to(device)
with torch.no_grad():
    outs = cnn_model(imgs[:32])
    preds = outs.argmax(dim=1)

plt.figure(figsize=(12, 8))
for i in range(32):
    plt.subplot(4, 8, i+1)
    plt.imshow(imgs[i].cpu().squeeze(), cmap='gray')
    correct = '✓' if preds[i] == labs[i] else '✗'
    plt.title(f'{correct} {class_names[preds[i]]}', fontsize=9, 
              color='green' if preds[i] == labs[i] else 'red')
    plt.axis('off')
plt.suptitle('CNN Predictions on FashionMNIST', fontsize=12)
plt.tight_layout()
plt.show()

### Reflection: The importance of choosing the right inductive bias

The CNN achieves around **91-92% accuracy** - a significant improvement over the linear model!

**Why?** Because we made and encoded some good assumptions:
- Nearby pixels are related (convolutions)
- Features can appear anywhere (translation invariance)
- Complex patterns are built from simpler ones (hierarchical layers)

**Looking at the training curves**: The CNN shows a slightly larger gap between train and validation loss compared to the linear model. This is expected - more powerful models can fit training data better, but we're not seeing severe overfitting yet.

**Comparing the confusion matrices**: Notice how the CNN's confusion matrix has stronger diagonal values and weaker off-diagonal values. The CNN still confuses similar items (Shirt/T-shirt, Pullover/Coat), but less frequently. This shows that spatial structure helps the model learn more discriminative features.

**Looking at the hardest examples**: Even the CNN's mistakes are often on genuinely difficult cases. Compare these to the linear model's errors - do you notice any patterns?

**Key insight**: The CNN doesn't just have more parameters—it has *better* parameters that match the structure of the problem.

---

**Comparison so far**:
- Linear model: ~83% (treats image as flat vector)
- CNN: ~91% (exploits spatial structure)

## Part 4: Transfer to CIFAR-10

Now let's test our CNN architecture on a more challenging dataset: **CIFAR-10**. These are 32×32 color images of real-world objects (planes, cars, animals, etc.).

**Key differences from FashionMNIST**:
- Color images (3 channels instead of 1)
- More complex, natural images
- More variation within each class

Let's see how our CNN architecture performs without any modifications.

In [None]:
# Load CIFAR-10 dataset
cifar_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

# Load and split CIFAR-10
cifar_full_train = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=cifar_transform)

# Split into train (40k) and validation (10k) - 80/20 split
cifar_train_size = 40000
cifar_val_size = len(cifar_full_train) - cifar_train_size
cifar_train_ds, cifar_val_ds = random_split(cifar_full_train, [cifar_train_size, cifar_val_size],
                                             generator=torch.Generator().manual_seed(67))

cifar_test_ds = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=cifar_transform)

# Batch size for CIFAR-10: Using 128 instead of 64
# Larger batch = better GPU utilization for color images (3 channels vs 1)
# Also provides more stable gradient estimates for this harder task
cifar_batch = 128

cifar_train_loader = DataLoader(cifar_train_ds, batch_size=cifar_batch, shuffle=True, num_workers=num_workers)
cifar_val_loader = DataLoader(cifar_val_ds, batch_size=cifar_batch, shuffle=False, num_workers=num_workers)
cifar_test_loader = DataLoader(cifar_test_ds, batch_size=cifar_batch, shuffle=False, num_workers=num_workers)

cifar_class_names = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']
print(f'CIFAR-10 — Train: {len(cifar_train_ds)}, Val: {len(cifar_val_ds)}, Test: {len(cifar_test_ds)}')

In [None]:
# Visualize CIFAR-10 samples
imgs_c, labs_c = next(iter(cifar_train_loader))
plt.figure(figsize=(12, 3))
for i in range(8):
    plt.subplot(1, 8, i+1)
    img = imgs_c[i].permute(1, 2, 0).numpy()
    img = (img * [0.247, 0.243, 0.261]) + [0.4914, 0.4822, 0.4465]
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    plt.title(cifar_class_names[labs_c[i]], fontsize=9)
    plt.axis('off')
plt.suptitle('Sample CIFAR-10 Images', fontsize=12)
plt.tight_layout()
plt.show()

In [None]:
# Train CNN on CIFAR-10 (note: in_channels=3 for RGB)
# Using 20 epochs: CIFAR-10 is harder and benefits from longer training
# Watch the training curves to see if we need early stopping!
print("Training CNN on CIFAR-10...")
cnn_cifar = SimpleCNN(in_channels=3, num_classes=10).to(device)
cnn_cifar, cifar_history = train_model(cnn_cifar, cifar_train_loader, cifar_val_loader, epochs=10, lr=1e-3)
cifar_acc = evaluate(cnn_cifar, cifar_test_loader)

In [None]:
# Plot training curves for CIFAR-10
plot_training_curves(cifar_history, title="CNN Training History (CIFAR-10, No Augmentation)")

In [None]:
# Confusion matrix for CNN on CIFAR-10 (no augmentation)
plot_confusion_matrix(cnn_cifar, cifar_test_loader, cifar_class_names,
                     title="CNN Confusion Matrix (CIFAR-10, No Augmentation)")

In [None]:
# Show hardest examples for CIFAR-10
show_hardest_examples(cnn_cifar, cifar_test_loader, cifar_class_names)

### Reflection: Performance on CIFAR-10

The CNN achieves around **74% accuracy** on CIFAR-10. That's much lower than on FashionMNIST!

**Why the drop?**
- CIFAR-10 images are more complex and varied
- Objects can appear at different scales, angles, and positions
- Background clutter and occlusion
- Only 40,000 training images for a harder task

**Looking at the training curves**: Notice the gap between train and validation loss growing over time? This is **overfitting**! The model is starting to memorize training examples rather than learning generalizable patterns. This suggests we could benefit from:
1. Early stopping (stop training when val loss stops improving)
2. Regularization techniques
3. **Data augmentation** (our next topic!)

**Looking at the confusion matrix**, you'll notice interesting patterns:
- **Cat vs. Dog**: A classic challenge even for humans with low-resolution images!
- **Bird vs. Plane**: Both can appear in the sky with similar backgrounds
- **Automobile vs. Truck**: Vehicles are often confused with each other

**Looking at the hardest examples**: These reveal the challenges of CIFAR-10 - low resolution, occlusion, unusual angles, etc.

**Question**: What modeling choice could help us make better use of our limited training data?

## Part 5: Data Augmentation as a Strategic Choice

**The problem**: Our model is overfitting to specific details in the training images (exact positions, orientations, etc.).

**The solution**: Use **data augmentation** to artificially expand our training set by applying random transformations that preserve the label. This induces *learnable invariance* into our dataset, that is, our dataset now *exemplifies* that image labels are invariant to image transformations like horizontal flips and cropping out some of the edge.

For CIFAR-10, we'll use:
- Random horizontal flips (a car is still a car when flipped)
- Random crops (objects can appear anywhere in the frame)

This is another *modeling choice*: we're encoding our knowledge that these transformations don't change the object's identity.

In [None]:
# Create augmented CIFAR-10 training set
# Note: We only augment TRAINING data, never validation or test data!
cifar_aug_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 50% chance of horizontal flip
    transforms.RandomCrop(32, padding=4),  # Random crop with padding
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

# Create new training set with augmentation
cifar_train_aug_full = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=False, transform=cifar_aug_transform)
cifar_train_aug_ds, _ = random_split(cifar_train_aug_full, [cifar_train_size, cifar_val_size],
                                     generator=torch.Generator().manual_seed(67))

cifar_train_aug_loader = DataLoader(cifar_train_aug_ds, batch_size=cifar_batch, shuffle=True, num_workers=num_workers)

# Train CNN with augmentation
print("Training CNN on CIFAR-10 (with augmentation)...")
cnn_cifar_aug = SimpleCNN(in_channels=3, num_classes=10).to(device)
# We can train longer here without overfitting
cnn_cifar_aug, cifar_aug_history = train_model(cnn_cifar_aug, cifar_train_aug_loader, cifar_val_loader, epochs=20, lr=1e-3)
cifar_acc_aug = evaluate(cnn_cifar_aug, cifar_test_loader)

In [None]:
# Plot training curves with augmentation
plot_training_curves(cifar_aug_history, title="CNN Training History (CIFAR-10, With Augmentation)")

In [None]:
# Confusion matrix for CNN on CIFAR-10 (with augmentation)
plot_confusion_matrix(cnn_cifar_aug, cifar_test_loader, cifar_class_names,
                     title="CNN Confusion Matrix (CIFAR-10, With Augmentation)")

In [None]:
# Show hardest examples with augmentation
show_hardest_examples(cnn_cifar_aug, cifar_test_loader, cifar_class_names)

In [None]:
# Visualize predictions with augmented model
cnn_cifar_aug.eval()
imgs_c, labs_c = next(iter(cifar_test_loader))
imgs_c, labs_c = imgs_c.to(device), labs_c.to(device)
with torch.no_grad():
    outs_c = cnn_cifar_aug(imgs_c[:8])
    preds_c = outs_c.argmax(dim=1)

plt.figure(figsize=(12, 3))
for i in range(8):
    plt.subplot(1, 8, i+1)
    img = imgs_c[i].cpu().permute(1, 2, 0).numpy()
    img = (img * [0.247, 0.243, 0.261]) + [0.4914, 0.4822, 0.4465]
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    correct = '✓' if preds_c[i] == labs_c[i] else '✗'
    plt.title(f'{correct} {cifar_class_names[preds_c[i]]}', fontsize=9,
              color='green' if preds_c[i] == labs_c[i] else 'red')
    plt.axis('off')
plt.suptitle('CNN Predictions on CIFAR-10 (with augmentation)', fontsize=12)
plt.tight_layout()
plt.show()

### Reflection: The Impact of Augmentation

With augmentation, accuracy improves to around **77%**. While this might seem marginal, it's worth noting that data augmentation also improves the ability for the model to *generalize* to new examples outside the distribution of the dataset as well as preventing overfitting (we can train for longer without the model memorizing data).

**Looking at the training curves**: Compare the two CIFAR-10 training curves (scroll up). With augmentation:
- The gap between train and val loss is **smaller** - less overfitting!
- Training loss is higher (model can't perfectly memorize augmented data)
- Validation loss is lower (better generalization)
- We could potentially train even longer without overfitting

**Why does augmentation help more on CIFAR-10 than it would on FashionMNIST?**
- CIFAR-10 has more variation in object position, scale, and orientation
- The training set is relatively small for the complexity of the task
- Augmentation effectively teaches the model to be invariant to these transformations

**Key insight**: Data augmentation is most valuable when:
1. The task has natural invariances (flips, rotations, crops)
2. The test-time data is expected to vary along these transformations
3. The task is complex and the dataset is not expected to exemplify these invariances without augmentation

---

**Note**: Try adding augmentation to FashionMNIST (you'd see minimal improvement) vs. CIFAR-10 (clear improvement). This illustrates how modeling choices interact with dataset characteristics.

## Summary: Modeling Choices Matter

Let's review what we've learned about how different choices affect performance:

| Model | Dataset | Accuracy | Key Insight |
|-------|---------|----------|-------------|
| Linear | FashionMNIST | ~84% | Surprisingly effective on simple, aligned images |
| CNN | FashionMNIST | ~92% | Spatial inductive bias helps significantly |
| CNN | CIFAR-10 (no aug) | ~74% | Same architecture, harder dataset, shows overfitting |
| CNN | CIFAR-10 (aug) | ~77% | Augmentation reduces overfitting and improves generalization |

### Key Takeaways

1. **Inductive bias matters**: The CNN's spatial structure assumption led to a 6% improvement on FashionMNIST
2. **Dataset characteristics matter**: The same CNN performs differently on FashionMNIST vs. CIFAR-10
3. **Strategic augmentation matters**: Augmentation provides a bigger boost on complex, varied datasets
4. **Overfitting detection is crucial**: Training curves reveal when models memorize vs. generalize
5. **Machine learning is modeling**: Every choice (architecture, data representation, augmentation) encodes assumptions about the problem

### Questions for Further Exploration

- Why didn't we use augmentation on FashionMNIST? (Try it and see!)
- What other augmentations might help on CIFAR-10? (color jittering, cutout, etc.)
- How would a linear model perform on CIFAR-10? (Spoiler: poorly!)
- What if we trained for more epochs? Used a deeper network?

The point isn't to achieve state-of-the-art results—it's to understand how and why different modeling choices lead to different outcomes.

In [None]:
# Optional: Print final comparison
print("\n" + "="*60)
print("FINAL RESULTS SUMMARY")
print("="*60)
print(f"Linear Model (FashionMNIST):          {linear_acc:.2f}%")
print(f"CNN (FashionMNIST):                   {cnn_acc:.2f}%")
print(f"CNN (CIFAR-10, no augmentation):      {cifar_acc:.2f}%")
print(f"CNN (CIFAR-10, with augmentation):    {cifar_acc_aug:.2f}%")
print("="*60)
print("\nKey Observations:")
print("- CNN improves over Linear by ~8% on FashionMNIST")
print("- CIFAR-10 is much harder (~17% drop from FashionMNIST)")
print("- Augmentation helps reduce overfitting on CIFAR-10")
print("- Check training curves to understand generalization!")