In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
#export 
from exp.tests import MNIST_URL, test, test_eq, test_near
from fastai import datasets
import pickle, gzip
from torch import tensor

def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_tr, y_tr), (x_vl, y_vl), _) = pickle.load(f, encoding='latin-1')
        return map(tensor, (x_tr, y_tr, x_vl, y_vl))

In [None]:
import matplotlib.pyplot as plt
import torch
import matplotlib as mpl

In [None]:
x_tr, y_tr, x_vl, y_vl = get_data()

In [None]:
x = x_tr[0]
mpl.rcParams['image.cmap'] = 'gray'
plt.imshow(x.view(28,28))

In [None]:
tr_mean, tr_std = x_tr.mean(), x_tr.std()
vl_mean, vl_std = x_vl.mean(), x_vl.std()
tr_mean, tr_std, vl_mean, vl_std

In [None]:
#export
def normalize(x, m=None, s=None):
    assert (m is None and s is None) or (m is not None and s is not None)
    if m is None:
        m = x.mean()
        s = x.std()
    return (x - m) / s
def get_stats(x):
    return x.mean(), x.std()

In [None]:
get_stats(normalize(x))

In [None]:
#export
def test_near_zero(a, tol=1e-3): assert a.abs()<tol,f"Near zero: {a}"

In [None]:
test_near_zero(tensor([0.0001]))

In [None]:
assert y_tr.max() == 9
assert y_tr.min() == 0

In [None]:
# NOTE: try with higher standard deviations later
# may help to get more precision
# since distance between two points in weights space can be shown with higher precision

In [None]:
nh = 50
n, m = x_tr.shape
c = y_tr.max() + 1
n, m, c

In [None]:
#export 
import math
def init_weights(i,o,isForward=True):
    # xavier
    return torch.randn(i,o)*math.sqrt(2/(i if isForward else o))

In [None]:
w1 = init_weights(m, nh)
b1 = torch.zeros(nh)
w2 = init_weights(nh, 1)
b2 = torch.zeros(1)
w1.shape, b1.shape, w2.shape, b2.shape

In [None]:
get_stats(w1), get_stats(w2)

In [None]:
w1.std(), tensor(1/math.sqrt(m))

In [None]:
w1.std() - 1/math.sqrt(2/m)

In [None]:
# test_near_zero(w1.std() - 1/math.sqrt(2/m))

In [None]:
x_trm, x_trs = get_stats(x_tr)
x_tr = normalize(x_tr, x_trm, x_trs)
x_vl = normalize(x_vl, x_trm, x_trs)
x_tr.mean(), x_tr.std(), x_vl.mean(), x_vl.std()

In [None]:
def lin(x, w, b): return x@w + b

In [None]:
%time t = lin(x_tr, w1, b1)
t.mean(), t.std()

In [None]:
def relu(x): return x.clamp_min(0.) - 0.5

In [None]:
%time t1 = relu(lin(x_tr, w1, b1))
t1.mean(), t1.std()

In [None]:
def model(xb):
    l1 = lin(xb, w1, b1)
    l2 = relu(l1)
    l3 = lin(l2, w2, b2)
    return l3

In [None]:
%timeit -n 10 _ = model(x_tr)

In [None]:
y_tr_hat = model(x_tr)
y_tr_hat.shape

In [None]:
y_tr.shape

In [None]:
#export 
def mse(output, target):
    return (output.squeeze(-1) - target).pow(2).mean()

In [None]:
mse(y_tr_hat, y_tr)

mse : (summation (y - yhat)^2) / number of elements
diff: summ 2*(y-yhat) / n = 2(y-yhat).mean

In [None]:
def mse_grad(inp, out):
    inp.g = 2 * (inp.squeeze(1) - out)[...,None] / out.shape[0]

In [None]:
mse_grad(y_tr_hat, y_tr)#, mse_grad1(y_tr_hat, y_tr)
y_tr_hat.g.shape

In [None]:
def relu_grad(inp, out):
    inp.g = (inp > 0).float() * out.g

 linear: y = x @ w + b
 dy/dx = wT
 
 mse( relu (lin(x)))
 
 x -> lin -> x1 -> relu -> x2 -> mseloss
 
mseloss = (y-x2)^2 / n
dmseloss/dx2 = 2(y-x2) / n
dmseloss/dx1 = 2(y-x2)/n * dx2/dx1

dmseloss/dx = dmseloss/dx1 * dx1/dx
x1 = x@w + b

look for explanation here.
or in jeremys linear algebra paper
http://cs231n.stanford.edu/handouts/linear-backprop.pdf
dx1/dx = w.T
dmseloss/dx = dmessloss/dx1 @ w.T 

In [None]:
def lin_grad(inp, out, w, b):
    inp.g = out.g @ w.t
    w.g = out.g * x