# MNIST Project: Your First Image Classifier

**Module 2.2, Lesson 2 (Project)** | CourseAI

This is a project notebook — the main deliverable of the lesson. You will build a complete image classifier from scratch:

1. **Load MNIST** with torchvision — transforms, DataLoaders
2. **Build the model** — MNISTClassifier with three linear layers
3. **Write the training loop** — cross-entropy loss, Adam, accuracy tracking
4. **Evaluate on the test set** — model.eval(), torch.no_grad()
5. **Visualize predictions** — correct and incorrect with confidence scores
6. **Build an improved model** — BatchNorm, Dropout, weight decay; compare

Steps 1–4 are **guided** (mostly complete code with small blanks). Steps 5–6 are **supported** (template provided, more work required).

**Estimated time:** 20–30 minutes.

---

## Setup

Run this cell to import everything and configure the environment.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# For nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

---

## Step 1: Load MNIST

MNIST contains 70,000 grayscale images of handwritten digits (0–9), each 28×28 pixels. We apply two transforms:

- **ToTensor()** — converts PIL images to tensors and scales pixel values from [0, 255] to [0.0, 1.0]
- **Normalize((0.1307,), (0.3081,))** — normalizes using the MNIST dataset mean and standard deviation

Fill in the DataLoader creation below.

In [None]:
# Define transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
])

# Download and load MNIST
train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

# TODO: Create DataLoaders for train and test sets
# train_loader: batch_size=64, shuffle=True
# test_loader:  batch_size=64, shuffle=False
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=64, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=64, shuffle=False
)

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples:     {len(test_dataset)}')
print(f'Image shape:      {train_dataset[0][0].shape}')
print(f'Train batches:    {len(train_loader)}')
print(f'Test batches:     {len(test_loader)}')

In [None]:
# Visualize some samples
fig, axes = plt.subplots(2, 10, figsize=(14, 3))

for i in range(20):
    img, label = train_dataset[i]
    row, col = i // 10, i % 10
    axes[row, col].imshow(img.squeeze(), cmap='gray')
    axes[row, col].set_title(str(label), fontsize=10)
    axes[row, col].axis('off')

fig.suptitle('MNIST Sample Images', fontsize=14)
plt.tight_layout()
plt.show()

---

## Step 2: Build the Model

Your `MNISTClassifier` is a simple dense (fully-connected) network:

```
Input (1x28x28)
  -> Flatten            # 784
  -> Linear(784, 256)   # first hidden layer
  -> ReLU
  -> Linear(256, 128)   # second hidden layer
  -> ReLU
  -> Linear(128, 10)    # output (one per digit class)
```

The `__init__` method is provided. Fill in the `forward` method.

In [None]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        # TODO: Implement the forward pass
        # 1. Flatten the input
        # 2. fc1 -> ReLU
        # 3. fc2 -> ReLU
        # 4. fc3 (no activation — CrossEntropyLoss handles softmax)
        #
        # Hint: x = self.relu(self.fc1(x))
        pass  # Replace this

# Create model and verify
model = MNISTClassifier().to(device)

# Quick dimension check
test_input = torch.randn(1, 1, 28, 28).to(device)
test_output = model(test_input)
print(f'Input shape:  {test_input.shape}')   # [1, 1, 28, 28]
print(f'Output shape: {test_output.shape}')  # [1, 10]
print(f'Parameters:   {sum(p.numel() for p in model.parameters()):,}')

if test_output is not None and test_output.shape == torch.Size([1, 10]):
    print('\nDimensions correct! Ready to train.')
else:
    print('\nDimension mismatch — check your forward method.')

---

## Step 3: Training Loop

The training loop follows the standard PyTorch pattern:

1. Forward pass → compute loss
2. Backward pass → compute gradients
3. Optimizer step → update weights
4. Track metrics

We use **CrossEntropyLoss** (combines LogSoftmax + NLLLoss) and **Adam** optimizer with lr=1e-3.

Fill in the missing line in the training loop.

In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training configuration
num_epochs = 10

