# MNIST Handwritten Digit Classification (PyTorch CNN)

## Assignment Questions (Follow in Order)

- **Q1 (15 pts):** Build the MNIST data pipeline (transforms, train/val split, DataLoaders).
- **Q2 (20 pts):** Implement a CNN classifier with the specified architecture.
- **Q3 (40 pts):** Implement training and validation functions, run training loop with checkpointing, and plot curves.
  - 3.1 (15 pts): Training function
  - 3.2 (8 pts): Validation function
  - 3.3 (11 pts): Training loop with best checkpoint saving
  - 3.4 (6 pts): Plot training and validation curves
- **Q4 (5 pts):** Load the best checkpoint and evaluate on the test set.

**Total: 80 points**

---

## Starter Notebook (Student Version)

Fill in the sections marked with:

```python
# ========== YOUR CODE STARTS HERE ==========
# ========== YOUR CODE ENDS HERE ============
```

Do **not** change the rubric section above.

In [None]:
# If you are running in a fresh environment, uncomment to install dependencies:
# !pip -q install torch torchvision matplotlib tqdm scikit-learn

import os
import random
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

print("torch:", torch.__version__)
device = "cpu"  # For fixed, reproducible results. (You may switch to "cuda" after you finish debugging.)
print("device:", device)

def set_seed(seed: int = 42):
    """Make results as reproducible as possible across runs."""
    import os, random
    import numpy as np
    import torch

    os.environ["PYTHONHASHSEED"] = str(seed)
    # If you later switch to CUDA and want maximal determinism:
    # os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Deterministic flags (safe on CPU; on GPU some ops may error if non-deterministic)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    try:
        torch.use_deterministic_algorithms(True)
    except Exception as e:
        print("Warning: could not enable full deterministic algorithms:", e)

set_seed(42)


## 1. Dataset and DataLoader (15 pts)

In [None]:
# Q1: Build the data pipeline
data_dir = "./data"
batch_size = 128  # Use 128 for training
num_workers = 0  # For fully reproducible ordering across platforms

# ========== YOUR CODE STARTS HERE ==========
# TODO:
# 1) Create transforms for train and test (normalize with mean=0.1307, std=0.3081)
#    Avoid random augmentation for reproducibility
# 2) Load MNIST datasets (train and test) using datasets.MNIST()
# 3) Split training set into train (55k) and validation (5k) using random_split
# 4) Create three DataLoaders (train, val, test)
#    For train_loader, use shuffle=True

train_tf = None
test_tf = None
full_train = None
test_set = None
train_set, val_set = None, None
train_loader = None
val_loader = None
test_loader = None
# ========== YOUR CODE ENDS HERE ============

print("train/val/test:", len(train_set), len(val_set), len(test_set))

## 2. Define the CNN Model (20 pts)

In [None]:
# Q2: Define your CNN model
class MNISTCNN(nn.Module):
    """Input: (B,1,28,28) -> Output: (B,10)"""
    def __init__(self, num_classes=10):
        super().__init__()
        # ========== YOUR CODE STARTS HERE ==========
        # Build a CNN with:
        # - Feature extraction: 3 convolutional blocks
        #   * First block: 1 input channel to 32 output channels, with BatchNorm and ReLU, then MaxPool
        #   * Second block: 32 to 64 channels, with BatchNorm and ReLU, then MaxPool
        #   * Third block: 64 to 128 channels, with ReLU (no pooling)
        #   Use kernel_size=3 and padding=1 for all convolutions, pool_size=2 for pooling
        # - Classifier: Flatten, then fully connected layers
        #   * After two MaxPool layers, your feature map will be 7x7
        #   * First FC layer: 128*7*7 inputs to 256 outputs, with ReLU and Dropout(0.3)
        #   * Final FC layer: 256 to num_classes outputs
        
        self.features = None  # Define your convolutional layers here
        self.classifier = None  # Define your fully connected layers here
        # ========== YOUR CODE ENDS HERE ============

    def forward(self, x):
        # ========== YOUR CODE STARTS HERE ==========
        # Pass input through features, then classifier
        return None
        # ========== YOUR CODE ENDS HERE ============

model = MNISTCNN().to(device)
print(model)

# Parameter count
num_params = sum(p.numel() for p in model.parameters())
print("Total params:", num_params)

## 3. Training Loop (40 pts)

### 3.1 Training Function (15 pts)

In [None]:
# Q3.1: Training function
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def train_one_epoch(model, loader):
    """Train for one epoch and return (avg_loss, accuracy)."""
    # ========== YOUR CODE STARTS HERE ==========
    # TODO:
    # - Set model to training mode
    # - Loop through batches in the loader
    # - For each batch: 
    #   * Zero gradients
    #   * Forward pass
    #   * Compute loss
    #   * Backward pass
    #   * Optimizer step
    # - Track total loss and accuracy (IMPORTANT: get predicted labels and compare with targets)
    # - Return average loss and accuracy
    
    return None, None
    # ========== YOUR CODE ENDS HERE ============

### 3.2 Validation Function (8 pts)

In [None]:
# Q3.2: Validation function
@torch.no_grad()
def evaluate(model, loader):
    """Evaluate model and return (avg_loss, accuracy)."""
    # ========== YOUR CODE STARTS HERE ==========
    # TODO:
    # - Set model to evaluation mode
    # - Loop through batches without computing gradients
    # - Compute loss and accuracy
    # - Return average loss and accuracy
    
    return None, None
    # ========== YOUR CODE ENDS HERE ============

### 3.3 Training Loop (11 pts)

In [None]:
# Q3.3: Training loop with checkpointing
epochs = 10
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
best_val_acc = 0.0
best_epoch = -1
ckpt_path = "./checkpoints/best_mnist_cnn.pt"
os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)

# ========== YOUR CODE STARTS HERE ==========
# TODO:
# - Loop through epochs
# - Each epoch: train, validate, log metrics
# - Track best validation accuracy and save checkpoint when improved
# - Print training and validation metrics each epoch
# - For checkpoint: save a dictionary containing keys `model_state_dict` and `epoch`

# ========== YOUR CODE ENDS HERE ============

print("Best val acc:", best_val_acc, "at epoch", best_epoch)
print("Saved to:", ckpt_path)

### 3.4 Plot Training Curves (6 pts)

In [None]:
# Q3.4: Plot curves
# ========== YOUR CODE STARTS HERE ==========
# TODO:
# - Create two plots: one for loss (train vs val), one for accuracy (train vs val)
# - Use the history dictionary to get the values
# - Add labels, legends, and display the plots

# ========== YOUR CODE ENDS HERE ============

## 4. Testing (5 pts)

In [None]:
# Q4: Test evaluation
# ========== YOUR CODE STARTS HERE ==========
# TODO:
# - Load the best checkpoint (it's a dictionary with 'model_state_dict' and 'epoch')
# - Load the model state from the checkpoint
# - Evaluate on the test set
# - Print test loss and accuracy

test_loss, test_acc = None, None
# ========== YOUR CODE ENDS HERE ============

print(f"Test loss: {test_loss:.4f} | Test acc: {test_acc:.4f}")