In [6]:
from exp.nb_01 import *
from fastai import datasets

In [18]:
def get_data():
    path=datasets.download_data(MNIST_URL,ext='.gz')
    with gzip.open(path) as f:
        ((x_train,y_train),(x_valid,y_valid),_)=pickle.load(f,encoding='latin-1')
    return map(tensor,(x_train,y_train,x_valid,y_valid))

def normalize(x,m,s):return (x-m)/s


In [118]:
#export
import operator
def near(a,b): return torch.allclose(a, b, rtol=1e-3, atol=1e-5)
def test_near(a,b): test(a,b,near)


def test(a,b,cmp,cname=None):
    if cname is None: cname=cmp.__name__
    assert cmp(a,b),f"{cname}:\n{a}\n{b}"

def test_eq(a,b): test(a,b,operator.eq,'==')

In [19]:
x_train,y_train,x_valid,y_valid=get_data()

In [25]:
x_train_mean=x_train.mean()
x_train_std=x_train.std()
x_train_mean,x_train_std

(tensor(0.1304), tensor(0.3073))

In [26]:
x_train=normalize(x_train,x_train_mean,x_train_std)
# NB: Use training, not validation mean for validation set
x_valid=normalize(x_valid,x_train_mean,x_train_std)

In [27]:
x_train_mean=x_train.mean()
x_train_std=x_train.std()
x_train_mean,x_train_std

(tensor(-7.6999e-06), tensor(1.))

In [28]:
n,m = x_train.shape
c = y_train.max()+1
n,m,c

(50000, 784, tensor(10))

In [30]:
x_valid.mean(),x_valid.std()

(tensor(-0.0059), tensor(0.9924))

In [31]:
nh=50

In [67]:
w1=torch.randn(m,nh)/math.sqrt(m)
b1=torch.randn(nh)
w2=torch.randn(nh,1)/math.sqrt(nh)     #Kaiming normalization
b2=torch.zeros(1)

In [38]:
def lin(x,w,b):return x@w+b;

In [42]:
t=lin(x_valid,w1,b1)

In [43]:
t.std(),t.mean()

(tensor(0.9421), tensor(-0.0570))

In [53]:
def relu(x): return x.clamp_min(0.)-0.5

In [45]:
t=relu(lin(x_valid,w1,b1))

In [46]:
t.mean(),t.std()

(tensor(0.3448), tensor(0.5351))

In [68]:
w1=torch.randn(m,nh)*math.sqrt(2/m)
t=relu(lin(x_valid,w1,b1))
t.mean(),t.std()

(tensor(0.1842), tensor(0.9690))

In [56]:
#export
from torch.nn import init

In [69]:
#usinh torch kaiming
w1=torch.zeros(m,nh)
init.kaiming_normal_(w1,mode='fan_out')
t=lin(x_valid,w1,b1)
t.mean(),t.std()

(tensor(0.3935), tensor(1.6941))

In [60]:
init.kaiming_normal_??

In [62]:
w1.shape

torch.Size([784, 50])

In [70]:
def model(xb):
    l1=lin(xb,w1,b1)
    l2=relu(l1)
    l3=lin(l2,w2,b2)
    return l3

In [72]:
%timeit -n 10 _=model(x_valid)

4.09 ms ± 524 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [73]:
assert model(x_valid).shape==torch.Size([x_valid.shape[0],1])

In [74]:
model(x_valid).shape

torch.Size([10000, 1])

In [75]:
def mse(output,target):return (output.squeeze(-1)-target).pow(2).mean()

In [76]:
y_train,y_valid = y_train.float(),y_valid.float()

In [77]:
preds=model(x_train)

In [79]:
mse(preds,y_train)

tensor(28.1744)

In [80]:
##Gradients


In [81]:
def mse_grad(inp,targ):
    inp.g=2*(inp.squeeze(-1)-targ).unsqueeze(-1)/inp.shape[0]

In [82]:
def relu_grad(inp,out):
    inp.g=(inp>0).float()*out.g

In [83]:
def lin_grad(inp, out, w, b):
    # grad of matmul with respect to input
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)

In [84]:
def forward_backward(inp,targ):
    #forward
    l1=lin(inp,w1,b1)
    l2=relu(l1)             #lin(relu(lin()))
                            # lin -> relu -> lin->mse->targ
    out=lin(l2,w2,b2)
    
    loss=mse(out,targ)
    
    #backward
    mse_grad(out,targ)
    lin_grad(l2,out,w2,b2)
    relu_grad(l1,l2)
    lin_grad(inp,l1,w1,b1)

In [85]:
forward_backward(x_train,y_train)

In [86]:
# Save for testing against later
w1g = w1.g.clone()
w2g = w2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig  = x_train.g.clone()

In [122]:
xt2 = x_train.clone().requires_grad_(True)
w12 = w1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

In [88]:
def forward(inp,targ):
    l1=inp@w12+b12
    l2=relu(l1)
    out=l2@w22+b22
    return mse(out,targ)
    

In [123]:
loss=forward(xt2,y_train)

