# MNIST from Scratch

Can I train a model to recognize handwritten digits using numpy?

1. Load the MNIST dataset from the web and store as numpy arrays
2. Train a simple model to solve MNIST using PyTorch
3. Do the same with numpy by implementing various ML algorithms

In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm import tqdm
from torch import nn, Tensor, optim

In [2]:
data_dir = Path(".").resolve(strict=True).parent / "data" / "mnist"

In [3]:
def load_dataset(split: str) -> np.ndarray:
    dataset = pd.read_csv(data_dir / f"mnist_{split}.csv")
    X = dataset.drop("label", axis=1).to_numpy() / 255.0
    y = dataset["label"].to_numpy()
    return X, y

In [4]:
X_train, y_train = load_dataset("train")
X_test, y_test = load_dataset("test")

In [5]:
class NeuralNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(784, 256, bias=False)
        self.layer2 = nn.Linear(256, 64, bias=False)
        self.layer3 = nn.Linear(64, 10, bias=False)
        self.act = nn.ReLU()
    
    def forward(self, x: Tensor) -> Tensor:
        x = x.view(-1, 28 * 28)
        x = self.act(self.layer1(x))
        x = self.act(self.layer2(x))
        x = self.layer3(x)
        return x

In [6]:
iterations = 3000
batch_size = 64
learning_rate = 0.005

model = NeuralNet()
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [7]:
# training loop
for i in (t := tqdm(range(iterations))):
    sample = np.random.randint(0, len(X_train), size=batch_size)
    out = model(Tensor(X_train[sample]))
    loss = loss_func(out, Tensor(y_train[sample]).long())
    loss.backward()
    optimizer.step()
    model.zero_grad()
    t.set_description(f"Iteration {i} loss {loss.item():.2f}")

Iteration 2999 loss 0.03: 100%|██████████| 3000/3000 [00:17<00:00, 169.71it/s]


In [8]:
# test accuracy
pred = model(Tensor(X_test)).argmax(dim=1)
accuracy = (pred.argmax(dim=1) == Tensor(y_test)).float().mean().item()
accuracy

0.9729999899864197