# MNIST from Scratch

Can we train a model to recognize handwritten digits using numpy?

1. Load the MNIST dataset and store as NumPy arrays
2. Train a simple model to solve MNIST using PyTorch
3. Do the same with NumPy by implementing the ML algorithms

In [1]:
import numpy as np
import torch
from torch import nn, optim
from tqdm import trange
from datasets import load_dataset
from helpers import get_device, train, evaluate

In [2]:
torch.manual_seed(0)
np.random.seed(0)
device = get_device()

Load mnist dataset

In [3]:
dataset = load_dataset("mnist")
X_train = np.array([np.array(image) for image in dataset["train"]["image"]], dtype=np.float32) / 255.0
Y_train = np.array(dataset["train"]["label"])
X_test = np.array([np.array(image) for image in dataset["test"]["image"]], dtype=np.float32) / 255.0
Y_test = np.array(dataset["test"]["label"])

Solve with PyTorch

In [4]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(784, 128, bias=False)
        self.l2 = nn.Linear(128, 10, bias=False)

    def __call__(self, x):
        x = x.view(-1, 784)
        x = self.l1(x).relu()
        x = self.l2(x)
        return x

In [5]:
model = MLP().to(device)
model(torch.rand(1, 784, device=device))

tensor([[ 0.2202,  0.2427, -0.1632,  0.0537,  0.1001, -0.0506,  0.1258,  0.0454,
         -0.0495, -0.3454]], device='mps:0', grad_fn=<LinearBackward0>)

In [6]:
lr = 0.005
epochs = 3
batch_size = 32

In [7]:
optimizer = optim.SGD(model.parameters(), lr=lr)
train_steps = len(X_train) // batch_size
test_steps = len(X_test) // batch_size

for epoch in range(epochs):
    train(model, X_train, Y_train, optimizer, train_steps, device=device)

evaluate(model, X_test, Y_test, device=device)

loss 0.65 accuracy 0.82: 100%|██████████| 1875/1875 [00:04<00:00, 397.86it/s]
loss 0.40 accuracy 0.91: 100%|██████████| 1875/1875 [00:04<00:00, 419.45it/s]
loss 0.29 accuracy 0.92: 100%|██████████| 1875/1875 [00:04<00:00, 423.58it/s]
100%|██████████| 79/79 [00:00<00:00, 1050.49it/s]


test set accuracy is 0.9035


Sanity check with NumPy

In [8]:
w1 = model.l1.weight.T.detach().cpu().numpy()
w2 = model.l2.weight.T.detach().cpu().numpy()

In [9]:
def forward(x: np.ndarray) -> np.ndarray:
    x = x.reshape(-1, 784)
    x = x @ w1
    x = np.maximum(x, 0)
    x = x @ w2
    return x

In [10]:
pred = forward(X_test).argmax(axis=1)
accuracy = (pred == Y_test).mean() * 100
print(f"test set accuracy is {accuracy}")

test set accuracy is 90.35


Now solve with NumPy

In [11]:
def layer_init(m: int, h: int) -> np.ndarray:
    weights = np.random.uniform(-1., 1., size=(m, h)) / np.sqrt(m * h)
    return weights

In [12]:
w1 = layer_init(784, 128)
w2 = layer_init(128, 10)

In [13]:
def train_forward(x0: np.ndarray) -> tuple[np.ndarray]:
    x0 = x0.reshape(-1, 784)
    x1 = x0 @ w1  # x1.shape = (batch_size, 128)
    x2 = np.maximum(x1, 0)  # x2.shape = (batch_size, 128), relu
    x3 = x2 @ w2  # x3.shape = (batch_size, 10)
    x3 = np.exp(x3) / np.sum(np.exp(x3), axis=1, keepdims=True)  # softmax
    return x3, x2, x1, x0

In [14]:
def cross_entropy_loss(pred: np.ndarray, labels: np.ndarray) -> tuple[float, np.ndarray]:
    # NOTE: actual is the one-hot enconding of labels
    actual = np.zeros((labels.shape[0], 10))
    actual[np.arange(labels.shape[0]), labels] = 1
    loss = -np.mean(np.sum(actual * np.log(pred), axis=1))
    error = pred - actual
    return loss, error

In [15]:
def backward(error: np.ndarray, xs: tuple[np.ndarray]) -> tuple[np.ndarray]:
    _, x2, x1, x0 = xs
    dx3 = error  # dx3.shape = (batch_size, 10), derivative of loss function
    dw2 = x2.T @ dx3  # dw2.shape = (128, 10), derivative of dot
    dx2 = dx3 @ w2.T  # dx2.shape = (batch_size, 128), derivative of dot
    dx1 = (x1 > 0).astype(np.float64) * dx2  # dx1.shape = (128, batch_size)
    dw1 = x0.T @ dx1  # dw.1shape = (784, 128)
    assert dw2.shape == w2.shape
    assert dw1.shape == w1.shape
    return dw2, dw1

In [16]:
def update_weights(dws: tuple[np.ndarray], *, lr: float) -> tuple[np.ndarray]:
    global w1, w2
    ws = (w2, w1)
    for wi, dwi in zip(ws, dws):
        wi -= lr * dwi
    return ws

In [17]:
for i in (t := trange(1000)):
    sample = np.random.randint(0, len(X_train), size=256)
    xs = train_forward(X_train[sample])
    loss, error = cross_entropy_loss(xs[0], Y_train[sample].astype(int))
    dws = backward(error, xs)
    update_weights(dws, lr=0.001)
    t.set_description(f"loss {loss.item():.2f}")
pred = forward(X_test).argmax(axis=1)
accuracy = (pred == Y_test).mean() * 100
print(f"test set accuracy is {accuracy}")

loss 0.15: 100%|██████████| 1000/1000 [00:01<00:00, 570.85it/s]


test set accuracy is 95.42
