In [1]:
import torch
import math
#torch.set_grad_enabled(False)

In [2]:
class Module(object):
    def forward (self, *input):
        raise NotImplementedError
        
    def backward ( self , * gradwrtoutput ) :
        raise NotImplementedError
    
    def param ( self ) :
        return []   

In [3]:
class Losses(object):        
    def function():
        return NotImplementedError
    def derivative():
        NotImplementedError

In [4]:
class Optimizers(object):
    def step():
        return NotImplementedError

In [5]:
class Parameter():
    def __init__(self):
        self.data = None
        self.grad = None

In [6]:
class Linear(Module):
    
    def __init__(self, input_dim, out_dim, bias = True):
        super().__init__()
        std = 1/math.sqrt(input_dim)
        self.weight = Parameter()
        
        self.weight.data = torch.rand(out_dim, input_dim)
        self.weight.data = 2*std*self.weight.data - std
        
        self.with_bias = bias
        if bias :
            self.bias = Parameter()
            self.bias = torch.rand(out_dim)
            self.bias = 2*std*self.bias.data - std
            
        self.x = None
              
    def forward(self, x):
        self.x = x
        return self.weight.data.mv(x) + self.bias.data
        
    def backward(self, prev_grad):
        
        prev_grad = prev_grad.view(-1, 1)
        if self.x is None:
            raise CallForwardFirst
        
        if self.weight.grad is None:
            self.weight.grad = torch.zeros_like(self.weight.data)
        
        self.weight.grad += prev_grad.view(-1, 1)*self.x.view(1, -1)
        
        if self.with_bias:
            if self.bias.grad is None:
                self.bias.grad = torch.zeros_like(self.bias.data)
            self.bias.grad += prev_grad.view(-1)
        
        next_grad = prev_grad.view(1, -1)@self.weight.data
        next_grad = next_grad.view(-1, 1)
        return next_grad

In [7]:
class ReLu(Module):
    def __init__(self):
        self.x = None
    
    def forward (self, x):
        self.x = x
        return x.tanh()
        
    def backward ( self, prev_grad) :
        if self.x is None:
            raise CallForwardFirst
            
        def d(x):
            return 4 * (x.exp() + x.mul(-1).exp()).pow(-2)
        
        prev_grad = prev_grad.view(-1)
        deriv = d(self.x)
        return deriv*prev_grad
    

In [8]:
class MSE(Losses):
    def function(self, v, t):
        return (v - t).pow(2).mean()
    
    def derivative(self, v, t):
        return 2 * (v - t)

In [9]:
#forward, comparing with torch
linear = Linear(5, 6, True)
builtin_linear = torch.nn.Linear(5, 6)
linear.weight.data = builtin_linear.weight.data
linear.bias.data = builtin_linear.bias.data

x = torch.randn(5)
b = torch.randn(5)
y = torch.randn(6)
linear.forward(x), builtin_linear(x)

(tensor([-0.0944,  0.9078,  1.0506, -0.0373,  0.2254, -0.3821]),
 tensor([-0.0944,  0.9078,  1.0506, -0.0373,  0.2254, -0.3821],
        grad_fn=<AddBackward0>))

In [10]:
#backward, comparing with torch
linear = Linear(5, 6, True)
builtin_linear = torch.nn.Linear(5, 6)
linear.weight.data = builtin_linear.weight.data
linear.bias.data = builtin_linear.bias.data
relu = ReLu()

#building loss derivative
builtin_output = builtin_linear(x)
builtin_loss = torch.nn.MSELoss()(builtin_output, y)
builtin_loss_derivative = torch.autograd.grad( builtin_loss, builtin_output,  )[0]#.detach().copy()
der = builtin_loss_derivative.detach()
der.requires_grad = False

relu.forward(linear.forward(x))
linear.backward((der))

###
builtin_loss = torch.nn.MSELoss()(builtin_linear(x), y)
builtin_loss.backward()




In [11]:
linear.weight.grad, builtin_linear.weight.grad, linear.weight.grad == builtin_linear.weight.grad 

(tensor([[ 0.0666,  0.0258,  0.2986, -0.0188,  0.1002],
         [-0.0459, -0.0178, -0.2056,  0.0130, -0.0690],
         [-0.0734, -0.0285, -0.3291,  0.0207, -0.1104],
         [ 0.0060,  0.0023,  0.0269, -0.0017,  0.0090],
         [-0.1086, -0.0421, -0.4867,  0.0307, -0.1633],
         [-0.2709, -0.1051, -1.2143,  0.0765, -0.4075]]),
 tensor([[ 0.0666,  0.0258,  0.2986, -0.0188,  0.1002],
         [-0.0459, -0.0178, -0.2056,  0.0130, -0.0690],
         [-0.0734, -0.0285, -0.3291,  0.0207, -0.1104],
         [ 0.0060,  0.0023,  0.0269, -0.0017,  0.0090],
         [-0.1086, -0.0421, -0.4867,  0.0307, -0.1633],
         [-0.2709, -0.1051, -1.2143,  0.0765, -0.4075]]),
 tensor([[True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True]]))