# MNIST from Scratch

Can I train a model to recognize handwritten digits using numpy?

1. Load the MNIST dataset from the web and store as NumPy arrays
2. Train a simple model to solve MNIST using PyTorch
3. Do the same with NumPy by implementing various ML algorithms

In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm import trange
from tinygrad import nn
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import Adam, get_parameters
from datasets import fetch_mnist
from extra.training import train, evaluate

ops_triton not available No module named 'pycuda'


Load mnist dataset

In [2]:
X_train, Y_train, X_test, Y_test = fetch_mnist()
X_train /= 255.0
X_test /= 255.0

Solve with PyTorch

In [3]:
class NeuralNet():

    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(784, 128, bias=False)
        self.layer2 = nn.Linear(128, 10, bias=False)
    
    def forward(self, x: Tensor) -> Tensor:
        x = x.reshape(-1, 28 * 28)
        x = self.layer1(x).relu()
        x = self.layer2(x).logsoftmax()
        return x

In [4]:
iterations = 1000
batch_size = 64
learning_rate = 1e-3
np.random.seed(42)
model = NeuralNet()
optimizer = Adam(get_parameters(model), lr=learning_rate)

In [5]:
# training loop
train(model, X_train, Y_train, optimizer, iterations)

loss 0.12 accuracy 0.97: 100%|██████████| 1000/1000 [00:26<00:00, 37.62it/s]


In [6]:
# test accuracy
evaluate(model, X_test, Y_test);

100%|██████████| 79/79 [00:00<00:00, 165.74it/s]

test set accuracy is 0.959200





Sanity check with NumPy

In [7]:
w1 = model.layer1.weight.detach().numpy().astype(np.float64)
w2 = model.layer2.weight.detach().numpy().astype(np.float64)

In [8]:
def forward(x: np.ndarray) -> np.ndarray:
    x = x @ w1.T
    x = np.maximum(x, 0)
    x = x @ w2.T
    return x

In [9]:
pred = forward(X_test).argmax(axis=1)
accuracy = (pred == Y_test).mean()
accuracy

0.9592

Now solve with NumPy

In [10]:
def layer_init(m: int, h: int) -> np.ndarray:
    weights = np.random.uniform(-1., 1., size=(m, h)) / np.sqrt(m * h)
    return weights

In [11]:
w1 = layer_init(128, 784)
w2 = layer_init(10, 128)

In [12]:
def saved_forward(x0: np.ndarray) -> tuple[np.ndarray]:
    x1 = x0 @ w1.T  # batch_size * 128
    x2 = np.maximum(x1, 0)  # batch_size * 128, relu
    x3 = x2 @ w2.T  # batch_size * 10
    x3 = np.exp(x3) / np.sum(np.exp(x3), axis=1, keepdims=True)  # softmax
    return x3, x2, x1, x0

In [13]:
def cross_entropy_loss(pred: np.ndarray, labels: np.ndarray) -> tuple[float, np.ndarray]:
    # y is the one-hot enconding of labels
    actual = np.zeros((labels.shape[0], 10))
    actual[np.arange(labels.shape[0]), labels] = 1
    # cross-entropy loss
    loss = -np.mean(np.sum(actual * np.log(pred), axis=1))
    error = pred - actual
    return loss, error

In [14]:
def backward(error: np.ndarray, xs: tuple[np.ndarray]) -> tuple[np.ndarray]:
    x3, x2, x1, x0 = xs
    dx3 = error  # 10 * batch_size, derivative of squared_error
    dw2 = (x2.T @ dx3).T  # 10 * batch_size, derivative of dot
    dx2 = dx3 @ w2  # 128 * batch_size, derivative of dot
    dx1 = (x1 > 0).astype(np.float64) * dx2 # 128 * batch_size
    dw1 = (x0.T @ dx1).T  # 784 * batch_size
    assert dw2.shape == w2.shape
    assert dw1.shape == w1.shape
    return dw2, dw1

In [15]:
def update_weights(dws: tuple[np.ndarray], lr: float = 1e-3) -> tuple[np.ndarray]:
    global w1, w2
    ws = (w2, w1)
    # SGD
    for wi, dwi in zip(ws, dws):
        wi -= lr * dwi
    return ws

In [16]:
for i in (t := trange(iterations)):
    sample = np.random.randint(0, len(X_train), size=batch_size)
    xs = saved_forward(X_train[sample])
    loss, error = cross_entropy_loss(xs[0], Y_train[sample].astype(int))
    dws = backward(error, xs)
    update_weights(dws, lr=learning_rate)
    t.set_description(f"Iteration {i} loss {loss.item():.2f}")

Iteration 999 loss 0.26: 100%|██████████| 1000/1000 [00:03<00:00, 296.59it/s]


In [17]:
pred = forward(X_test).argmax(axis=1)
accuracy = (pred == Y_test).mean()
accuracy

0.9099