In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# ------------------------------
# 1. Setup
# ------------------------------
# Device: use CUDA if available, otherwise mps
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

# Data transform: convert images (0-255) → tensors (0-1)
transform = transforms.ToTensor()

# Download + prepare datasets
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

# DataLoaders: provide mini-batches
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

# ------------------------------
# 2. Define the Model
# ------------------------------
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()                 # Flatten 28x28 → 784
        self.fc1 = nn.Linear(784, 256)              # First hidden layer
        self.relu = nn.ReLU()                       # Non-linearity
        self.fc2 = nn.Linear(256, 10)               # Output layer (10 classes)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)                             # No softmax here; loss expects raw logits
        return x

model = MLP().to(device)

# ------------------------------
# 3. Loss and Optimizer
# ------------------------------
criterion = nn.CrossEntropyLoss()                   # Combines LogSoftmax + NLLLoss
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# ------------------------------
# 4. Training Loop
# ------------------------------
for epoch in range(1, 4):                           # Train for 3 epochs
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()                       # Reset gradients
        output = model(data)                        # Forward pass
        loss = criterion(output, target)            # Compute loss
        loss.backward()                             # Backprop
        optimizer.step()                            # Update weights

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# ------------------------------
# 5. Evaluation
# ------------------------------
model.eval()
correct = 0
with torch.no_grad():                               # No gradients needed for eval
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1)                 # Pick class with highest logit
        correct += pred.eq(target).sum().item()

accuracy = correct / len(test_dataset)
print(f"Test Accuracy: {accuracy*100:.2f}%")

Epoch 1, Loss: 0.3067
Epoch 2, Loss: 0.0200
Epoch 3, Loss: 0.1897
Test Accuracy: 97.38%
