# MNIST Handwritten Digit Recognition (MLP in PyTorch)

This notebook trains a Multi-Layer Perceptron (MLP) on the MNIST dataset and walks through data loading, preprocessing, model construction, training, evaluation, visualization, saving/loading checkpoints, and an optional lightweight hyperparameter sweep.

Sections mirror a production workflow and include optional mixed precision and hyperparameter exploration.


## 1. Install / Verify Dependencies

If running in a fresh environment, you can (optionally) install required packages. This cell is idempotent and will skip installs if already present.

```bash
# Optional (uncomment if needed)
# pip install torch torchvision torchaudio matplotlib seaborn scikit-learn tqdm
```


In [None]:
# (Optional) quick dependency check
import importlib, sys

required = ["torch", "torchvision", "matplotlib", "seaborn", "sklearn", "tqdm"]
missing = []
for pkg in required:
    if importlib.util.find_spec(pkg) is None:
        missing.append(pkg)
print("Missing packages:", missing if missing else "None. All good.")
if missing:
    print("You can install them with: pip install " + " ".join(missing))

## 2. Imports & Device Selection


In [None]:
import os, math, json, time, random, pathlib
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Device selection: prefer CUDA, then MPS (Apple Silicon), else CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print("Using device:", device)

## 3. Reproducibility (Set All Random Seeds)


In [None]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
# Optional stricter determinism (may slow training)
# torch.use_deterministic_algorithms(True)
print("Seeds set to", seed)

## 4. Hyperparameter Configuration


In [None]:
hparams = {
    "batch_size": 128,
    "epochs": 10,
    "learning_rate": 1e-3,
    "weight_decay": 0.0,
    "hidden_sizes": [256, 128],
    "input_dim": 28 * 28,
    "num_classes": 10,
    "scheduler_patience": 3,
    "scheduler_factor": 0.5,
    "num_workers": 2,
    "amp": True,  # set False to disable mixed precision
}
print(hparams)

## 5. Define Data Transforms & Download MNIST


In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

data_root = Path("./data")
train_dataset_full = datasets.MNIST(
    root=data_root, train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root=data_root, train=False, download=True, transform=transform
)
print("Train size (raw):", len(train_dataset_full), " Test size:", len(test_dataset))

## 6. Create DataLoaders (Train / Val / Test Split)


In [None]:
val_size = 5000
train_size = len(train_dataset_full) - val_size
train_dataset, val_dataset = random_split(train_dataset_full, [train_size, val_size])

kwargs = {
    "batch_size": hparams["batch_size"],
    "num_workers": hparams["num_workers"],
    "pin_memory": True if device.type == "cuda" else False,
}
train_loader = DataLoader(train_dataset, shuffle=True, **kwargs)
val_loader = DataLoader(val_dataset, shuffle=False, **kwargs)
test_loader = DataLoader(test_dataset, shuffle=False, **kwargs)
print(
    f"Train: {len(train_dataset)}  Val: {len(val_dataset)}  Test: {len(test_dataset)}"
)

## 7. Define MLP Model Class


In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_sizes, num_classes):
        super().__init__()
        layers = []
        prev = input_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            prev = h
        layers.append(nn.Linear(prev, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.net(x)


model = MLP(hparams["input_dim"], hparams["hidden_sizes"], hparams["num_classes"]).to(
    device
)
print(model)

## 8. Utility: Count & Display Trainable Parameters


In [None]:
def count_params(model):
    total = sum(p.numel() for p in model.parameters() if p.requires_grad)
    if total >= 1e6:
        human = f"{total/1e6:.2f}M"
    elif total >= 1e3:
        human = f"{total/1e3:.1f}K"
    else:
        human = str(total)
    print(f"Trainable parameters: {human} ({total})")


count_params(model)

## 9. Initialize Model, Loss Function, Optimizer, Scheduler


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=hparams["learning_rate"],
    weight_decay=hparams["weight_decay"],
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=hparams["scheduler_factor"],
    patience=hparams["scheduler_patience"],
)
print("Optimizer and scheduler initialized")

scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda" and hparams["amp"]))

## 10. Training Step Function


In [None]:
def train_one_epoch(model, dataloader, optimizer, device, scaler=None):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(dataloader, leave=False)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(
            enabled=(scaler is not None and scaler.is_enabled())
        ):
            outputs = model(images)
            loss = criterion(outputs, labels)
        if scaler is not None and scaler.is_enabled():
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        pbar.set_description(f"loss {loss.item():.4f}")
    avg_loss = running_loss / total
    acc = correct / total
    return avg_loss, acc

## 11. Evaluation Function


In [None]:
def evaluate(model, dataloader, device, return_preds=False):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            if return_preds:
                all_preds.append(preds.cpu())
                all_labels.append(labels.cpu())
    avg_loss = running_loss / total
    acc = correct / total
    if return_preds:
        return avg_loss, acc, torch.cat(all_preds), torch.cat(all_labels)
    return avg_loss, acc

## 12. Main Training Loop


In [None]:
history = {
    "epoch": [],
    "train_loss": [],
    "val_loss": [],
    "train_acc": [],
    "val_acc": [],
    "lr": [],
}
best_val_acc = 0.0
best_state = None
start_time = time.time()
for epoch in range(1, hparams["epochs"] + 1):
    train_loss, train_acc = train_one_epoch(
        model, train_loader, optimizer, device, scaler
    )
    val_loss, val_acc = evaluate(model, val_loader, device)
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]["lr"]
    history["epoch"].append(epoch)
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_acc"].append(train_acc)
    history["val_acc"].append(val_acc)
    history["lr"].append(current_lr)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_state = model.state_dict()
    print(
        f"Epoch {epoch:02d}/{hparams['epochs']} | TLoss {train_loss:.4f} VLoss {val_loss:.4f} | TAcc {train_acc*100:.2f}% VAcc {val_acc*100:.2f}% | LR {current_lr:.2e}"
    )

