# MNIST from Scratch

Can we train a model to recognize handwritten digits using numpy?

1. Load the MNIST dataset from the web and store as NumPy arrays
2. Train a simple model to solve MNIST using PyTorch
3. Do the same with NumPy by implementing various ML algorithms

In [1]:
import numpy as np
import jax.numpy as jnp
import optax
from flax import nnx
from tqdm import trange
from datasets import load_dataset

In [2]:
np.random.seed(0)
rngs = nnx.Rngs(0)

Load mnist dataset

In [3]:
dataset = load_dataset("mnist")
X_train = np.array([np.array(image) for image in dataset["train"]["image"]], dtype=np.float32).reshape(-1, 784) / 255.0
Y_train = np.array(dataset["train"]["label"], dtype=np.int32)
X_test = np.array([np.array(image) for image in dataset["test"]["image"]], dtype=np.float32).reshape(-1, 784) / 255.0
Y_test = np.array(dataset["test"]["label"], dtype=np.int32)

Solve with PyTorch

In [4]:
class Net(nnx.Module):
    def __init__(self, *, rngs):
        self.l1 = nnx.Linear(784, 128, rngs=rngs)
        self.l2 = nnx.Linear(128, 10, rngs=rngs)

    def __call__(self, x):
        x = nnx.relu(self.l1(x))
        x = self.l2(x)
        return x

In [5]:
model = Net(rngs=rngs)
y = model(jnp.ones((1, 784)))
nnx.display(y)

[[-0.49278238  1.2711219  -0.0858871  -0.08085708 -0.4907502  -1.1703337
   0.41073963 -0.03593045  0.57198393 -0.7578646 ]]


In [6]:
learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average("loss"))

In [7]:
def loss_fn(model, images, labels):
    logits = model(images)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
    return loss, logits

@nnx.jit
def train_step(model, optimizer, metrics, images, labels):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, images, labels)
    metrics.update(loss=loss, logits=logits, labels=labels)
    optimizer.update(grads)

@nnx.jit
def eval_step(model, metrics, images, labels):
    loss, logits = loss_fn(model, images, labels)
    metrics.update(loss=loss, logits=logits, labels=labels)

In [8]:
batch_size = 32
eval_every = 200
train_steps = len(X_train) // batch_size + 1
test_steps = len(X_test) // batch_size + 1

metrics_history = {
    "train_loss": [],
    "train_accuracy": [],
    "test_loss": [],
    "test_accuracy": [],
}

for step in range(train_steps):
    sample = np.random.randint(0, len(X_train), size=batch_size)
    images, labels = X_train[sample], Y_train[sample]
    train_step(model, optimizer, metrics, images, labels)

    if step > 0 and (step % eval_every == 0 or step == train_steps - 1):
        for metric, value in metrics.compute().items():
            metrics_history[f"train_{metric}"].append(value)
        metrics.reset()

        for test_step in range(test_steps):
            images = X_test[batch_size*test_step:batch_size*(test_step+1)]
            labels = Y_test[batch_size*test_step:batch_size*(test_step+1)]
            eval_step(model, metrics, images, labels)

        for metric, value in metrics.compute().items():
            metrics_history[f"test_{metric}"].append(value)
        metrics.reset()

        print(
            f"[train] step: {step}, "
            f"loss: {metrics_history['train_loss'][-1]:.4f}, "
            f"accuracy: {metrics_history['train_accuracy'][-1] * 100:.2f}"
        )
        print(
            f"[test] step: {step}, "
            f"loss: {metrics_history['test_loss'][-1]:.4f}, "
            f"accuracy: {metrics_history['test_accuracy'][-1] * 100:.2f}"
        )

