In [34]:
from datasets import load_dataset

mnist = load_dataset("mnist").with_format("jax")
mnistData = mnist['train']
X_img = mnistData['image']
y = mnistData['label']
X_img_test = mnist["test"]["image"]
n_test_samples = X_img_test.shape[0]
X_test = X_img_test.reshape((n_test_samples, -1))
y_test = mnist["test"]["label"]

In [35]:
n_samples, _, _ = X_img.shape
X = X_img.reshape((n_samples, -1))
n_samples, d_in = X.shape
d_out = len(set(y.tolist()))

In [49]:
from jax import vmap, jit, grad
from jax.nn import sigmoid, softmax
from jax.numpy import zeros
from better_partial import _, partial as F
import sys
sys.path.append('../berries')
import init_utils, random_utils
from init_utils import zerO_init_2D
import optax

def affine(x, W, b):
    return W.T @ x + b

def swish(x):
    return x * sigmoid(x)

def mlp2l(x, W1, b1, W2, b2):
    affineF = F(affine)
    l1 = affineF(_, W1, b1)
    l2 = affineF(_, W2, b2)
    return l2(swish(l1(x)))

def accuracy(logits, y):
    return (logits.argmax(-1) == y).mean()

def get_accuracy(x, y, W):
    return accuracy(mlp2l(x, *W), y)


get_accuracy_b = vmap(get_accuracy, in_axes=(0, 0, None), out_axes=0)
get_accuracy_b_d = lambda W: get_accuracy_b(X, y, W).mean()
get_accuracy_b_t = lambda W: get_accuracy_b(X_test, y_test, W).mean()

seed = 0
key_gen = random_utils.infinite_safe_keys(seed)

print(X.shape)
d_h1 = 128
normal_init_std = 0.01

#W1 = init_utils.zerO_init_2D((d_in, d_h1))
W1 = init_utils.normal_init(next(key_gen), normal_init_std, (d_in, d_h1))
b1 = zeros((d_h1))
#W2 = init_utils.zerO_init_2D((d_h1, d_out))
W2 = init_utils.normal_init(next(key_gen), normal_init_std, (d_h1, d_out))
b2 = zeros((d_out))

mlp2l_b = vmap(mlp2l, in_axes=(0, None, None, None, None), out_axes=0)
mlp2l_b(X, W1, b1, W2, b2)

def sce_loss(to_logits, x, y, W):
    return optax.softmax_cross_entropy_with_integer_labels(to_logits(x, *W), y)


loss_b_all = vmap(F(sce_loss)(mlp2l, _, _, _), (0, 0, None), 0)
loss_d = F(loss_b_all)(X, y, _)
loss_b_d = lambda X: loss_d(X).mean()

W = (W1, b1, W2, b2)
loss0 = loss_b_d(W)
print(loss0)
lr = 0.001
opt = optax.rmsprop(lr)
state = opt.init(W)

@jit
def update(W, opt_state):
    grads = grad(loss_b_d)(W)
    updates, opt_state = opt.update(grads, opt_state)
    new_W = optax.apply_updates(W, updates)
    return new_W, opt_state

(60000, 784)
4.149068


In [56]:

for i in range(1000):
    W, state = update(W, state)
    if i % 50 == 0:
        print(get_accuracy_b_d(W), get_accuracy_b_t(W), loss_b_d(W))
print(loss_d(W))



1.0 0.97959995 2.2086206e-07
1.0 0.97959995 2.1861102e-07
1.0 0.97959995 2.1633811e-07
1.0 0.97959995 2.1417051e-07
1.0 0.97959995 2.1201087e-07
1.0 0.97959995 2.099247e-07
1.0 0.97959995 2.0785049e-07
1.0 0.97959995 2.059054e-07
1.0 0.97959995 2.038888e-07
1.0 0.97959995 2.0190397e-07
1.0 0.97959995 2.0007613e-07
1.0 0.97959995 1.982264e-07
1.0 0.97959995 1.9640649e-07
1.0 0.97959995 1.9457067e-07
1.0 0.97959995 1.9279646e-07
1.0 0.9795 1.9108383e-07
1.0 0.9795 1.8931955e-07
1.0 0.9795 1.876546e-07
1.0 0.9795 1.8605124e-07
1.0 0.9795 1.84432e-07
[0. 0. 0. ... 0. 0. 0.]
