In [2]:
import math
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%run engine.ipynb


In [3]:
import random

class Neuron:
    def __init__(self,nin):
        self.w = [Value(random.uniform(-1,1)) for _ in range(nin)]  #range(2) #[Value(-0.69), Value(0.37)]
        self.b = Value(random.uniform(-1,1))


    def __call__(self, x):
        # w * x + b
        
        act = sum((wi*xi for wi, xi in zip(self.w, x)), self.b)
        out = act.tanh()
        return out

    def parameters(self):
        return self.w + [self.b]

class Layer:
    def __init__(self, nin, nout):
        self.neurons = [Neuron(nin) for _ in range(nout)]

    def __call__(self, x):
        outs = [n(x) for n in self.neurons]
        return outs[0] if len(outs)== 1 else outs

    def parameters(self):
        return [p for neuron in self.neurons for p in neuron.parameters()]
        # params = []
        # for neuron in self.neurons:
        #     ps = neuron.parameters()
        #     params.extend(ps)
        # return params
        

class MLP:
    def __init__(self, nin, nouts):
        sz = [nin] + nouts
        self.layers = [Layer(sz[i], sz[i+1]) for i in range(len(nouts))]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    def parameters(self):
        return [p for layer in self.layers for p in layer.parameters()]



In [4]:
#final
xs = [
    [2.0,3.0,-1.0],
    [3.0,-1.0,0.5],
    [0.5,1.0,1.0],
    [1.0,1.0,-1.0],
]
ys = [1.0,-1.0,-1.0,1.0] #desired targets
n = MLP(3,[4,4,1])

In [5]:
#final
for k in range(20):
    
    # forward pass
    ypred = [n(x) for x in xs]
    loss = sum([(yout - ygt)**2 for ygt, yout in zip(ys, ypred)])

    #backward pass
    for p in n.parameters():
        p.grad=0.0
    loss.backward()
    
    #upgrade    
    for p in n.parameters():
        p.data += -0.01 * p.grad

    print(k,loss.data)

0 7.7605960830323975
1 7.21910284307103
2 7.043804264128693
3 6.887508203231934
4 6.7043516379517225
5 6.478960599346664
6 6.202383099569554
7 5.875333821160886
8 5.515488340154154
9 5.156831760658026
10 4.832964742500063
11 4.559750998865158
12 4.335666358403729
13 4.1515034537745885
14 3.9970247759439474
15 3.8634159000581807
16 3.743729334257949
17 3.632652057081207
18 3.5261424430472768
19 3.4211229840832567
