In [1]:
import torch
from torchvision.datasets import MNIST

# data preprocessing

In [2]:
ds = MNIST('.',train = True, download= True)

In [3]:
def normalize(x,m,s):
    return (x-m)/s

In [4]:
x = ds.data.reshape(60000,(28*28))/255.
y = ds.targets

In [5]:
mean = x.mean()
std = x.std()
mean , std

(tensor(0.1307), tensor(0.3081))

In [6]:
x = normalize(x,mean,std)

In [7]:
x.mean(),x.std()

(tensor(-1.6608e-09), tensor(1.0000))

In [8]:
n, m = x.shape
c = y.max()+1

In [9]:
n,m,c

(60000, 784, tensor(10))

# weight initialization

In [10]:
n_in = 28 *28
nh = 50
n_out = 1

In [11]:
# kaiming init / he init for relu, which can not maintain mean and std
w1 = torch.randn(n_in,nh)* (2/m)**0.5 
b1 = torch.zeros(nh)
w2 = torch.randn(nh,n_out)* (2/m)**0.5 
b2 = torch.zeros(1)

# Function based NN

In [12]:
def lin(x, w, b):
    return x @ w + b

In [13]:
def relu(x):
    return x.clamp_min(0).float()

In [14]:
def mse(pred, y):
    return ((pred.squeeze(-1)-y)**2).float().mean()

In [15]:
#Al 
def mse_grad(pred, y):
    pred.g = 2. * (pred.squeeze() - y).unsqueeze(-1) / y.shape[0]

In [16]:
def relu_grad(Z, A):
    Z.g = (Z > 0).float() * A.g

In [17]:
def lin_grad(inp, out, w, b):
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)

In [18]:
def forward_backward(x, y, w1, b1, w2, b2):
    l1 = lin(x, w1, b1)
    l2 = relu(l1)
    out = lin(l2, w2, b2)
    
    loss = mse(out, y)
    
    mse_grad(out ,y)
    lin_grad(l2, out, w2, b2)
    relu_grad(l1, l2)
    lin_grad(x, l1, w1, b1)

In [19]:
forward_backward(x,y,w1,b1,w2,b2)

# Class based NN

In [20]:
class Module:
    def __call__(self, *args):
        self.args = args
        self.out = self.forward(*args)
        return self.out
        
    def forward(self):
        raise Exception('not implemented')
    def bwd(self):
        raise Exception('not implemented')
        
    def backward(self):
        self.bwd(self.out,*self.args)

In [21]:
class Relu(Module):
    def forward(self,x):
        return x.clamp_min(0.)-0.5
    def bwd(self, out, x):
        x.g = (x > 0).float()*out.g
    

In [22]:
class Lin(Module):
    def __init__(self,w,b):
        self.w = w
        self.b = b
    
    def forward(self, x):
        return ( x @ self.w ) + self.b
    
    def bwd(self,out ,x):
        x.g = out.g @ self.w.T
        self.w.g = x.T @ out.g
        self.b.g = out.g.sum(0)

In [23]:
class Mse(Module):
    def forward(self,pred,y):
        return ((pred.squeeze() -y )**2).mean()
    def bwd(self,out, pred , y):
        pred.g = 2*(pred.squeeze() - y ).unsqueeze(-1)/y.shape[0]
        

In [24]:
class Model:
    def __init__(self,w1,b1,w2,b2):
        self.layers = [Lin(w1,b1),Relu(),Lin(w2,b2)]
        self.loss = Mse()
    def __call__(self,x,y):
        for l in self.layers:
            x = l(x)
        return self.loss(x,y)
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers):
            l.backward()

In [25]:
model = Model(w1,b1,w2,b2)

In [26]:
loss = model(x,y)

In [27]:
model.backward()

# Pytorch NN

In [28]:
import torch.nn as nn

In [29]:
def mse(output, targ): return (output.squeeze(-1) - targ).pow(2).mean()

In [30]:
class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out)]
        self.loss = mse
        
    def __call__(self, x, targ):
        for l in self.layers: x = l(x)
        return self.loss(x.squeeze(), targ)

In [31]:
model =Model(n_in,nh,n_out)

In [32]:
loss = model(x,y)

In [33]:
loss.backward()