# History tracking
history = {
    'train_loss': [],
    'train_acc': [],
    'test_acc': [],
}

print(f'Training for {num_epochs} epochs...')
print('=' * 65)

for epoch in range(num_epochs):
    # --- Training phase ---
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # TODO: Backward pass and optimizer step
        # Hint: three lines — zero_grad, backward, step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Track metrics
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    history['train_loss'].append(epoch_loss)
    history['train_acc'].append(epoch_acc)

    # --- Evaluation phase (computed here for logging, full eval in Step 4) ---
    model.eval()
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            test_correct += (predicted == labels).sum().item()
            test_total += labels.size(0)

    test_acc = test_correct / test_total
    history['test_acc'].append(test_acc)

    print(f'Epoch {epoch+1:2d}/{num_epochs}  '
          f'Train Loss: {epoch_loss:.4f}  '
          f'Train Acc: {epoch_acc:.1%}  '
          f'Test Acc: {test_acc:.1%}')

print('=' * 65)
print(f'\nFinal test accuracy: {history["test_acc"][-1]:.1%}')

---

## Step 4: Evaluate on the Test Set

We already computed test accuracy inside the training loop for logging, but here we do a clean standalone evaluation to confirm the final numbers.

Two important patterns:
- **model.eval()** — switches BatchNorm/Dropout to inference mode (matters later in Step 6)
- **torch.no_grad()** — disables gradient computation, saves memory

In [None]:
def evaluate_model(model, test_loader):
    """Evaluate model on the test set. Returns accuracy."""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    return accuracy

test_accuracy = evaluate_model(model, test_loader)
print(f'Test accuracy: {test_accuracy:.2%}')
print(f'Correctly classified: {int(test_accuracy * len(test_dataset)):,} / {len(test_dataset):,}')

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

epochs_range = range(1, num_epochs + 1)

