In [1]:
import torch
import math
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x10f6da358>

### Data Generation

In [2]:
def generate_disc_set(nb):
    input_ = torch.empty(nb, 2).uniform_(0,1)
    target_ = torch.empty(nb).long()
    
    for i in range(nb):
        if (torch.norm(input_[i]) < math.sqrt(1/(2*math.pi))):
            target_[i] = 1
        else : target_[i] = 0
    
    return input_, target_

In [3]:
train_input, train_target = generate_disc_set(1000)
test_input, test_target = generate_disc_set(1000)

### Activation functions

In [4]:
def Tanh_fun(x):
    return x.tanh()
    
def d_Tanh(x):
    return (1 - torch.pow(Tanh_fun(x), 2))

def ReLU_fun(x):
    return x * (x > 0).float()

def d_ReLU(x):
    return 1. * (x > 0).float()

### Loss

In [5]:
def LossMSE(v, t):
    return torch.sum(torch.pow(t-v, 2)).item()
    
def d_LossMSE(v, t):
    return (2*(v-t)).item()

### Forward and backward passes

In [6]:
class Module(object):
    
    def forward_pass(self, *input):
        raise NotImplementedError

    def backward_pass(self, *gradwrtoutput):
        raise NotImplementedError
        
    def param(self):
        return []

In [7]:
class Linear(Module):
    
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        self.w = torch.empty(self.out_features, self.in_features).normal_(0,1e-6)
        self.b = torch.empty(self.out_features).normal_(0,1e-6)
        self.dl_dw = torch.zeros(self.w.size())
        self.dl_db = torch.zeros(self.b.size())
        
        self.cache_forward = None
    
    def forward_pass(self, input):
        self.cache_forward = input
        return self.w @ input + self.b

    def backward_pass(self, gradwrtoutput):
        self.dl_dw += gradwrtoutput.view(self.out_features,1) @ self.cache_forward.view(1,self.in_features)
        self.dl_db += gradwrtoutput
        self.cache_forward = None
        return gradwrtoutput @ self.w
        
    def param(self):
        return [(self.w, self.dl_dw), (self.b, self.dl_db)]
    
    def zerograd(self):
        self.dl_dw = torch.zeros(self.w.size())
        self.dl_db = torch.zeros(self.b.size())
        
    def update(self, eta):
        self.w -= eta * self.dl_dw
        self.b -= eta * self.dl_db
    

class ReLU(Module):
    
    def __init__(self):
        super(ReLU, self).__init__()
        self.cache_forward = None
    
    def forward_pass(self, input):
        self.cache_forward = input
        return ReLU_fun(input)

    def backward_pass(self, gradwrtoutput):
        dl_s = d_ReLU(self.cache_forward) * gradwrtoutput
        self.cache_forward = None
        return dl_s
        
    def param(self):
        return []
    
    def zerograd(self):
        return []

class Tanh(Module):
    
    def __init__(self):
        super(Tanh, self).__init__()
        self.cache_forward = None
    
    def forward_pass(self, input):
        self.cache_forward = input
        return Tanh_fun(input)

    def backward_pass(self, gradwrtoutput):
        dl_s = d_Tanh(self.cache_forward) * gradwrtoutput
        self.cache_forward = None
        return dl_s
        
    def param(self):
        return []
    
    def zerograd(self):
        return []

In [8]:
class Sequential(Module):
    
    def __init__(self, *args):
        super(Sequential, self).__init__()
        self.modules = []
        self.backward = None
        for module in args:
            self.modules.append(module)
    
    def forward_pass(self, input):
        self.forward = input
        for module in self.modules:
            self.forward = module.forward_pass(self.forward)
        return torch.argmax(self.forward)
            
    def backward_pass(self, target):
        self.backward = d_LossMSE(torch.argmax(self.forward).item(), target)
        for module in reversed(self.modules):
            self.backward = module.backward_pass(self.backward)
    
    def zerograd(self):
        for module in self.modules:
            module.zerograd()
            
    def update(self, eta):
        for module in self.modules:
            if(len(module.param()) > 0):
                module.update(eta)

In [15]:
def train(model, train_input, train_target, eta, mini_batch_size, epochs):
    losses = []
    for e in range(epochs):
        total_loss = 0
        for b in range(0, train_input.size(0), mini_batch_size):
            for i in range(mini_batch_size):
                output = model.forward_pass(train_input[b+i])
                model.backward_pass(train_target[b+i])
                total_loss += LossMSE(output, train_target[b+i])
            model.update(eta)
            model.zerograd()
        losses.append(total_loss)
    return losses

In [18]:
eta = 0.01
mini_batch_size = 100
epochs = 10

model = Sequential(Linear(2,25), ReLU(), Linear(25,25), ReLU(), Linear(25,2), Tanh())
train(model, train_input, train_target, eta, mini_batch_size, epochs)

[15,
 10,
 14,
 14,
 85,
 83,
 83,
 92,
 90,
 91,
 85,
 90,
 86,
 86,
 85,
 83,
 83,
 92,
 90,
 91,
 85,
 90,
 86,
 86,
 85,
 83,
 83,
 92,
 90,
 91,
 85,
 90,
 86,
 86,
 85,
 83,
 83,
 92,
 90,
 91,
 85,
 90,
 86,
 86,
 85,
 83,
 83,
 92,
 90,
 91,
 85,
 90,
 86,
 86,
 85,
 83,
 83,
 92,
 90,
 91,
 85,
 90,
 86,
 86,
 85,
 83,
 83,
 92,
 90,
 91,
 85,
 90,
 86,
 86,
 85,
 83,
 83,
 92,
 90,
 91,
 85,
 90,
 86,
 86,
 85,
 83,
 83,
 92,
 90,
 91,
 85,
 90,
 86,
 86,
 85,
 83,
 83,
 92,
 90,
 91]