total_time = time.time() - start_time
print(
    f"Training completed in {total_time/60:.2f} minutes. Best Val Acc: {best_val_acc*100:.2f}%"
)

## 13. Track & Store Metrics


In [None]:
# Save metrics to JSON
metrics_path = Path("training_metrics.json")
with open(metrics_path, "w") as f:
    json.dump(history, f, indent=2)
print("Metrics saved to", metrics_path.resolve())

## 14. Plot Training vs Validation Curves


In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history["epoch"], history["train_loss"], label="Train Loss")
plt.plot(history["epoch"], history["val_loss"], label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss")
plt.legend()
plt.grid(alpha=0.3)
plt.subplot(1, 2, 2)
plt.plot(history["epoch"], history["train_acc"], label="Train Acc")
plt.plot(history["epoch"], history["val_acc"], label="Val Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy")
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## 15. Final Test Evaluation


In [None]:
# Load best weights
if best_state is not None:
    model.load_state_dict(best_state)

test_loss, test_acc, test_preds, test_labels = evaluate(
    model, test_loader, device, return_preds=True
)
print(f"Test Loss: {test_loss:.4f}  Test Accuracy: {test_acc*100:.2f}%")

## 16. Confusion Matrix & Classification Report


In [None]:
cm = confusion_matrix(test_labels, test_preds)
report = classification_report(test_labels, test_preds, digits=4)
print(report)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()

## 17. Visualize Sample Predictions


In [None]:
model.eval()
examples = []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        for img, pred, label in zip(images, preds, labels):
            examples.append((img.cpu(), pred.item(), label.item()))
            if len(examples) >= 36:
                break
        if len(examples) >= 36:
            break

cols = 6
rows = math.ceil(len(examples) / cols)
plt.figure(figsize=(cols * 1.5, rows * 1.5))
for i, (img, pred, label) in enumerate(examples):
    plt.subplot(rows, cols, i + 1)
    plt.imshow(img.squeeze(0), cmap="gray")
    color = "green" if pred == label else "red"
    plt.title(f"p:{pred} t:{label}", color=color, fontsize=8)
    plt.axis("off")
plt.tight_layout()
plt.show()

## 18. Save Model Checkpoint


In [None]:
ckpt = {
    "model_state": best_state if best_state is not None else model.state_dict(),
    "hyperparameters": hparams,
    "best_val_acc": best_val_acc,
    "test_acc": float(test_acc),
    "metrics": history,
}
ckpt_path = Path("mnist_mlp.pt")
torch.save(ckpt, ckpt_path)
print("Checkpoint saved to", ckpt_path.resolve())

## 19. Load Checkpoint & Inference Sanity Check


In [None]:
loaded = torch.load("mnist_mlp.pt", map_location=device)
model2 = MLP(hparams["input_dim"], hparams["hidden_sizes"], hparams["num_classes"]).to(
    device
)
model2.load_state_dict(loaded["model_state"])
model2.eval()
with torch.no_grad():
    images, labels = next(iter(test_loader))
    images = images.to(device)
    outputs = model2(images)
    preds = outputs.argmax(dim=1)
print(
    "Sanity batch accuracy:", (preds.cpu() == labels).float().mean().item() * 100, "%"
)

## 20. Optional: Simple Hyperparameter Sweep


In [None]:
enable_sweep = False  # set True to run a quick sweep (will take extra time)
sweep_results = []

if enable_sweep:
    sweep_configs = [[512, 256, 128], [256, 256, 128], [512, 256], [128, 64]]
    short_epochs = 3
    for cfg in sweep_configs:
        temp_model = MLP(hparams["input_dim"], cfg, hparams["num_classes"]).to(device)
        temp_opt = optim.Adam(temp_model.parameters(), lr=1e-3)
        for ep in range(short_epochs):
            train_one_epoch(temp_model, train_loader, temp_opt, device)
            _, val_acc = evaluate(temp_model, val_loader, device)
        sweep_results.append({"hidden": cfg, "val_acc": val_acc})
    print("Sweep results:")
    for r in sweep_results:
        print(r)

## 21. Optional: Mixed Precision Note

Mixed precision (autocast + GradScaler) is enabled automatically when on CUDA and `hparams['amp']` is True. Adjust if numerical instability occurs.


## 22. (Optional) Export Notebook to Script

```bash
# Run this in a shell cell if you want a .py export
# !jupyter nbconvert --to script mnist_mlp.ipynb
```


## 23. Environment Info & Cleanup

Below we print versions and (if on CUDA) optionally clear GPU cache.


In [None]:
print("Torch version:", torch.__version__)
print("Device:", device)
if device.type == "cuda":
    print("CUDA device count:", torch.cuda.device_count())
    print("Current device name:", torch.cuda.get_device_name(0))
    torch.cuda.empty_cache()
elif device.type == "mps":
    print("Using Apple Silicon MPS backend")
else:
    print("Running on CPU")