In [None]:
%matplotlib inline

from jax import random
import jax.numpy as np
from jax import grad, jit, vjp, vmap

import pandas as pd
import matplotlib.pyplot as plt

import xaby

In [None]:
data = pd.read_csv('data/iris.data', 
                   names=['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species'])

labels = data['species'].unique().tolist()
x = data.iloc[:, :4].values
x = (x - x.mean(axis=0))/x.std(axis=0) # normalize
y = data['species'].replace(labels, [0, 1, 2]).values

In [None]:
# Randomly shuffle and split into train/test
split = 25
key = xaby.random.key_manager.key
combined_data = np.hstack((x, y[:, None]))
shuffled = random.shuffle(key, combined_data)
train_x, test_x = shuffled[:-split, :4], shuffled[-split:, :4]
train_y, test_y = shuffled[:-split, 4].astype(np.int8), shuffled[-split:, 4].astype(np.int8)

In [None]:
from xaby import nn

In [None]:
train_losses = []
test_losses = []

# Data tensors, x and y were generated above
inputs = xaby.Tensor(train_x)
targets = xaby.Tensor(train_y)
test_inputs = xaby.Tensor(test_x)
test_targets = xaby.Tensor(test_y)

# Optimize with Stochastic Gradient Descent
optimize = xaby.optim.SGD(lr=0.0003)

# Define model
model = nn.linear(4, 3) >> xaby.log_softmax

# Backpropagate with Negative Log-Likelihood loss
backprop = model << xaby.losses.nlloss

# Backprop and update network
for i in range(200):
    loss, grads = inputs >> backprop << targets
    model >> optimize << grads
    
    train_losses.append(loss/len(inputs))
    test_loss = test_inputs >> model >> xaby.losses.nlloss << test_targets
    test_losses.append(test_loss.item()/len(test_inputs))
    
    if i % 10 == 0:
        print(loss, test_loss.item())

In [None]:
plt.plot(train_losses, label='Train')
plt.plot(test_losses, label='Test')
plt.legend()

In [None]:
predictions = np.argmax((test_inputs >> net).data, axis=1)
accuracy = (predictions == test_targets.data).mean()
print(accuracy)

In [None]:
predictions

In [None]:
np.exp((test_inputs >> net).data)