In [124]:
loss.backward()

# Model

In [125]:
class Relu():
    def __call__(self,inp):
        self.inp = inp
        self.out = inp.clamp_min(0.)-0.5
        return self.out
    def backward(self):
        self.inp.g=(self.inp>0).float()*self.out.g
        

In [126]:
class Lin():
    def __init__(self, w, b): self.w,self.b = w,b
    def __call__(self,inp):
        self.inp=inp
        self.out=inp@self.w+self.b
        return self.out
    def backward(self):
        self.inp.g = self.out.g @ self.w.t()
        self.w.g = (self.inp.unsqueeze(-1) * self.out.g.unsqueeze(1)).sum(0)
        self.b.g = self.out.g.sum(0)

In [127]:
class Mse():
    def __call__(self,inp,targ):
        self.inp=inp
        self.targ=targ
        self.out=(self.inp.squeeze(-1)-targ).pow(2).mean()
        return self.out;
    def backward(self):
        self.inp.g=2*(self.inp.squeeze(-1)-self.targ).unsqueeze(-1)/self.inp.shape[0]

In [128]:
class Model():
    def __init__(self,w1,b1,w2,b2):
        self.layer=[Lin(w1,b1),Relu(),Lin(w2,b2)]
        self.loss=Mse()
    def __call__(self,x,targ):
        for l in self.layer: x =l(x)
        return self.loss(x,targ)
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layer): l.backward()
    

In [129]:
w1.g,b1.g,w2.g,b2.g = [None]*4
model = Model(w1, b1, w2, b2)

In [130]:
m=model(x_train,y_train)

In [131]:
m.shape

torch.Size([])

In [132]:
%time model.backward()

CPU times: user 3.59 s, sys: 3.99 s, total: 7.58 s
Wall time: 1.96 s


In [135]:
test_near(w22.grad, w2g)
test_near(b22.grad, b2g)
test_near(w12.grad, w1g)
test_near(b12.grad, b1g)
test_near(xt2.grad, ig )

*Module Forward*

In [136]:
class Module():
    def __call__(self,*args):
        self.args=args
        self.out=self.forward(*args)
        return self.out
    def forward(self):raise Exception('not implemented')
    def backward(self): self.bwd(self.out, *self.args)

In [137]:
class Relu(Module):
    def forward(self, inp): return inp.clamp_min(0.)-0.5
    def bwd(self, out, inp): inp.g = (inp>0).float() * out.g

In [138]:
class Lin(Module):
    def __init__(self, w, b): self.w,self.b = w,b
        
    def forward(self, inp): return inp@self.w + self.b
    
    def bwd(self, out, inp):
        inp.g = out.g @ self.w.t()
        self.w.g = torch.einsum("bi,bj->ij", inp, out.g)
        self.b.g = out.g.sum(0)

In [139]:
class Mse(Module):
    def forward (self, inp, targ): return (inp.squeeze() - targ).pow(2).mean()
    def bwd(self, out, inp, targ): inp.g = 2*(inp.squeeze()-targ).unsqueeze(-1) / targ.shape[0]

In [140]:
class Model():
    def __init__(self,w1,b1,w2,b2):
        self.layer=[Lin(w1,b1),Relu(),Lin(w2,b2)]
        self.loss=Mse()
    def __call__(self,x,targ):
        for l in self.layer: x =l(x)
        return self.loss(x,targ)
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layer): l.backward()

In [141]:
w1.g,b1.g,w2.g,b2.g = [None]*4
model = Model(w1, b1, w2, b2)

In [142]:
%time model(x_train,y_train)

CPU times: user 136 ms, sys: 8 ms, total: 144 ms
Wall time: 36.4 ms


tensor(28.1744)

In [143]:
test_near(w22.grad, w2g)
test_near(b22.grad, b2g)
test_near(w12.grad, w1g)
test_near(b12.grad, b1g)
test_near(xt2.grad, ig )

# Using pytorch Module

In [147]:
#export 
from torch import nn

In [148]:
class Model(nn.Module):
    def __init__(self,n_in,n_h,n_out):
        super().__init__()
        self.layers=[nn.Linear(n_in,n_h),nn.ReLU(),nn.Linear(n_h,1)]
        self.loss=mse
    def __call__(self,x,target):
        for l in self.layers: x=l(x)
        return self.loss(x.squeeze(-1),target)
        

In [149]:
m

tensor(28.1744)

In [150]:
n,m = x_train.shape
c = y_train.max()+1
n,m,c

(50000, 784, tensor(10.))

In [155]:
model=Model(m,50,1)
%time loss=model(x_train,y_train)

CPU times: user 88 ms, sys: 0 ns, total: 88 ms
Wall time: 22.5 ms


In [156]:
%time loss.backward()

CPU times: user 144 ms, sys: 0 ns, total: 144 ms
Wall time: 37.3 ms


In [157]:
!./notebook2script.py 02_fully_connected.ipynb

Converted 02_fully_connected.ipynb to exp/nb_02.py


In [158]:
nn.Module??