In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tinyAG.engine import Value
from tinyAG.draw_dot import draw_dot
from tinyAG.nn import Neuron, Layer, MLP
from tinyAG.losses import losses, MSE

In [2]:
xs = [
    [2.0, 3.0, -1.0],
    [3.0, -1.0, 0.5],
    [0.5, 1.0, 1.0],
    [1.0, 1.0, -1.0],
]

ys = [1.0, -1.0, -1.0, 1.0]

In [3]:
model = MLP(3, [4, 4, 1])
print(model)

MLP of [
 Layer of [tanh Neuron(3), tanh Neuron(3), tanh Neuron(3), tanh Neuron(3)] 
 Layer of [tanh Neuron(4), tanh Neuron(4), tanh Neuron(4), tanh Neuron(4)] 
 Layer of [linear Neuron(4)] 
]


In [4]:
def lossi(X, y, batch_size=None):
    
    # inline DataLoader :)
    if batch_size is None:
        Xb, yb = X, y
    else:
        ri = np.random.permutation(X.shape[0])[:batch_size]
        Xb, yb = X[ri], y[ri]
    inputs = [list(map(Value, xrow)) for xrow in Xb]
    
    # forward the model to get scores
    scores = list(map(model, inputs))
    
    # svm "max-margin" loss
    losses = sum((scorei - yi)**2 for yi, scorei in zip(yb, scores))
    ### [(1 + -yi*scorei).tanh() for yi, scorei in zip(yb, scores)]
    data_loss = losses #sum(losses) * (1.0 / len(losses))
    # L2 regularization
    #alpha = 1e-4
    #reg_loss = alpha * sum((p*p for p in model.parameters()))
    total_loss = data_loss #+ reg_loss
    
    # also get accuracy
    accuracy = [(yi > 0) == (scorei.data > 0) for yi, scorei in zip(yb, scores)]
    return total_loss, sum(accuracy) / len(accuracy)

In [5]:
epochs = 1000

old_loss = 0.0
tot_loss = Value(0.0)
alpha = 1.0
loss = MSE()
for epoch in range(epochs):
    ypred = list(map(model, xs)) 
    old_loss = tot_loss.data 
    tot_loss = loss(ys, ypred) 

    model.zero_grad()
    tot_loss.backward()

    alpha = (alpha * 1e-1) if (tot_loss.data >= old_loss) else alpha
    learning_rate = (1.0 - 9e-1 * epoch / epochs) * alpha
    for p in model.parameters():
        p.data -= learning_rate * p.grad

    if epoch % (epochs/10) == 0:
        print(f"epoch {epoch} -- loss {tot_loss} -- accuracy")


epoch 0 -- loss Value(data=1.2092222610065044) -- accuracy
epoch 100 -- loss Value(data=0.003966250092263554) -- accuracy
epoch 200 -- loss Value(data=0.0005739395659517589) -- accuracy
epoch 300 -- loss Value(data=9.123731109321765e-05) -- accuracy
epoch 400 -- loss Value(data=1.701015813222591e-05) -- accuracy
epoch 500 -- loss Value(data=3.8746461787978954e-06) -- accuracy
epoch 600 -- loss Value(data=1.096045017211613e-06) -- accuracy
epoch 700 -- loss Value(data=3.8733751028002114e-07) -- accuracy
epoch 800 -- loss Value(data=1.713750053320363e-07) -- accuracy
epoch 900 -- loss Value(data=9.499814609335254e-08) -- accuracy