# Loss
ax1.plot(epochs_range, history['train_loss'], 'o-', linewidth=2, label='Train Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.legend()
ax1.grid(alpha=0.3)

# Accuracy
ax2.plot(epochs_range, history['train_acc'], 'o-', linewidth=2, label='Train Acc')
ax2.plot(epochs_range, history['test_acc'], 's-', linewidth=2, label='Test Acc')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Accuracy')
ax2.legend()
ax2.grid(alpha=0.3)
ax2.set_ylim(0.9, 1.0)

plt.tight_layout()
plt.show()

---

## Step 5: Visualize Predictions

Numbers only tell part of the story. Looking at actual predictions reveals **what** the model gets wrong and **how confident** it is.

Your task: fill in the plotting code to show a grid of predictions. Correct predictions get a green title, incorrect get a red title. Show the softmax confidence for each.

This is a **supported** step — the template is provided but you need to fill in more than before.

In [None]:
# Collect predictions on a batch of test images
model.eval()
dataiter = iter(test_loader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)

with torch.no_grad():
    outputs = model(images)
    probabilities = torch.softmax(outputs, dim=1)
    confidences, predicted = torch.max(probabilities, dim=1)

# Move to CPU for plotting
images_cpu = images.cpu()
labels_cpu = labels.cpu()
predicted_cpu = predicted.cpu()
confidences_cpu = confidences.cpu()

# TODO: Plot a 4x8 grid of predictions
#
# For each image:
# - Show the grayscale image
# - Title format: "pred (conf%)" e.g. "7 (98.3%)"
# - If correct: green title
# - If incorrect: red title, and include true label: "pred (conf%) [true]"
#
# Hint:
#   ax.set_title(title, color='green')  or  color='red'
#   Use images_cpu[i].squeeze() with cmap='gray'

fig, axes = plt.subplots(4, 8, figsize=(16, 8))

for i in range(32):
    row, col = i // 8, i % 8
    ax = axes[row, col]

    ax.imshow(images_cpu[i].squeeze(), cmap='gray')
    ax.axis('off')

    pred = predicted_cpu[i].item()
    true = labels_cpu[i].item()
    conf = confidences_cpu[i].item() * 100

    # Your code here: set the title with color based on correctness
    pass

fig.suptitle('Model Predictions (green=correct, red=incorrect)', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Show specifically the incorrect predictions (if any)
# This helps you understand what the model finds hard

model.eval()
wrong_images = []
wrong_labels = []
wrong_preds = []
wrong_confs = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        confs, preds = torch.max(probs, dim=1)

        mask = preds != labels
        if mask.any():
            wrong_images.append(images[mask].cpu())
            wrong_labels.append(labels[mask].cpu())
            wrong_preds.append(preds[mask].cpu())
            wrong_confs.append(confs[mask].cpu())

        if sum(len(w) for w in wrong_images) >= 16:
            break

wrong_images = torch.cat(wrong_images)[:16]
wrong_labels = torch.cat(wrong_labels)[:16]
wrong_preds = torch.cat(wrong_preds)[:16]
wrong_confs = torch.cat(wrong_confs)[:16]

n_wrong = len(wrong_images)
cols = min(8, n_wrong)
rows = (n_wrong + cols - 1) // cols

fig, axes = plt.subplots(rows, cols, figsize=(2 * cols, 2.5 * rows))
if rows == 1:
    axes = axes[np.newaxis, :] if cols > 1 else np.array([[axes]])

for i in range(n_wrong):
    row, col = i // cols, i % cols
    ax = axes[row, col]
    ax.imshow(wrong_images[i].squeeze(), cmap='gray')
    ax.set_title(
        f'Pred: {wrong_preds[i].item()} ({wrong_confs[i].item()*100:.1f}%)\nTrue: {wrong_labels[i].item()}',
        color='red', fontsize=9
    )
    ax.axis('off')

# Hide unused subplots
for i in range(n_wrong, rows * cols):
    row, col = i // cols, i % cols
    axes[row, col].axis('off')

fig.suptitle('Incorrect Predictions', fontsize=14, color='red')
plt.tight_layout()
plt.show()

print(f'Showing {n_wrong} incorrect predictions from the test set.')
print('Notice: many of these are genuinely ambiguous — even humans might disagree.')

---

## Step 6: Build an Improved Model

The simple model works, but we can do better with three standard techniques:

- **BatchNorm1d** after each linear layer — normalizes activations, stabilizes training
- **Dropout(0.3)** — randomly zeros 30% of activations during training, reduces overfitting
- **weight_decay=1e-4** in the optimizer — L2 regularization on weights

This is a **supported** step. The template gives you the structure, but you fill in the layers.

```
Input (1x28x28)
  -> Flatten                      # 784
  -> Linear(784, 256)
  -> BatchNorm1d(256)
  -> ReLU
  -> Dropout(0.3)
  -> Linear(256, 128)
  -> BatchNorm1d(128)
  -> ReLU
  -> Dropout(0.3)
  -> Linear(128, 10)              # output
```

In [None]:
class ImprovedMNIST(nn.Module):
    def __init__(self):
        super().__init__()
        # TODO: Define the layers
        # Same linear layers as MNISTClassifier, but add:
        #   - nn.BatchNorm1d(size) after each hidden linear layer
        #   - nn.Dropout(0.3) after each ReLU
        #
        # Layers to define:
        #   self.flatten
        #   self.fc1, self.bn1
        #   self.fc2, self.bn2
        #   self.fc3
        #   self.relu, self.dropout
        pass  # Replace with your layer definitions

    def forward(self, x):
        # TODO: Implement the forward pass
        # Order for each hidden layer: linear -> batchnorm -> relu -> dropout
        # Final layer: just linear (no batchnorm, no relu, no dropout)
        pass  # Replace this

improved_model = ImprovedMNIST().to(device)

# Verify dimensions
test_input = torch.randn(1, 1, 28, 28).to(device)
test_output = improved_model(test_input)
print(f'Input shape:  {test_input.shape}')
print(f'Output shape: {test_output.shape}')
print(f'Parameters:   {sum(p.numel() for p in improved_model.parameters()):,}')

if test_output is not None and test_output.shape == torch.Size([1, 10]):
    print('\nDimensions correct! Ready to train.')
else:
    print('\nDimension mismatch — check your forward method.')

In [None]:
# Train the improved model
# Note: weight_decay=1e-4 adds L2 regularization
improved_criterion = nn.CrossEntropyLoss()
improved_optimizer = optim.Adam(improved_model.parameters(), lr=1e-3, weight_decay=1e-4)

improved_history = {
    'train_loss': [],
    'train_acc': [],
    'test_acc': [],
}

print(f'Training improved model for {num_epochs} epochs...')
print('=' * 65)

for epoch in range(num_epochs):
    # Training
    improved_model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        improved_optimizer.zero_grad()
        outputs = improved_model(images)
        loss = improved_criterion(outputs, labels)
        loss.backward()
        improved_optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    improved_history['train_loss'].append(epoch_loss)
    improved_history['train_acc'].append(epoch_acc)

    # Evaluation
    test_acc = evaluate_model(improved_model, test_loader)
    improved_history['test_acc'].append(test_acc)

    print(f'Epoch {epoch+1:2d}/{num_epochs}  '
          f'Train Loss: {epoch_loss:.4f}  '
          f'Train Acc: {epoch_acc:.1%}  '
          f'Test Acc: {test_acc:.1%}')

print('=' * 65)
print(f'\nFinal test accuracy: {improved_history["test_acc"][-1]:.1%}')

In [None]:
# Side-by-side training curves: simple vs improved
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

epochs_range = range(1, num_epochs + 1)

# Loss comparison
ax1.plot(epochs_range, history['train_loss'], 'o-', linewidth=2, label='Simple')
ax1.plot(epochs_range, improved_history['train_loss'], 's-', linewidth=2, label='Improved')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Training Loss')
ax1.set_title('Training Loss: Simple vs Improved')
ax1.legend()
ax1.grid(alpha=0.3)

# Accuracy comparison
ax2.plot(epochs_range, history['test_acc'], 'o-', linewidth=2, label='Simple')
ax2.plot(epochs_range, improved_history['test_acc'], 's-', linewidth=2, label='Improved')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Test Accuracy')
ax2.set_title('Test Accuracy: Simple vs Improved')
ax2.legend()
ax2.grid(alpha=0.3)
ax2.set_ylim(0.95, 1.0)

plt.tight_layout()
plt.show()

---

## Summary

Final comparison of the two models:

In [None]:
# Final comparison table
simple_acc = history['test_acc'][-1]
improved_acc = improved_history['test_acc'][-1]
simple_params = sum(p.numel() for p in model.parameters())
improved_params = sum(p.numel() for p in improved_model.parameters())

print('=' * 55)
print(f'{"Model":<20} {"Test Accuracy":>15} {"Parameters":>15}')
print('-' * 55)
print(f'{"Simple":<20} {simple_acc:>14.2%} {simple_params:>15,}')
print(f'{"Improved":<20} {improved_acc:>14.2%} {improved_params:>15,}')
print('-' * 55)

diff = improved_acc - simple_acc
direction = 'improvement' if diff > 0 else 'decrease'
print(f'Difference: {abs(diff):.2%} {direction}')
print(f'Extra parameters from BatchNorm: {improved_params - simple_params:,}')
print('=' * 55)

### What You Built

| Step | What Happened |
|------|---------------|
| **1. Data Loading** | MNIST with transforms, DataLoaders for batching |
| **2. Simple Model** | Three linear layers, ReLU activations |
| **3. Training Loop** | CrossEntropyLoss, Adam optimizer, epoch tracking |
| **4. Evaluation** | model.eval(), torch.no_grad(), test accuracy |
| **5. Visualization** | Correct/incorrect predictions with confidence scores |
| **6. Improved Model** | BatchNorm + Dropout + weight decay, compared training curves |

**Key takeaways:**
- The full training pipeline is: data → model → loss → optimizer → loop → evaluate
- `model.eval()` and `torch.no_grad()` are essential for correct evaluation
- BatchNorm and Dropout are nearly free improvements that regularize training
- Looking at actual predictions (not just accuracy numbers) reveals what the model struggles with