# JAX Code in python for classical cnn, using MNIST dataset 50 size train and 30 size test

In [5]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit
from jax.experimental import optimizers
from jax.experimental.stax import Conv, Relu, MaxPool, Flatten, Dense, LogSoftmax

from jax.scipy.special import logsumexp
from jax.nn import softmax_cross_entropy

from jax.experimental.stax import serial

from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

def one_hot(x, k, dtype=jnp.float32):
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

def data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    y_train, y_test = one_hot(y_train, 10), one_hot(y_test, 10)
    return x_train[:50], y_train[:50], x_test[:30], y_test[:30]

def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return jnp.mean(softmax_cross_entropy(targets, preds))

def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(predict(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)

init_random_params, predict = serial(Conv(32, (5, 5), (1, 1), 'SAME'),
                                     Relu,
                                     MaxPool((2, 2), (2, 2), 'VALID'),
                                     Conv(64, (5, 5), (1, 1), 'SAME'),
                                     Relu,
                                     MaxPool((2, 2), (2, 2), 'VALID'),
                                     Flatten,
                                     Dense(1024),
                                     Relu,
                                     Dense(10),
                                     LogSoftmax)

_, initial_params = init_random_params(random.PRNGKey(0), (-1, 28, 28, 1))

opt_init, opt_update, get_params = optimizers.adam(1e-3)
opt_state = opt_init(initial_params)

@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

x_train, y_train, x_test, y_test = data()

for i in range(100):
    for j in range(0, len(x_train), 10):
        batch = (x_train[j:j+10], y_train[j:j+10])
        opt_state = update(i, opt_state, batch)
    train_acc = accuracy(get_params(opt_state), (x_train, y_train))
    test_acc = accuracy(get_params(opt_state), (x_test, y_test))
    print(f"Step {i}, Train Accuracy: {train_acc}, Test Accuracy: {test_acc}")

ModuleNotFoundError: No module named 'jax'