In [2]:
import sys
import torch
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F

import torchvision.datasets as datasets
from torchvision.transforms import ToTensor

# Use a reduced training set (1,000 samples) to make overfitting more likely
mnist_train_full = datasets.MNIST(root='./data', download=True, train=True, transform=ToTensor())
mnist_train = torch.utils.data.Subset(mnist_train_full, list(range(1000)))
mnist_test  = datasets.MNIST(root='./data', download=True, train=False, transform=ToTensor())

train_dataloader = DataLoader(mnist_train, batch_size=32, shuffle=True)
test_dataloader  = DataLoader(mnist_test,  batch_size=32, shuffle=True)

# Phase 2 â€” change to the simpler model and re-run
model = nn.Sequential(
    nn.Linear(784, 100),
    nn.ReLU(),
    nn.Linear(100, 10)
)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    loss_sum = 0.0
    for X, y in train_dataloader:
        X = X.reshape((-1, 784))
        y_onehot = F.one_hot(y, num_classes=10).type(torch.float32)

        optimizer.zero_grad()
        outputs = model(X)
        loss = loss_fn(outputs, y_onehot)
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()
    print("complex epoch loss:", loss_sum)

model.eval()
with torch.no_grad():
    correct = total = 0
    for X, y in test_dataloader:
        X = X.reshape((-1, 784))
        probs = torch.softmax(model(X), dim=1)
        preds = probs.argmax(dim=1)
        correct += (preds == y).sum().item()
        total   += y.size(0)
    print("complex model test accuracy:", correct / total)

# Now change the model to a simpler one by REMOVING the extra hidden layers and activations:
#       Keep only: Linear(784, 100) -> ReLU -> Linear(100, 10)

complex epoch loss: 59.096773862838745
complex epoch loss: 30.13172221183777
complex epoch loss: 18.637634754180908
complex epoch loss: 13.793805181980133
complex epoch loss: 11.327433317899704
complex epoch loss: 9.262314550578594
complex epoch loss: 8.126669861376286
complex epoch loss: 7.241165563464165
complex epoch loss: 6.1419263780117035
complex epoch loss: 5.448169752955437
complex model test accuracy: 0.8724
