In [1]:
%pip install numpy datasets Pillow

Note: you may need to restart the kernel to use updated packages.


# Implementation

In [2]:
import numpy as np

In [3]:
def softmax(x):
    x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return x / np.sum(x, axis=-1, keepdims=True)

def log_softmax(x):
    x = x - np.max(x, axis=-1, keepdims=True)
    return x - np.log(np.sum(np.exp(x), axis=-1, keepdims=True))

In [4]:
def loss_fn(logits, y):
    return -log_softmax(logits)[y]

def d_loss_fn(logits, y):
    p = softmax(logits)
    p = p - np.eye(logits.shape[-1])[y]
    return p

In [5]:
def relu(x):
    return np.maximum(0, x)

def d_relu(x):
    return np.where(x > 0, 1, 0)

In [6]:
def init_params(in_size: int, out_size: int, hidden_sizes: list[int]):
    dims = [in_size] + hidden_sizes + [out_size]
    params = []
    for nin, nout in zip(dims[:-1], dims[1:]):
        # xavier uniform initialization for weight matrixes
        a = (6 / (nin + nout)) ** 0.5
        w = np.random.uniform(-a, a, size=[nin, nout])
        b = np.zeros(nout)
        params.append([w, b])
    return params

def forward(x, params):
    tape = [(None, x)]
    for w, b in params:
        z = x @ w + b
        x = relu(z)
        tape.append((z, x))
    return z, tape[:-1]

def backprop(x, y, params):
    logits, tape = forward(x, params)

    grad = []
    error = d_loss_fn(logits, y)
    for (z, a), (w, _) in zip(reversed(tape), reversed(params)):
        grad_w = np.sum(error[...,np.newaxis,:] * a[...,:,np.newaxis], axis=0) / x.shape[0]
        grad_b = np.sum(error, axis=0) / x.shape[0]
        grad.append((grad_w, grad_b))
        if z is not None:
            error = error @ w.T
            error = error * d_relu(z)
    grad = list(reversed(grad))
    
    return grad

# Training on MNIST

In [7]:
def train_mnist():
    import datasets

    mnist = datasets.load_dataset("mnist")
    xtrain, ytrain = np.array(mnist["train"]["image"]).reshape(-1, 784) / 255.0, mnist["train"]["label"]
    xtest, ytest = np.array(mnist["test"]["image"]).reshape(-1, 784) / 255.0, mnist["test"]["label"]

    def compute_val_acc(params):
        val_correct = 0
        for x, y in zip(xtest, ytest):
            z, _ = forward(x, params)
            val_correct += np.argmax(z) == y
        return val_correct / len(xtest)

    lr = 1e-3
    bs = 64
    n_epochs = 10
    log_every_n_steps = 100

    params = init_params(784, 20, [64])
    for epoch in range(n_epochs):
        for step, idx in enumerate(range(0, len(xtrain), bs)):
            # get batch of training examples
            x, y = xtrain[idx:idx+bs], ytrain[idx:idx+bs]

            # compute gradient
            grad = backprop(x, y, params)

            # update the parameters
            for k in range(len(params)):
                params[k][0] -= lr * grad[k][0] 
                params[k][1] -= lr * grad[k][1]

            # log
            if step % log_every_n_steps == 0:
                print(f"epoch: {epoch} | step: {step} | acc: {compute_val_acc(params):.4f}")

train_mnist()

  from .autonotebook import tqdm as notebook_tqdm


epoch: 0 | step: 0 | acc: 0.0547
epoch: 0 | step: 100 | acc: 0.0883
epoch: 0 | step: 200 | acc: 0.1427
epoch: 0 | step: 300 | acc: 0.2094
epoch: 0 | step: 400 | acc: 0.2627
epoch: 0 | step: 500 | acc: 0.3055
epoch: 0 | step: 600 | acc: 0.3634
epoch: 0 | step: 700 | acc: 0.4181
epoch: 0 | step: 800 | acc: 0.4804
epoch: 0 | step: 900 | acc: 0.5387
epoch: 1 | step: 0 | acc: 0.5559
epoch: 1 | step: 100 | acc: 0.5906
epoch: 1 | step: 200 | acc: 0.6230
epoch: 1 | step: 300 | acc: 0.6514
epoch: 1 | step: 400 | acc: 0.6807
epoch: 1 | step: 500 | acc: 0.7035
epoch: 1 | step: 600 | acc: 0.7250
epoch: 1 | step: 700 | acc: 0.7381
epoch: 1 | step: 800 | acc: 0.7479
epoch: 1 | step: 900 | acc: 0.7592
epoch: 2 | step: 0 | acc: 0.7635
epoch: 2 | step: 100 | acc: 0.7697
epoch: 2 | step: 200 | acc: 0.7773
epoch: 2 | step: 300 | acc: 0.7854
epoch: 2 | step: 400 | acc: 0.7950
epoch: 2 | step: 500 | acc: 0.8025
epoch: 2 | step: 600 | acc: 0.8037
epoch: 2 | step: 700 | acc: 0.8057
epoch: 2 | step: 800 | acc

# Verifying Backprop

We can verify that our backprop gives a correct by comparing it with a numerically computed approximation of the gradient, which we can obtain via (using a small value for $h$):

$$
f'(x) = \lim_{h\rightarrow 0} \frac{f(x - h) - f(x + h)}{2h}
$$

In [8]:
import copy

def numerical_grad(x, y, params, h=1e-6):
    compute_loss = lambda params: loss_fn(forward(x, params)[0], y)  # noqa: E731

    grad = copy.deepcopy(params)
    for i in range(len(params)):
        for j in range(len(params[i])):
            for k in np.ndindex(params[i][j].shape):
                prev_value = params[i][j][k]
                params[i][j][k] += h
                l = compute_loss(params)
                params[i][j][k] -= 2*h
                r = compute_loss(params)
                grad[i][j][k] = (l - r) / (2 * h)
                params[i][j][k] = prev_value

    return grad

def check_gradients():
    x, y = np.random.randn(256), np.random.randint(4)

    params = init_params(256, 4, [128, 64, 32, 16, 8])
    agrad = backprop(x.reshape(1, -1), y, params)
    ngrad = numerical_grad(x, y, params)

    diff = []
    rels = []
    for (aw, ab), (bw, bb) in zip(agrad, ngrad):
        diff.append(np.linalg.norm(aw - bw) / (np.linalg.norm(aw) + np.linalg.norm(bw)))
        diff.append(np.linalg.norm(ab - bb) / (np.linalg.norm(ab) + np.linalg.norm(bb)))
        
        rels += (abs(aw - bw) / (abs(aw) + np.finfo(aw.dtype).smallest_subnormal)).flatten().tolist()
        rels += (abs(ab - bb) / (abs(bb) + np.finfo(aw.dtype).smallest_subnormal)).flatten().tolist()
    
    diff = np.array(diff)
    rels = np.array(rels)
    
    print("Max Difference Norm:", diff.max())
    print("Mean Difference Norm:", diff.mean())
    print()
    print("Max Relative Diff:", rels.max())
    print("Mean Relative Diff:", rels.mean())
    print()

    assert diff.max() < 1e-6
    assert rels.max() < 0.05
    print("gradient checking passed ✅")
    print()

check_gradients()

Max Difference Norm: 3.4893667378829324e-09
Mean Difference Norm: 1.2138554306040649e-09

Max Relative Diff: 5.8688532911883646e-05
Mean Relative Diff: 2.8477597187627406e-08

gradient checking passed ✅

