In [1]:
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl, numpy as np
from pathlib import Path
from torch import tensor
from fastcore.test import test_close
torch.manual_seed(42)

mpl.rcParams['image.cmap'] = 'gray'
torch.set_printoptions(precision=2, linewidth=125, sci_mode=False)
np.set_printoptions(precision=2, linewidth=125)

path_data = Path('data')
path_gz = path_data/'mnist.pkl.gz'

In [2]:
with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])

## initialize params

In [3]:
n,m = x_train.shape
c = y_train.max() + 1
n,m,c

(50000, 784, tensor(10))

In [4]:
nh = 200

In [5]:
w1=torch.randn(m,nh)
b1=torch.zeros(nh)
w2=torch.randn(nh, 1)
b2 = torch.zeros(1)

## code for model's layers, loss, and gradients

In [6]:
def lin(x, w, b):
    return x@w + b

In [7]:
def relu(x):
    return x.clamp_min(0.)

In [8]:
y_train,y_valid=map(lambda t: t.float(), (y_train,y_valid))
y_train,y_valid

(tensor([5., 0., 4.,  ..., 8., 4., 8.]),
 tensor([3., 8., 6.,  ..., 5., 6., 8.]))

In [9]:
def mse(output, targ):
    return (output[:,0]-targ).pow(2).mean()

In [10]:
def lin_grad(inp, out, w, b):
    inp.g = out.g @ w.t()
    w.g = inp.t() @ out.g
    b.g = out.g.sum(0)

In [11]:
def forward_and_backward(xb, yb):
    # forward
    l1 = lin(xb, w1, b1)
    l2 = relu(l1)
    preds = lin(l2, w2, b2)
    loss = mse(preds, yb)

    # backpropagation
    n=xb.shape[0]
    '''
    preds is shape (batch_size, num_activation_features) = (batch_size, 1), aka a column vector
    yb shape is (batch_size) and needs to be transformed into column vector aka (batch_size, 1)
    hence yb[:,None], in order to do elementwise subtraction
    '''
    preds.g = (2./n) * (preds-yb[:,None]) 
    lin_grad(l2, preds, w2, b2)
    l1.g = (l1 > 0).float() * l2.g
    lin_grad(xb, l1, w1, b1)

In [12]:
forward_and_backward(x_train, y_train)

## comparing gradients to pytorch computed gradients based on same (cloned) params and same (duplicated) forward pass

In [13]:
def get_homemade_grad(t):
    return t.g.clone()

In [14]:
tensors = w1, b1, w2, b2, x_train
my_grads = tuple(map(get_homemade_grad, tensors))

In [15]:
def clone_tensor_with_grads(t):
    return t.clone().requires_grad_(True)

In [16]:
pt_tensors = tuple(map(clone_tensor_with_grads, tensors))
w1c, b1c, w2c, b2c, x_trainc = pt_tensors

In [17]:
def forward_only(inp, targ):
    l1 = lin(inp, w1c, b1c)
    l2 = relu(l1)
    out = lin(l2, w2c, b2c)
    return mse(out, targ)

In [18]:
loss = forward_only(x_trainc, y_train)
loss.backward()

In [19]:
for my_grad,their_tensor in zip(my_grads, pt_tensors):
    test_close(my_grad, their_tensor.grad, eps=0.01)

## Refactor model into object oriented

In [20]:
class Relu():
    def __call__(self, inp):
        self.inp = inp
        self.out = inp.clamp_min(0.)
        return self.out

    def backward(self):
        self.inp.g = (self.inp > 0).int() * self.out.g

In [21]:
class Lin():
    def __init__(self, w, b):
        self.w = w
        self.b = b
    def __call__(self, inp):
        self.inp = inp
        self.out = inp @ self.w + self.b
        return self.out

    def backward(self):
        self.w.g = self.inp.t() @ self.out.g
        self.b.g = self.out.g.sum(0)
        self.inp.g = self.out.g @ self.w.t()

In [22]:
class Mse():
    def __call__(self, preds, targ):
        self.preds = preds
        self.diff = preds - targ[:,None]
        self.out = (self.diff).pow(2).mean()
        return self.out

    def backward(self):
        self.preds.g = (2. / self.diff.shape[0]) * self.diff

In [39]:
class Model():
    def __init__(self, w1, b1, w2, b2):
        self.layers = [
            Lin(w1, b1),
            Relu(),
            Lin(w2, b2)
        ]
        self.loss = Mse()

    def __call__(self, x, y):
        for l in self.layers:
            x = l(x)
        loss = self.loss(x, y)
        return (x, loss)

    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers):
            l.backward()


In [24]:
model = Model(w1, b1, w2, b2)
preds,loss = model(x_train, y_train)

In [25]:
w1,w2,b1,b2,x_train,y_train = tuple(map(lambda t: t.to('cuda'), (w1,w2,b1,b2,x_train,y_train,)))

In [26]:
def acc(preds, y):
    return (preds[:,0].round() == y).half().mean()

