In [None]:
import gzip
import pickle
from pathlib import Path
from urllib.request import urlretrieve
import torch

In [None]:
MNIST_URL = "https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz?raw=true"

In [None]:
def get_mnist():
    path_data = Path("data")
    path_data.mkdir(exist_ok=True)
    path_gz = path_data/"mnist.pkl.gz"
    if not path_gz.exists(): 
        urlretrieve(MNIST_URL, path_gz)
    with gzip.open(path_gz, 'rb') as f: 
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    data = x_train, y_train, x_valid, y_valid
    print(f"train: x.shape={x_train.shape}, y.shape={y_train.shape}")
    print(f"valid: x.shape={x_valid.shape}, y.shape={y_valid.shape}")
    return data

In [None]:
x_trn, y_trn, x_val, y_val = get_mnist()

In [None]:
def test(t1, t2, eps=1e-8): assert torch.allclose(t1, t2, atol=eps)

In [None]:
def grad(t): return t.requires_grad_(True)

In [None]:
class Module:
    def __init__(self):
        self._cache = {}
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)
    def cache(self, **kwargs):
        self._cache.update(**kwargs)
    def get(self, key, *keys):
        keys = [key] + list(keys)
        return [self._cache.get(k) for k in keys]
    def backward(self):
        raise NotImplementedError()
    def __repr__(self):
        return f"{self.__class__.__name__}"
    __str__ = __repr__

In [None]:
from torch import tensor as t
from torch.nn import functional as F

## MSE Loss

$$
\begin{aligned}
L &= MSE(y, \hat{y}) = \mathbb{E}(y - \hat{y})^2 = \mathbb{E}(d)^2 \\
\frac{dL}{d\hat{y}} &= 2 \times \mathbb{E}(d) = 2\frac{d}{N}
\end{aligned}
$$

In [None]:
class MSE(Module):
    def forward(self, pred, gt):
        d = pred - gt
        self.cache(pred=pred, d=d)
        return d.pow(2).mean()
    def backward(self):
        pred, d = self.get("pred", "d")
        pred.g = 2 * d / d.shape[0]

In [None]:
x, y = grad(t([1.5, 0.3, 2.0])), grad(t([1., 1., 3.]))

mse = MSE()
my, ref = mse(x, y), F.mse_loss(x, y)
test(my, ref)

mse.backward()
ref.backward()
test(x.g, x.grad)

## ReLU

In [None]:
class ReLU(Module):
    def forward(self, inp):
        out = inp.clamp_min(0)
        self.cache(i=inp, o=out)
        return out
    def backward(self):
        i, o = self.get("i", "o")
        i.g = (i > 0).float() * o.g

In [None]:
from torch.nn import MSELoss

In [None]:
x, y = grad(t([-1.5, 0., 2.0])), grad(t([1., 1., 3.]))

F.mse_loss(F.relu(x), y).backward()

mse, relu = MSE(), ReLU()
mse(relu(x), y)
mse.backward()
relu.backward()

test(x.g, x.grad)

## Linear

In [None]:
class MSE(Module):
    def forward(self, pred, gt):
        d = pred - gt
        self.cache(pred=pred, d=d)
        return d.pow(2).mean()
    def backward(self):
        pred, d = self.get("pred", "d")
        pred.g = 2 * d / d.shape[0]

In [None]:
class Linear(Module):
    def __init__(self, W, b):
        super().__init__()
        self.W, self.b = W, b
    def forward(self, inp):
        out = inp @ self.W + self.b
        self.cache(i=inp, o=out)
        return out
    def backward(self):
        i, o = self.get("i", "o")
        i.g = o.g @ self.W.t() # i.g = self.W * o.g
        self.W.g = i.t() @ o.g # self.W.g = i * o.g
        self.b.g = o.g

In [None]:
X = torch.tensor([
    [ 1, -2, 3, 1],
    [ 0,  1, 0, 1],
    [-1,  0, 1, 0]
]).float()
y = torch.tensor([1, 0, 3]).unsqueeze(-1)
X, y

In [None]:
W1 = grad(torch.eye(4))
b1 = grad(torch.ones(X.shape[0], W1.shape[-1]))
W2 = grad(t([1., 0., -1., 1.]).unsqueeze(-1))
b2 = grad(torch.ones((X.shape[0], W2.shape[1])))

In [None]:
(F.relu(F.relu(X@W1 + b1)@W2 + b2) - y).pow(2).mean().backward()

In [None]:
mse = MSE()
lin1, rel1 = Linear(W1, b1), ReLU()
lin2, rel2 = Linear(W2, b2), ReLU()
mse(rel2(lin2(rel1(lin1(X)))), y)
mse.backward()
rel2.backward()
lin2.backward()
rel1.backward()
lin1.backward()

In [None]:
[test(x.g, x.grad) for x in (W1, W2, b1, b2)];

## Model

In [None]:
class MLP(Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = layers
    def forward(self, inp):
        x = inp
        for layer in self.layers:
            x = layer(x)
        return x
    def backward(self):
        for layer in self.layers[::-1]:
            layer.backward()

In [None]:
W1 = grad(torch.eye(4))
b1 = grad(torch.ones(X.shape[0], W1.shape[-1]))
W2 = grad(t([1., 0., -1., 1.]).unsqueeze(-1))
b2 = grad(torch.ones(X.shape[0], W2.shape[1]))

mse = MSE()
lin1, rel1 = Linear(W1, b1), ReLU()
lin2, rel2 = Linear(W2, b2), ReLU()
model = MLP([lin1, rel1, lin2, rel2])

mse(model(X), y)
mse.backward()
model.backward()