In [1]:
import torch

In [2]:
class Module():
    def __call__(self, *args):
        self.args = args
        self.y = self.forward(*args)
        return self.y

    def forward(self, *args): raise NotImplementedError
    def backward(self, *args): self.bwd(self.y, *self.args)
    

In [3]:
class relu(Module):
    def forward(self, x):  return x.clamp_min(0.)-0.5
    def bwd(self, y, x): x.g = (x>0).float()* y.g

In [124]:
class linear(Module):
    def __init__(self, in_size, out_size):
        self.w = torch.zeros(in_size, out_size)
        self.b = torch.zeros(out_size)
        self.w.g = torch.zeros(in_size, out_size)
        self.b.g = torch.zeros(out_size)
        torch.nn.init.kaiming_normal_(self.w, mode='fan_out')


    def forward(self, x):
        y = x.T @ self.w + self.b
        print('mean: ', y.mean(), 'std: ' ,y.std())
        return y.T
    
    def bwd(self, y, x):
        x.g = y.g.T @ self.w.T
        self.w.g = torch.einsum('ij,ki->ik', x, y)
        self.b.g = torch.sum(y.g, dim=0)


In [125]:
class Mse(Module):
    def forward (self, x, targ): return (x.squeeze() - targ).pow(2).mean()
    def bwd(self, y, x, targ): x.g = 2*(x.squeeze()-targ).unsqueeze(-1) / targ.shape[0]

In [126]:
class Model():
    def __init__(self):
        self.layers = [linear(16, 5), relu(), linear(5, 2)]
        self.loss = Mse()
        
    def __call__(self, x, targ):
        for l in self.layers: x = l(x)
        return self.loss(x, targ)
    
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers): l.backward()


In [127]:
model = Model()

In [128]:
a = torch.randn(16,1)

In [129]:
a.mean(), a.std()

(tensor(0.2433), tensor(0.7794))

In [130]:
model(a, torch.randn(1))

mean:  tensor(-0.2884) std:  tensor(0.9857)
mean:  tensor(0.8023) std:  tensor(0.4716)


tensor(1.4701)

In [131]:
model.backward()

In [139]:
model.layers[0].w.g.mean(), model.layers[0].w.g.std()

(tensor(-0.0702), tensor(0.7368))