In [27]:
pw1, pb1, pw2, pb2 = w1.clone(), b1.clone(), w2.clone(), b2.clone()

In [28]:
w1, b1, w2, b2= pw1, pb1, pw2, pb2

In [29]:
lr = 0.05
for i in range(1000000):
    model = Model(w1, b1, w2, b2)
    preds,loss = model(x_train, y_train)
    model.backward()
    if i % 1000 == 0:
        print(i, loss,acc(preds, y_train))
    w1 -= lr * w1.g
    b1 -= lr * b1.g
    w2 -= lr * w2.g
    b2 -= lr * b2.g

0 tensor(5253.24, device='cuda:0') tensor(0.00, device='cuda:0', dtype=torch.float16)
1000 tensor(8.36, device='cuda:0') tensor(0.10, device='cuda:0', dtype=torch.float16)
2000 tensor(8.36, device='cuda:0') tensor(0.10, device='cuda:0', dtype=torch.float16)
3000 tensor(8.36, device='cuda:0') tensor(0.10, device='cuda:0', dtype=torch.float16)
4000 tensor(8.36, device='cuda:0') tensor(0.10, device='cuda:0', dtype=torch.float16)
5000 tensor(8.36, device='cuda:0') tensor(0.10, device='cuda:0', dtype=torch.float16)
6000 tensor(8.36, device='cuda:0') tensor(0.10, device='cuda:0', dtype=torch.float16)
7000 tensor(8.36, device='cuda:0') tensor(0.10, device='cuda:0', dtype=torch.float16)


KeyboardInterrupt: 

## Refactor, use forward() method

In [30]:
class Module():
    def __call__(self, *args):
        self.args = args
        self.out = self.forward(*args)
        return self.out

    def forward(self): 
        None
    def backward(self):
        # save the args in flight in __call__ to pass them again here. still need to declare them in signature,
        # but don't need to chase down vals to pass them in again
        self.bwd(self.out, *self.args)
    def bwd(self):
        None

In [41]:
class Relu(Module):
    def forward(self, inp):
        return inp.clamp_min(0.)
    def bwd(self, out, inp):
        inp.g = out.g * (inp > 0).float()

In [42]:
class Lin(Module):
    def __init__(self, w, b):
        self.w = w
        self.b = b
    def forward(self, inp):
        return inp @ self.w + self.b

    def bwd(self, out, inp):
        inp.g = out.g @ self.w.t()
        self.w.g = inp.t() @ out.g
        self.b.g = out.g.sum(0)

In [33]:
class Mse(Module):
    def forward(self, inp, targ):
        return (inp.squeeze() - targ).pow(2).mean()

    def bwd(self, out, inp, targ):
        inp.g = 2*(inp.squeeze() - targ).unsqueeze(-1) / targ.shape[0]

In [43]:
model = Model(w1, b1, w2, b2)

In [44]:
loss = model(x_train, y_train)

In [45]:
loss

(tensor([[4.45],
         [4.45],
         [4.45],
         ...,
         [4.45],
         [4.45],
         [4.45]], device='cuda:0'),
 tensor(8.36, device='cuda:0'))

In [46]:
model.backward()

## autograd

In [48]:
from torch import nn
import torch.nn.functional as F

In [61]:
class Linear(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.w = torch.randn(n_in, n_out).to('cuda').requires_grad_()
        self.b = torch.zeros(n_out).to('cuda').requires_grad_()
    def forward(self, inp):
        return inp@self.w + self.b

In [62]:
class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [Linear(n_in, nh), nn.ReLU(), Linear(nh, n_out)]

    def __call__(self, x, targ):
        for l in self.layers:
            x = l(x)
        return (x, F.mse_loss(x, targ[:, None]))

In [63]:
model = Model(m, nh, 1)

In [64]:
out,loss = model(x_train, y_train)

In [65]:
loss.backward()

In [67]:
model.layers[0].b.grad

tensor([    25.94,     39.71,     59.51,     13.67,     59.99,    -52.92,    -95.64,      5.11,     26.07,     -4.32,
           -67.83,   -183.01,    -26.47,    -33.72,     70.47,      2.54,    -48.64,   -118.62,     23.59,      9.02,
          -245.84,    -35.43,    -25.58,     24.89,    -14.35,     -4.79,    -40.82,    -68.20,     16.47,    122.15,
          -179.85,    -30.03,    -24.63,    -36.48,    214.56,     35.95,     65.49,     16.53,     15.27,    -74.02,
           -67.38,    -28.08,     11.56,    -83.02,     23.20,     -7.73,     24.34,    -46.75,     99.27,   -162.95,
           -43.80,     69.14,      9.87,   -110.22,    -13.57,     -2.81,     -5.78,    -30.22,    -19.45,     54.48,
            40.27,     62.32,     -7.23,    -16.30,    125.21,      8.19,     68.10,     -1.61,    -28.98,     -2.74,
            32.40,    -23.22,     51.09,     45.71,    -45.75,     -5.73,     11.74,    -11.81,      1.66,    208.56,
            83.04,     63.33,      4.97,    -36.27,    -