# MNIST classifier **Exercise**





In this exercise, you will apply what you've learned to build, train, and evaluate a neural network to classify handwritten digits from the famous MNIST dataset.


# Step 1: Import the libraries

Step 1: Setup and Imports
First, we import the necessary libraries and, most importantly, set up our device. This ensures our code will use a GPU for faster training if one is available.

In [None]:
# --- Core PyTorch and data handling libraries ---
import random
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms

# --- Visualization and analysis ---
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

# --- Device and reproducibility ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


## Step 2 — Prepare the dataset transforms and load MNIST

We will scale images to `[0,1]` with `ToTensor()` and then **standardize** with mean/std for MNIST. Keep these lines unchanged for the lab.

In [None]:
# MNIST statistics (train set)
mnist_mean = 0.1307
mnist_std = 0.3081

transform = transforms.Compose([
    # Resize the image to 28x28 pixels
    transforms.Resize((28, 28)),
    # Convert the image to a PyTorch tensor
    # ** Fill in the code below **
    # Normalize the image with mean and standard deviation
    #** Fill in the code below **
])

train_set = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_set  = datasets.MNIST(root='data', train=False, download=True, transform=transform)

print('Train samples:', len(train_set))
print('Test samples :', len(test_set))


### Visualize some training samples
The images are normalized. The helper function below **unnormalizes** them before plotting so they appear correctly.

In [None]:
def imshow_normalized(tensor_img, mean=mnist_mean, std=mnist_std):
    """Unnormalize a CxHxW tensor and plot it as HxW (grayscale)"""
    img = tensor_img.clone().cpu().numpy()
    img = img * std + mean
    plt.imshow(img.squeeze(), cmap='gray')
    plt.axis('off')

plt.figure(figsize=(6,6))
cols, rows = 4, 4
for i in range(1, cols*rows + 1):
    idx = random.randint(0, len(train_set)-1)
    img, label = train_set[idx]
    plt.subplot(rows, cols, i)
    imshow_normalized(img)
    plt.title(f"Label: {label}")
plt.tight_layout()
plt.show()


## Step 3 — Create DataLoaders

Create `DataLoader`s for training and testing. Leave `shuffle=True` for training.

In [None]:
BATCH_SIZE = 64
train_loader = #** Fill in the code below **
test_loader  = #** Fill in the code below **

print('Train batches:', len(train_loader))
print('Test  batches:', len(test_loader))


## Step 4 — Define your network 

Implement a PyTorch `nn.Module` for classification. 

Below is a skeleton: fill in the `TODO` parts.

Hints:
- Input size is `28*28` after flattening.
- Output size must be `10` (classes 0..9).
- Use `nn.ReLU()` activations and `nn.Linear` layers.


In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # TODO: build your network layers here
        
        
        raise NotImplementedError("Define your network layers in __init__")

    def forward(self, x):
        # TODO: implement forward pass
        
        
        raise NotImplementedError("Implement forward() to return logits")

# Initialize model (students should implement the class above first)
try:
    model = NeuralNetwork().to(device)
    print(model)
except NotImplementedError as e:
    print('Model not defined yet — fill the TODO in the class above.')


## Step 5 — Training setup 

Create the loss function, optimizer and a training loop. Fill the TODOs below. Use `nn.CrossEntropyLoss()` for the loss and `torch.optim.SGD` or `Adam` for the optimizer.

Implement `train_loop` and `test_loop` 

In [None]:
# Hyperparameters (feel free to experiment)
learning_rate = 1e-2
epochs = 10

# TODO: choose loss and optimizer
loss_fn = None  # TODO
optimizer = None  # TODO


# Optional, feel free to implement something else you are comfortable with
def train_loop(dataloader, model, loss_fn, optimizer, device):
    """Train for one epoch. Return (avg_loss, accuracy_fraction).
    Implement per-sample loss averaging and exact accuracy counting.
    """
    # TODO: implement training loop
    raise NotImplementedError("Implement train_loop")

def test_loop(dataloader, model, loss_fn, device):
    """Evaluate on validation/test set. Return (avg_loss, accuracy_fraction)."""
    # TODO: implement test loop
    raise NotImplementedError("Implement test_loop")


### Run training 

Fill the training loop above. Print per-epoch train/test loss and accuracy. Do not run this cell until you implemented the functions.

In [None]:
# Run training: implement the train_loop/test_loop first!
try:
    for epoch in range(1, epochs+1):
        print(f"Epoch {epoch}/{epochs}")
        train_loss, train_acc = train_loop(train_loader, model, loss_fn, optimizer, device)
        test_loss, test_acc = test_loop(test_loader, model, loss_fn, device)
        print(f"Train loss: {train_loss:.4f}, Train acc: {100*train_acc:.2f}% | Test loss: {test_loss:.4f}, Test acc: {100*test_acc:.2f}%")
except NotImplementedError:
    print('Training loop or model not implemented yet. Fill the TODOs above before running training.')


## Step 6 — Evaluate & visualize (***Student Task***)

Once you have trained the model, compute and plot the learning curves, confusion matrix and some sample predictions. 

## Step 7 — Saving and loading models

Once you're satisfied with the trained model, save the `state_dict`. The code below is provided — use it after training.


In [None]:
# Example: saving and loading (run AFTER training)
PATH = 'mnist_model_student.pth'
# torch.save(model.state_dict(), PATH)
# To load:
# model_loaded = NeuralNetwork().to(device)
# model_loaded.load_state_dict(torch.load(PATH))
# model_loaded.eval()




### Final notes for students
- Try different architectures, learning rates, optimizers and schedulers.
- Experiment with batch size and number of epochs.
- Compare normalizing with dataset stats vs simple `mean=0.5,std=0.5`.
- Critically analyze your results, how is your model performing? How can it be improved?
