In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [3]:
#export
from exp.nb_01 import *

def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path,'rb') as f:
        ((x_train,y_train),(x_valid,y_valid),_) = pickle.load(f,encoding='latin-1')
    return map(tensor, (x_train,y_train,x_valid,y_valid))

def normalize(x,m,s): return(x-m)/s

In [7]:
x_train,y_train,x_valid,y_valid = get_data()

In [10]:
train_mean,train_std = x_train.mean(),x_train.std()
train_mean,train_std

(tensor(-7.6999e-06), tensor(1.))

In [9]:
x_train = normalize(x_train,train_mean,train_std)
x_valid = normalize(x_valid,train_mean,train_std)

In [11]:
n,m = x_train.shape
c = y_train.max()+1

## Basic architecture

In [12]:
#num hidden
nh = 50

In [32]:
#simplified kaiming init/ he init
w1 = torch.randn(m,nh)*math.sqrt(2/m)
b1 = torch.zeros(nh)
w2 = torch.randn(nh,1)*math.sqrt(2/m)
b2 = torch.zeros(1)

In [38]:
def lin(x,w,b): return x@w+b
def relu(x): return x.clamp_min(0.)-0.5

In [39]:
t = lin(x_valid,w1,b1)

In [45]:
t.mean(),t.std()

(tensor(0.0559), tensor(0.8323))

In [36]:
t = relu(t)

In [41]:
#export
from torch.nn import init

In [44]:
w1 = torch.zeros(m,nh)
init.kaiming_normal_(w1, mode="fan_out")
t = relu(lin(x_valid,w1,b1))

In [52]:
def model(x):
    l1 = lin(x,w1,b1)
    l2 = relu(l1)
    return lin(l2,w2,b2)

In [56]:
assert model(x_valid)[:,0].shape==y_valid.shape

## Loss function: MSE

In [57]:
model(x_valid).shape

torch.Size([10000, 1])

In [58]:
#export
def mse(output,targ): return (output.squeeze(-1)-targ).pow(2).mean()

In [59]:
y_train, y_valid = y_train.float(),y_valid.float()

In [60]:
preds = model(x_train)

In [62]:
mse(preds, y_train)

tensor(29.7456)

## Gradients and backward pass

In [63]:
def mse_grad(inp,targ):
    inp.g = 2*(inp-targ).unsqueeze(-1) / inp.shape[0]

In [64]:
def relu_grad(inp, out): 
    inp.g = (inp>0).float() * out.g

In [65]:
def lin_grad(inp, out, w, b):
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1)*out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)