# MNIST CNN Project

In this notebook, you'll build a CNN for MNIST, train it, and compare it to the dense network you built in your PyTorch project.

**What you'll do:**
- Load MNIST data (provided)
- See the dense network baseline (provided)
- Fill in the CNN class (your task)
- Verify dimensions match the architecture diagram
- Train and compare both models

The architecture: Conv(1->32, 3x3, pad=1) -> ReLU -> MaxPool(2x2) -> Conv(32->64, 3x3, pad=1) -> ReLU -> MaxPool(2x2) -> Flatten -> Linear(3136->128) -> ReLU -> Linear(128->10)

You traced this step by step in Building a CNN. Now you implement it.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# For nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

## 1. Load MNIST Data

Same data loading you used in your PyTorch project. Nothing new here.

In [None]:
# Download and load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
])

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=64, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=64, shuffle=False
)

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')
print(f'Image shape: {train_dataset[0][0].shape}')

## 2. Dense Network Baseline

This is the dense network from your PyTorch project: flatten the 28x28 image into 784 values, then pass through fully-connected layers. We'll train it and use it as the baseline.

In [None]:
class DenseNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

dense_model = DenseNetwork().to(device)
print(f'Dense network parameters: {sum(p.numel() for p in dense_model.parameters()):,}')

## 3. Your CNN

Fill in the `__init__` and `forward` methods. The architecture is the one you traced in Building a CNN:

```
Input (1x28x28)
  -> Conv2d(1, 32, 3, padding=1) -> ReLU -> MaxPool2d(2)    # 32x14x14
  -> Conv2d(32, 64, 3, padding=1) -> ReLU -> MaxPool2d(2)   # 64x7x7
  -> Flatten                                                  # 3136
  -> Linear(3136, 128) -> ReLU                               # 128
  -> Linear(128, 10)                                         # 10
```

Each TODO tells you exactly what layer to add.

In [None]:
class MnistCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # TODO: First conv block
        # self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        # self.pool1 = nn.MaxPool2d(kernel_size=2)

        # TODO: Second conv block
        # self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # self.pool2 = nn.MaxPool2d(kernel_size=2)

        # TODO: Classifier
        # self.flatten = nn.Flatten()
        # self.fc1 = nn.Linear(3136, 128)   # 7 * 7 * 64 = 3136
        # self.fc2 = nn.Linear(128, 10)

        self.relu = nn.ReLU()

    def forward(self, x):
        # TODO: Pass x through the layers
        # Conv1 -> ReLU -> Pool1
        # Conv2 -> ReLU -> Pool2
        # Flatten -> FC1 -> ReLU -> FC2
        return x

cnn_model = MnistCNN().to(device)
print(f'CNN parameters: {sum(p.numel() for p in cnn_model.parameters()):,}')

## 4. Dimension Verification

Before training, verify your architecture produces the right output shape. Pass a random input through and check.

In [None]:
# Test with a random input
test_input = torch.randn(1, 1, 28, 28).to(device)
test_output = cnn_model(test_input)

print(f'Input shape:  {test_input.shape}')   # Should be [1, 1, 28, 28]
print(f'Output shape: {test_output.shape}')  # Should be [1, 10]

if test_output.shape == torch.Size([1, 10]):
    print('\nDimensions correct! Ready to train.')
else:
    print('\nDimension mismatch! Check your architecture.')

## 5. Training Loop

Same training loop you used for the dense network. Forward, loss, backward, update. The only difference is what happens inside `model(x)` -- and that's the architecture you just built.

In [None]:
def train_model(model, train_loader, test_loader, epochs=5, lr=1e-3):
    """Train a model and return training history."""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    history = {'train_loss': [], 'test_acc': []}
    
    for epoch in range(epochs):
        # Training
        model.train()
        running_loss = 0.0
        n_batches = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            n_batches += 1
        
        avg_loss = running_loss / n_batches
        history['train_loss'].append(avg_loss)
        
        # Evaluation
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        acc = 100.0 * correct / total
        history['test_acc'].append(acc)
        
        print(f'Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f} | Test Acc: {acc:.2f}%')
    
    return history

