In [None]:
from torch import empty
import torch
torch.set_grad_enabled(False)

### The module and sgd optimizer

In [None]:
class Module(object):
    
    def forward(self, *input):
        raise NotImplementedError
        
    def backward(self, *gradwrtoutput):
        raise NotImplementedError
        
    def param(self):
        return []


class Linear(Module):
    def __init__(self, in_features, out_features, bias=True):
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.params = {}
        self.w = empty(out_features, in_features).fill_(1)
        self.params['weight'] = [self.w]
        self.gradwrt_w = []
        self.params['grad'] = [self.gradwrt_w]
        if bias:
            self.b = empty(out_features).fill_(1)
            self.params['weight'].append(self.b)
            self.gradwrt_b = []
            self.params['grad'].append(self.gradwrt_b)
        self.input = None
        
    def forward(self, *input):
        self.input = input
        l = []
        
        if self.bias:
            l = [self.w @ tensor + self.b for tensor in input]
        else:
            l = [self.w @ tensor for tensor in input]
        
        return tuple(l)
        
    def backward(self, *gradwrtoutput):
        l = []
        
        for i in range(len(gradwrtoutput)):
            # with respect to the input
            l += [(gradwrtoutput[i] @ (self.w))]
            
            # with respect to the weight
            gradwrt_w = gradwrtoutput[i].view(-1,1).mm(self.input[i].view(1,-1))
            if len(self.gradwrt_w) != len(gradwrtoutput):
                self.gradwrt_w.append(empty(gradwrt_w.size()).fill_(0).squeeze()) # (1) this can be optimized
            self.gradwrt_w[i].add_(gradwrt_w.squeeze())
         
            # with respect to the bias
            if self.bias:
                gradwrt_b = gradwrtoutput[i]
                if len(self.gradwrt_b) != len(gradwrtoutput):
                    self.gradwrt_b.append(empty(gradwrt_b.size()).fill_(0).squeeze()) # (1) this can be optimized
                self.gradwrt_b[i].add_(gradwrt_b.squeeze())
        
        return tuple(l)
        
    def param(self):
        return self.params


class SGD(object):
    def __init__(self, params, lr):
        self.params = params
        self.lr = lr
    
    def step(self):
        for i in range(len(self.params['weight'])):
            for j in range(len(self.params['grad'][0])):
                self.params['weight'][i] -= self.lr * self.params['grad'][i][j]
            
    def zero_grad(self):
        for i in range(len(self.params['weight'])):
            self.params['grad'][i].clear() # (1) this can be optimized

### Example

In [None]:
# ------------------------------------------------------------------
# Control the randomness
# ------------------------------------------------------------------
import torch
torch.manual_seed(0)

# ------------------------------------------------------------------
# model
# ------------------------------------------------------------------
m = Linear(2, 3)
m2 = Linear(3, 4)
sgd = SGD(m.param(), 0.1)
sgd2 = SGD(m2.param(), 0.1)

# ------------------------------------------------------------------
# input and error and reset the gradients
# ------------------------------------------------------------------
input = empty(2)
input[0] = 1
input[1] = 2

# arbitrary error need 3 grad_loss because we have 3 inputs below
grad_loss = empty(4).fill_(10),  empty(4).fill_(5),  empty(4).fill_(1)

# zeroes the gradients as one would do in a training setting
sgd.zero_grad()
sgd2.zero_grad()

# ------------------------------------------------------------------
# forward pass
# ------------------------------------------------------------------
x = m.forward(input, input, input)
x = m2.forward(*x)
print("m params after forward {}\n".format(m.param()))
print("m2 params after forward {}\n".format(m2.param()))

# ------------------------------------------------------------------
# backward pass
# ------------------------------------------------------------------
x = m2.backward(*grad_loss)
output = m.backward(*x)
print("output (error with respect to input) {}\n".format(output))
print("m params after backward {}\n".format(m.param()))
print("m2 params after backward {}\n".format(m2.param()))

# ------------------------------------------------------------------
# gradient step
# ------------------------------------------------------------------
sgd.step()
sgd2.step()
print("m params after step {}\n".format(m.param()))
print("m2 params after step {}\n".format(m2.param()))

# zeroes the gradients as one would do in a training setting
sgd.zero_grad()
sgd2.zero_grad()
print("m params after zeroing the gradients {}\n".format(m.param()))
print("m2 params after zeroing the gradients {}\n".format(m2.param()))