In [None]:
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl, numpy as np
from pathlib import Path
from torch import tensor
from fastcore.test import test_close
torch.manual_seed(42)

mpl.rcParams['image.cmap'] = 'gray'
torch.set_printoptions(precision=2, linewidth=125, sci_mode=False)
np.set_printoptions(precision=2, linewidth=125)

path_data = Path('data')
path_gz = path_data/'mnist.pkl.gz'
with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])

## Foundations version

### Basic architecture

In [None]:
n,m = x_train.shape
c = y_train.max()+1
n,m,c

In [None]:
# num hidden
nh = 50

In [None]:
w1 = torch.randn(m, nh)
b1 = torch.zeros(nh)
w2 = torch.randn(nh, 1)
b2 = torch.zeros(1)

In [None]:
def lin(x, w, b): return x@w + b

In [None]:
t = lin(x_valid, w1, b1)
t.shape

In [None]:
def relu(x): return x.clamp_min(0.)

In [None]:
t = relu(t)
t

In [None]:
def model(xb):
    l1 = lin(xb, w1, b1)
    l2 = relu(l1)
    return lin(l2,w2,b2)

In [None]:
res = model(x_valid)
res.shape

### Loss function: MSE

(Of course, mse is not suitable loss function for multi-class classification; we'll use a better loss function soon.  We'll use mse for now to
keep thing simple)

In [None]:
res.shape,y_valid.shape

In [None]:
(res-y_valid).shape

We need to get rid of the trailing (,1), in order to use mse.

In [None]:
res[:,0].shape

In [None]:
res.squeeze().shape

In [None]:
(res[:,0]-y_valid).shape

In [None]:
y_train,y_valid = y_train.float(),y_valid.float()

preds = model(x_train)
preds.shape

In [None]:
def mse(output, targ): return (output[:,0]-targ).pow(2).mean()

In [None]:
mse(preds, y_train)

### Gradients and backward pass

In [None]:
from sympy import symbols,diff
x,y = symbols('x y')
diff(x**2, x)

In [None]:
diff(3*x**2+9, x)

In [None]:
def lin_grad(inp, out, w, b):
    # grad of matmul with respect to input
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)

In [None]:
def forward_and_backward(inp, targ):
    # forward pass:
    l1 = lin(inp, w1, b1)
    l2 = relu(l1)
    out = lin(l2, w2, b2)
    diff = out[:,0]-targ
    loss = diff.pow(2).mean()

    # backward pass:
    out.g = 2.*diff[:,None] / inp.shape[0]
    lin_grad(l2, out, w2, b2)
    l1.g = (l1>0).float() * l2.g
    lin_grad(inp, l1, w1, b1)

In [None]:
forward_and_backward(x_train, y_train)

In [None]:
# save for testing against later
def get_grad(x): return x.g.clone()
chks = w1,w2,b1,b2,x_train
grads = w1g,w2g,b1g,b2g,ig = tuple(map(get_grad, chks))

We cheat a litte bit and use PyTorch autograd to check our results.

In [None]:
def mkgrad(x): return x.clone().requires_grad_(True)
ptgrads = w12,w22,b12,b22,xt2 = tuple(map(mkgrad, chks))

In [None]:
def forward(inp, targ):
    l1 = lin(inp, w12, b12)
    l2 = relu(l1)
    out = lin(l2, w22, b22)
    return mse(out, targ)

In [None]:
loss = forward(xt2, y_train)
loss.backward()

In [None]:
for a,b in zip(grads, ptgrads): test_close(a, b.grad, eps=0.01)