In [26]:
# %load_ext autoreload
# %autoreload 2
# from nn.value import Value
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
%matplotlib inline
np.random.seed(1337)
random.seed(1337)

In [27]:
from nn.value import Value, exp

In [47]:
x1 = torch.tensor(5.,requires_grad=True)
x2 = torch.tensor(3.,requires_grad=True)
x3 = torch.exp(x1+x2) 
x3.backward()
x1.grad, x2.grad
z = x1.grad.item()

In [45]:
x1 = Value(5.)
x2 = Value(3.)
x3 = exp(x1+x2) 
x3.backward()
x1.grad, x2.grad

(2980.9579870417283, 2980.9579870417283)

In [48]:
z

2980.9580078125

In [32]:
x.grad

tensor(148.4132)

In [21]:
w1 = Value(5.)
b1 = Value(6.)


exp(Value(5))

Value(148.4131591025766, grad=0.0)

In [10]:
from dataclasses import dataclass
from contextlib import contextmanager
INFERENCE = False

@contextmanager
def inference_mode():
    global INFERENCE
    INFERENCE = true
    yield INFERENCE
    INFERENCE = false


class Module:
    def parameters(self,):
        return self._parameters

class BatchNorm:
    def __init__(self,):
        self.running_mean = None
        self.running_std = None
        self._parameters = None

    def __call__(self, x):
        mean = x.mean(0, keepdims=True)
        std = x.std(0, keepdims=True)

        if INFERENCE or self.running_mean is None:
            mean = self.running_mean
            std = self.running_std
        else:                        
            self.running_mean  = self.running_mean * .999 + mean * .01
            self.running_std  = self.running_std * .999 + std * .01
        
        x = x - mean / std
        return x

@dataclass
class Linear(Module):
    n_in: int
    n_out: int
    
    def __post_init__(self):
        self.w = [Value(random.random(), name='w') for i in range(self.n_out)]
        self.b = Value(random.random(), name='b')

        self._parameters = self.w + [self.b]

    def __call__(self, x):
        assert len(x) == self.n_in
        return [sum((x[i]*self.w[j] for i in range(self.n_in)), start=self.b) for j in range(self.n_out)]

@dataclass
class Relu(Module):
    def __init__(self):
        self._parameters = []
        
    def __call__(self, x):
        return [max(0, x_.data) * x_ for x_ in x]

@dataclass
class Tanh(Module):
    def __init__(self):
        self._parameters = []
        
    def __call__(self, x):
        return [(exp(2*x_)-1 ) / (exp(2*x_)+1) for x_ in x]

@dataclass
class MLP(Module):
    n_in: int
    n_out: int
    n_layers: int
    n_hidden: int
    
    def __post_init__(self):
        assert self.n_layers >= 1
        layers = [Linear(self.n_in, self.n_hidden), Tanh()]
        
        for i in range(self.n_layers-2):
            layers += [Linear(self.n_hidden, self.n_hidden), Tanh()]
            
        layers += [Linear(self.n_hidden, self.n_out)]

        self.layers = layers

    def parameters(self):
        return [p for layer in self.layers for p in layer.parameters() ]
    
    def __call__(self, x):
        for layer in self.layers:
            x = [layer(x_) for x_ in x]
        return x

X = inputs = [[Value(5., name='x1'), Value(6., name='x2')],
              [Value(3., name='x1'), Value(5., name='x2')]]
y = 2

mlp = MLP(2, 1, n_layers=1, n_hidden=3)

In [20]:
lr = .001

#forward
y_hat = mlp(X)
loss = sum((y - y_hat_[0])**2 for y_hat_ in y_hat)

print(mlp.parameters())

#backward
for p in mlp.parameters():
    p.grad = 0

loss.backward()

for p in mlp.parameters():
    p.data -= p.grad * p.data

[Value(w=-140811395.6500865, grad=0.0), Value(w=-54089237872.0708, grad=0.0), Value(w=-189.08374681259593, grad=0.0), Value(b=-4443179871.905231, grad=0.0), Value(w=-8168478052095251.0, grad=-541593871.5921905), Value(b=20590960663592.914, grad=180531290.53073016)]


In [18]:
loss

Value(4073943357586362.5, grad=1.0)

In [10]:
mlp.parameters()

[Value(w=-392.03045006390767, grad=0.0),
 Value(w=-296.0164755065418, grad=0.0),
 Value(w=-145.44493717290726, grad=0.0),
 Value(b=-91.33222568579471, grad=0.0),
 Value(w=-926.7610552891709, grad=0.0),
 Value(b=-302363111.0525898, grad=-34780.68310128581)]

In [6]:
x

Value(622.5901125806454, grad=1.0)