[train] step: 200, loss: 0.4570, accuracy: 85.70
[test] step: 200, loss: 0.2741, accuracy: 91.18
[train] step: 400, loss: 0.2540, accuracy: 92.36
[test] step: 400, loss: 0.1993, accuracy: 93.87
[train] step: 600, loss: 0.1826, accuracy: 94.78
[test] step: 600, loss: 0.2247, accuracy: 93.09
[train] step: 800, loss: 0.1789, accuracy: 94.75
[test] step: 800, loss: 0.1624, accuracy: 95.06
[train] step: 1000, loss: 0.1614, accuracy: 95.23
[test] step: 1000, loss: 0.1902, accuracy: 94.13
[train] step: 1200, loss: 0.1473, accuracy: 95.55
[test] step: 1200, loss: 0.1428, accuracy: 95.54
[train] step: 1400, loss: 0.1330, accuracy: 96.05
[test] step: 1400, loss: 0.1361, accuracy: 95.87
[train] step: 1600, loss: 0.1256, accuracy: 96.05
[test] step: 1600, loss: 0.1333, accuracy: 96.03
[train] step: 1800, loss: 0.1266, accuracy: 95.94
[test] step: 1800, loss: 0.1229, accuracy: 96.11
[train] step: 1875, loss: 0.1161, accuracy: 96.29
[test] step: 1875, loss: 0.1328, accuracy: 96.00


Sanity check with NumPy

In [9]:
w1 = np.asarray(model.l1.kernel.value)
w2 = np.array(model.l2.kernel.value)

In [10]:
def forward(x: np.ndarray) -> np.ndarray:
    x = x.reshape(-1, 784)
    x = x @ w1
    x = np.maximum(x, 0)
    x = x @ w2
    return x

In [11]:
pred = forward(X_test).argmax(axis=1)
accuracy = (pred == Y_test).mean() * 100
print(f"test set accuracy is {accuracy}")

test set accuracy is 96.1


Now solve with NumPy

In [12]:
def layer_init(m: int, h: int) -> np.ndarray:
    weights = np.random.uniform(-1., 1., size=(m, h)) / np.sqrt(m * h)
    return weights

In [13]:
w1 = layer_init(784, 128)
w2 = layer_init(128, 10)

In [14]:
def train_forward(x0: np.ndarray) -> tuple[np.ndarray]:
    x1 = x0 @ w1  # batch_size * 128
    x2 = np.maximum(x1, 0)  # batch_size * 128, relu
    x3 = x2 @ w2  # batch_size * 10
    x3 = np.exp(x3) / np.sum(np.exp(x3), axis=1, keepdims=True)  # softmax
    return x3, x2, x1, x0

In [15]:
def cross_entropy_loss(pred: np.ndarray, labels: np.ndarray) -> tuple[float, np.ndarray]:
    # NOTE: actual is the one-hot enconding of labels
    actual = np.zeros((labels.shape[0], 10))
    actual[np.arange(labels.shape[0]), labels] = 1
    loss = -np.mean(np.sum(actual * np.log(pred), axis=1))
    error = pred - actual
    return loss, error

In [16]:
def backward(error: np.ndarray, xs: tuple[np.ndarray]) -> tuple[np.ndarray]:
    _, x2, x1, x0 = xs
    dx3 = error  # batch_size * 10, derivative of squared_error
    dw2 = x2.T @ dx3  # 128 * 10, derivative of dot
    dx2 = dx3 @ w2.T  # batch_size * 128, derivative of dot
    dx1 = (x1 > 0).astype(np.float64) * dx2  # 128 * batch_size
    dw1 = x0.T @ dx1  # 784 * 128
    assert dw2.shape == w2.shape
    assert dw1.shape == w1.shape
    return dw2, dw1

In [17]:
def update_weights(dws: tuple[np.ndarray], lr: float = 1e-3) -> tuple[np.ndarray]:
    global w1, w2
    ws = (w2, w1)
    for wi, dwi in zip(ws, dws):
        wi -= lr * dwi
    return ws

In [18]:
for i in (t := trange(1000)):
    sample = np.random.randint(0, len(X_train), size=256)
    xs = train_forward(X_train[sample])
    loss, error = cross_entropy_loss(xs[0], Y_train[sample].astype(int))
    dws = backward(error, xs)
    update_weights(dws, lr=0.001)
    t.set_description(f"loss {loss.item():.2f}")
pred = forward(X_test).argmax(axis=1)
accuracy = (pred == Y_test).mean() * 100
print(f"test set accuracy is {accuracy}")

loss 0.17: 100%|██████████| 1000/1000 [00:02<00:00, 416.61it/s]


test set accuracy is 95.0
