In [129]:
from datasets import load_dataset

# load iris
iris = load_dataset("scikit-learn/iris")

from jax.numpy import vstack, array


iris_train = iris['train']

iris_train.set_format('jax', columns=['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm', 'Species'])
X = vstack((iris_train['SepalLengthCm'], iris_train['SepalWidthCm'], iris_train['PetalLengthCm'], iris_train['PetalWidthCm'])).T
y_str = iris_train['Species']
# y = y_str.map({'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2})
# y_str is a list
y = array([0 if s == 'Iris-setosa' else 1 if s == 'Iris-versicolor' else 2 for s in y_str])
y_1hot = array([[1, 0, 0] if s == 'Iris-setosa' else [0, 1, 0] if s == 'Iris-versicolor' else [0, 0, 1] for s in y_str])


In [130]:
from jax import vmap, jit, grad
from jax.nn import sigmoid, softmax
from jax.numpy import zeros
from better_partial import _, partial as F
import sys
sys.path.append('../berries')
import init_utils, plot_utils, random_utils
from init_utils import zerO_init_2D
import optax

def affine(x, W, b):
    return W.T @ x + b

def swish(x):
    return x * sigmoid(x)

def mlp2l(x, W1, b1, W2, b2):
    affineF = F(affine)
    l1 = F(affine)(_, W1, b1)
    l2 = F(affine)(_, W2, b2)
    return l2(swish(l1(x)))

seed = 0
key_gen = random_utils.infinite_safe_keys(seed)

print(X.shape)
n_samples, d_in = X.shape
n_samples_y, d_out = y_1hot.shape
assert n_samples == n_samples_y
d_h1 = 128

W1 = init_utils.zerO_init_2D((d_in, d_h1))
#W1 = init_utils.normal_init(next(key_gen), 1, (d_in, d_h1))
b1 = zeros((d_h1))
W2 = init_utils.zerO_init_2D((d_h1, d_out))
#W2 = init_utils.normal_init(next(key_gen), 1, (d_h1, d_out))
b2 = zeros((d_out))

mlp2l_b = vmap(mlp2l, in_axes=(0, None, None, None, None), out_axes=0)
mlp2l_b(X, W1, b1, W2, b2)

def sce_loss(to_logits, x, y, W):
    return optax.softmax_cross_entropy_with_integer_labels(to_logits(x, *W), y)


loss_b_all = vmap(F(sce_loss)(mlp2l, _, _, _), (0, 0, None), 0)
loss_d = F(loss_b_all)(X, y, _)
loss_b_d = lambda X: loss_d(X).mean()

W = (W1, b1, W2, b2)
loss0 = loss_b_d(W)
print(loss0)

@jit
def update(W, opt_state):
    grads = grad(loss_b_d)(W)
    updates, opt_state = opt.update(grads, opt_state)
    new_W = optax.apply_updates(W, updates)
    return new_W, opt_state

(150, 4)
1.3279482


In [131]:
lr = 0.01
opt = optax.sgd(lr)
state = opt.init(W)

for i in range(1000):
    W, state = update(W, state)
    if i % 10 == 0:
        print(loss_b_d(W))
print(loss_d(W))

1.2430714
1.0597048
0.9776005
0.9087785
0.8501591
0.7998119
0.75630563
0.71852314
0.68555737
0.65665466
0.6311825
0.6086086
0.5884853
0.57043713
0.5541501
0.5393617
0.52585274
0.5134391
0.50196797
0.49131024
0.48135743
0.47201848
0.46321666
0.45488563
0.44697022
0.43942282
0.43220252
0.4252752
0.41861042
0.41218182
0.4059676
0.3999486
0.39410797
0.3884308
0.38290495
0.3775192
0.37226376
0.3671308
0.3621122
0.35720196
0.35239476
0.3476857
0.34307003
0.33854416
0.33410487
0.32974884
0.32547405
0.3212776
0.31715792
0.31311262
0.30914062
0.3052399
0.30140936
0.2976474
0.29395306
0.29032513
0.2867626
0.28326437
0.2798294
0.27645695
0.27314606
0.26989585
0.26670521
0.26357344
0.26049984
0.25748318
0.25452295
0.251618
0.24876791
0.24597134
0.2432279
0.24053612
0.23789561
0.23530543
0.2327648
0.23027286
0.22782837
0.225431
0.22307993
0.22077361
0.21851186
0.21629359
0.21411823
0.21198462
0.20989227
0.2078402
0.20582731
0.20385335
0.20191711
0.20001815
0.19815546
0.1963285
0.19453628
0.19277833