# SOTA-ish MNIST MLP
Simple MLP architecture that achieves roughly SOTA performance (>98%) on MNIST dataset within **30** epochs (50 training epochs shown).

In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from torchvision import datasets
from torchvision.transforms import v2, ToTensor

In [3]:
# Download training data from open datasets.
transforms = v2.Compose([
    # v2.RandomResizedCrop(size=(28, 28), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomRotation(degrees=15),
    # v2.ToDtype(torch.float32, scale=True),
    # v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
    ToTensor()
])

training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms,
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [4]:
BATCH_SIZE = 64

train_dl = DataLoader(training_data, batch_size = BATCH_SIZE)
test_dl = DataLoader(test_data, batch_size = BATCH_SIZE)

In [5]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [6]:
class NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_block = nn.Sequential(
            nn.Linear(28*28, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.3),
            nn.Linear(1024, 384),
            nn.ReLU(),
            # nn.BatchNorm1d(192),
            nn.Linear(384, 10),
            # nn.Softmax(dim=1),
        )
    
    def forward(self, x):
        x = self.flatten(x)
        probs = self.linear_block(x)
        return probs

In [7]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            # print("PREDICTION: ", pred)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [8]:
model = NN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=1e-3,
)

EPOCHS = 50

for t in range(EPOCHS):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dl, model, loss_fn, optimizer)
    test(test_dl, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.307436  [   64/60000]
loss: 0.594470  [ 6464/60000]
loss: 0.535015  [12864/60000]
loss: 0.401990  [19264/60000]
loss: 0.375865  [25664/60000]
loss: 0.408990  [32064/60000]
loss: 0.354631  [38464/60000]
loss: 0.326443  [44864/60000]
loss: 0.263131  [51264/60000]
loss: 0.611126  [57664/60000]
Test Error: 
 Accuracy: 92.0%, Avg loss: 0.243021 

Epoch 2
-------------------------------
loss: 0.224656  [   64/60000]
loss: 0.382130  [ 6464/60000]
loss: 0.320709  [12864/60000]
loss: 0.321481  [19264/60000]
loss: 0.268204  [25664/60000]
loss: 0.335202  [32064/60000]
loss: 0.254348  [38464/60000]
loss: 0.309981  [44864/60000]
loss: 0.206965  [51264/60000]
loss: 0.276780  [57664/60000]
Test Error: 
 Accuracy: 95.1%, Avg loss: 0.158749 

Epoch 3
-------------------------------
loss: 0.146737  [   64/60000]
loss: 0.159576  [ 6464/60000]
loss: 0.295618  [12864/60000]
loss: 0.243006  [19264/60000]
loss: 0.294980  [25664/60000]
loss: 0.163240  [32064/600