In [1]:
from datasets import load_dataset

# load iris
iris = load_dataset("scikit-learn/iris")


In [2]:

from jax.numpy import vstack, array


iris_train = iris['train']

iris_train.set_format('jax', columns=['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm', 'Species'])
X = vstack((iris_train['SepalLengthCm'], iris_train['SepalWidthCm'], iris_train['PetalLengthCm'], iris_train['PetalWidthCm']))
y_str = iris_train['Species']
# y = y_str.map({'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2})
# y_str is a list
y = array([0 if s == 'Iris-setosa' else 1 if s == 'Iris-versicolor' else 2 for s in y_str])
y_1hot = array([[1, 0, 0] if s == 'Iris-setosa' else [0, 1, 0] if s == 'Iris-versicolor' else [0, 0, 1] for s in y_str])
y_1hot.shape



(150, 3)

In [67]:
def affine(x, W, b):
    return x @ W + b

from bokeh.io import output_notebook, push_notebook, show
from jax.numpy import arange
import sys
sys.path.append('../berries')
from positional_encoding import get_positional_encoding
import init_utils, plot_utils
from init_utils import zerO_init_2D
from plot_utils import plot_xys
from importlib import reload
reload(init_utils)
reload(plot_utils)
output_notebook()


dim_x = 10
dim_y = 1
dim_hidden = 5
n_samples = 100


Xy_ = get_positional_encoding(n_samples, dim_x+dim_y)
X_ = Xy_[:, 1:]
y_ = Xy_[:, 0]



In [68]:
def loss_fn(W, b):
    y_pred = affine(X_, W, b)
    return ((y_pred - y_) ** 2).mean()


from jax import grad, jit
from bokeh.plotting import figure
import optax
from jax.numpy import array

W_ = init_utils.zerO_init_2D((dim_x, dim_y))
b_ = array(0.)
lr = 0.1

opt = optax.sgd(lr)

state = opt.init((W_, b_))

ts = [0]
losses = [loss_fn(W_, b_)]


@jit
def update(W, opt_state):
    grads = grad(loss_fn, argnums=(0, 1))(*W)
    updates, opt_state = opt.update(grads, opt_state)
    new_W = optax.apply_updates(W, updates)
    return new_W, opt_state


# loss_plot = plot_xys(array([0]), [array([loss_fn(W_, b_)])], "loss", labels=['loss'])
loss_plot = figure(
    title="loss", x_axis_label="epoch", y_axis_label="loss", width=900, height=300
)
loss_curve = loss_plot.circle(array(ts), array(losses))
lp_target = show(loss_plot, notebook_handle=True)

xy_plot = figure(
    title="xy", x_axis_label="index", y_axis_label="y", width=900, height=300
)
x_line = xy_plot.line(arange(n_samples), X_[:, 0], legend_label="x0", color="green")
y_line = xy_plot.line(arange(n_samples), y_, legend_label="y")
y_pred_line = xy_plot.line(
    arange(n_samples), y_pred, color="red", legend_label="y_pred"
)
xy_target = show(xy_plot, notebook_handle=True)


for i in range(100):
    ts.append(i)
    losses.append(loss_fn(W_, b_))

    (W_, b_), state = update((W_, b_), state)

    if i % 10 == 0:
        y_pred = affine(X_, W_, b_)
        y_pred_line.data_source.data["y"] = y_pred
        push_notebook(handle=xy_target)
        # print(loss_curve.data_source.data)
        # loss_curve.data_source.data = {"x": array(ts), "y": array(losses)}
        # push_notebook(handle=lp_target)