# MNIST from Scratch

Can we train a model to recognize handwritten digits using numpy?

1. Load the MNIST dataset from the web and store as NumPy arrays
2. Train a simple model to solve MNIST using PyTorch
3. Do the same with NumPy by implementing various ML algorithms

In [2]:
import numpy as np
from tqdm import trange
from tinygrad.tensor import Tensor
from tinygrad.nn import optim
from tinygrad.nn.state import get_parameters
from extra.training import train, evaluate
from extra.datasets import fetch_mnist
np.random.seed(1337)
Tensor.manual_seed(1337)

Load mnist dataset

In [3]:
X_train, Y_train, X_test, Y_test = fetch_mnist()

Solve with PyTorch

In [4]:
class Net:
    def __init__(self):
        self.l1 = Tensor.scaled_uniform(784, 128)
        self.l2 = Tensor.scaled_uniform(128, 10)
    
    def forward(self, x: Tensor) -> Tensor:
        x = x.dot(self.l1).relu()
        x = x.dot(self.l2).log_softmax()
        return x

In [5]:
model = Net()
optimizer = optim.SGD(get_parameters(model), lr=0.001)
train(model, X_train, Y_train, optimizer, 1000, BS=256)
evaluate(model, X_test, Y_test)

loss 0.14 accuracy 0.95: 100%|██████████| 1000/1000 [00:07<00:00, 131.02it/s]
100%|██████████| 79/79 [00:00<00:00, 729.74it/s]

test set accuracy is 0.963800





0.9638

Sanity check with NumPy

In [6]:
w1 = model.l1.detach().numpy().astype(np.float64)
w2 = model.l2.detach().numpy().astype(np.float64)

In [7]:
def forward(x: np.ndarray) -> np.ndarray:
    x = x @ w1
    x = np.maximum(x, 0)
    x = x @ w2
    return x

In [8]:
pred = forward(X_test).argmax(axis=1)
accuracy = (pred == Y_test).mean()
print(f"test set accuracy is {accuracy}")

test set accuracy is 0.9638


Now solve with NumPy

In [9]:
def layer_init(m: int, h: int) -> np.ndarray:
    weights = np.random.uniform(-1., 1., size=(m, h)) / np.sqrt(m * h)
    return weights

In [10]:
w1 = layer_init(784, 128)
w2 = layer_init(128, 10)
X_train /= 255.0
X_test /= 255.0

In [11]:
def saved_forward(x0: np.ndarray) -> tuple[np.ndarray]:
    x1 = x0 @ w1  # batch_size * 128
    x2 = np.maximum(x1, 0)  # batch_size * 128, relu
    x3 = x2 @ w2  # batch_size * 10
    x3 = np.exp(x3) / np.sum(np.exp(x3), axis=1, keepdims=True)  # softmax
    return x3, x2, x1, x0

In [12]:
def cross_entropy_loss(pred: np.ndarray, labels: np.ndarray) -> tuple[float, np.ndarray]:
    # y is the one-hot enconding of labels
    actual = np.zeros((labels.shape[0], 10))
    actual[np.arange(labels.shape[0]), labels] = 1
    # cross-entropy loss
    loss = -np.mean(np.sum(actual * np.log(pred), axis=1))
    error = pred - actual
    return loss, error

In [13]:
def backward(error: np.ndarray, xs: tuple[np.ndarray]) -> tuple[np.ndarray]:
    x3, x2, x1, x0 = xs
    dx3 = error  # batch_size * 10, derivative of squared_error
    dw2 = x2.T @ dx3  # 128 * 10, derivative of dot
    dx2 = dx3 @ w2.T  # batch_size * 128, derivative of dot
    dx1 = (x1 > 0).astype(np.float64) * dx2 # 128 * batch_size
    dw1 = x0.T @ dx1  # 784 * 128
    assert dw2.shape == w2.shape
    assert dw1.shape == w1.shape
    return dw2, dw1

In [14]:
def update_weights(dws: tuple[np.ndarray], lr: float = 1e-3) -> tuple[np.ndarray]:
    global w1, w2
    ws = (w2, w1)
    # SGD
    for wi, dwi in zip(ws, dws):
        wi -= lr * dwi
    return ws

In [15]:
for i in (t := trange(1000)):
    sample = np.random.randint(0, len(X_train), size=256)
    xs = saved_forward(X_train[sample])
    loss, error = cross_entropy_loss(xs[0], Y_train[sample].astype(int))
    dws = backward(error, xs)
    update_weights(dws, lr=0.001)
    t.set_description(f"loss {loss.item():.2f}")
pred = forward(X_test).argmax(axis=1)
accuracy = (pred == Y_test).mean()
print(f"test set accuracy is {accuracy}")

loss 0.11: 100%|██████████| 1000/1000 [00:14<00:00, 70.66it/s]


test set accuracy is 0.9512
