# Import

In [1]:
import torch
from torch import FloatTensor, LongTensor
import math

In [2]:
from load_script_deep_framework import load_dataset

In [9]:
%load_ext autoreload
%autoreload 2

In [31]:
import torch.nn as nn

In [34]:
?nn.Linear.forward

# Loading the dataset

In [11]:
train,target_train=load_dataset()

In [12]:
test,target_test=load_dataset()

# Generic Module Class

In [86]:
class Module():
    
    def forward(self,input):
        raise NotImplemented
        
    def backward(self,input):
        raise NotImplemented
            
    def param(self):
        return
    
    def __call__(self,*input):
        return self.forward(*input)
    

# Specific classes

In [126]:
class Linear(Module):
    
    def __init__(self,input_features,output_features):
        super(Linear,self).__init__()
        
        self._input_features=input_features
        self._output_features=output_features
    
        self._weights=torch.rand(self._output_features,self._input_features)
        self._gradient=torch.zeros(self._weights.shape)
    
    def forward(self,input):
        self._input=input.view(-1)
    
        self._output=self._weights.mv(self._input)
        return self._output.clone()
        
    def backward(self,d_dy):
        self._gradient.add(d_dy.view(-1,1)*self._input.view(1,-1))
        
        d_dx=self._weights.t().mv(d_dy)
        return d_dx
    

In [113]:
class ReLU(Module):
    def __init__(self):
        super(ReLU,self).__init__()
        
    def forward(self,input):
        self._input=input.clone()
        
        self._output=self._input.clone()
        self._output[self._output<0]=0

        return self._output.clone()
    
    def backward(self,d_dy):
        d_dx=d_dy.clone()
        d_dx[self._input<0]=0
        
        return d_dx
        

In [114]:
class LossMSE(Module):
    def __init__(self):
        super(LossMSE,self).__init__()
        
    def forward(self,input,target):
        self._input=input-target
        self._output=(self._input).pow(2).sum()
        return self._output.clone()
        
    def backward(self):
        d_dx=2*self._input
        return d_dx