# Train the dense network
print('=== Training Dense Network ===')
dense_model = DenseNetwork().to(device)
dense_history = train_model(dense_model, train_loader, test_loader, epochs=5, lr=1e-3)

print()

# Train the CNN
print('=== Training CNN ===')
cnn_model = MnistCNN().to(device)
cnn_history = train_model(cnn_model, train_loader, test_loader, epochs=5, lr=1e-3)

## 6. Side-by-Side Comparison

Same data, same optimizer, same epochs, same loss function. The only variable: architecture.

In [None]:
# Final accuracy comparison
dense_acc = dense_history['test_acc'][-1]
cnn_acc = cnn_history['test_acc'][-1]

print('=' * 45)
print(f'{"Model":<20} {"Test Accuracy":>12} {"Parameters":>12}')
print('-' * 45)
print(f'{"Dense Network":<20} {dense_acc:>11.2f}% {sum(p.numel() for p in dense_model.parameters()):>12,}')
print(f'{"CNN":<20} {cnn_acc:>11.2f}% {sum(p.numel() for p in cnn_model.parameters()):>12,}')
print('=' * 45)

## 7. Parameter Count Breakdown

Where do each model's parameters live?

In [None]:
def print_param_breakdown(model, name):
    """Print parameter count per layer."""
    print(f'\n{name}')
    print('-' * 50)
    total = 0
    for param_name, param in model.named_parameters():
        count = param.numel()
        total += count
        print(f'  {param_name:<25} {count:>10,}')
    print(f'  {"TOTAL":<25} {total:>10,}')
    return total

dense_total = print_param_breakdown(dense_model, 'Dense Network')
cnn_total = print_param_breakdown(cnn_model, 'CNN')

# Highlight the conv stack vs dense first layer
print('\n' + '=' * 50)
print('KEY COMPARISON: Feature Extraction')
print('=' * 50)

dense_first_layer = 784 * 128 + 128  # fc1 weights + bias
cnn_conv_stack = (1*32*3*3 + 32) + (32*64*3*3 + 64)  # conv1 + conv2

print(f'Dense first layer (784->128):  {dense_first_layer:>10,} params')
print(f'CNN entire conv stack:         {cnn_conv_stack:>10,} params')
print(f'\nThe conv stack does MORE with {dense_first_layer - cnn_conv_stack:,} FEWER parameters.')
print('Weight sharing is the reason: one 3x3 filter detects a feature everywhere.')

## 8. Training Curves

Visualize how both models learned over time.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
epochs_range = range(1, 6)

# Loss curves
axes[0].plot(epochs_range, dense_history['train_loss'], 'o-', label='Dense Network', linewidth=2)
axes[0].plot(epochs_range, cnn_history['train_loss'], 's-', label='CNN', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Accuracy curves
axes[1].plot(epochs_range, dense_history['test_acc'], 'o-', label='Dense Network', linewidth=2)
axes[1].plot(epochs_range, cnn_history['test_acc'], 's-', label='CNN', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Test Accuracy (%)')
axes[1].set_title('Test Accuracy')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Reflection

Answer in your own words:

**"The CNN achieves higher accuracy than the dense network. The CNN's conv layers use far fewer parameters than the dense network's first layer, yet extract spatial features more effectively. Why?"**

Your answer should touch on:
- Spatial locality (filters see local neighborhoods, not all 784 pixels)
- Weight sharing (same filter applied at every position)
- Feature hierarchy (edges -> shapes -> digits through conv-pool stages)
- The dense network's weakness (a shifted digit is an entirely different input vector)

---

**Key Takeaways:**
- A simple CNN beats a dense network on MNIST with no tricks
- The CNN's convolutional feature extraction is dramatically more parameter-efficient than the dense approach
- The training loop is identical -- only the architecture changed
- Architecture encodes assumptions about data; matching architecture to data structure is the key insight
- The dense network flattens pixels; the CNN flattens features