<a href="https://colab.research.google.com/github/hunsii/jax-practice/blob/main/jax_lan_p1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax
import jax.numpy as jnp
from jax import random
from jax import grad, jit, vmap
from jax import lax
from jax.example_libraries import optimizers
from jax.tree_util import tree_flatten, tree_unflatten
from jax.nn import softmax, relu
from jax.scipy.special import logsumexp
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.datasets import mnist


In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
# from jax.experimental import optimizers
from jax.example_libraries import optimizers
from jax.nn import relu, softmax, log_softmax

# Load MNIST dataset
from tensorflow.keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 784) / 255.0
test_images = test_images.reshape(test_images.shape[0], 784) / 255.0
train_labels = jax.nn.one_hot(train_labels, 10)
test_labels = jax.nn.one_hot(test_labels, 10)

# Define neural network architecture
def net(params, x):
  w1, b1, w2, b2 = params
  hidden = jnp.dot(x, w1) + b1
  hidden = relu(hidden)
  logits = jnp.dot(hidden, w2) + b2
  return logits

# Define loss function
def loss(params, x, y):
  logits = net(params, x)
  return -jnp.mean(jnp.sum(y * log_softmax(logits), axis=1))

# Define accuracy metric
def accuracy(params, images, labels):
  preds = net(params, images)
  return jnp.mean(jnp.argmax(preds, axis=1) == jnp.argmax(labels, axis=1))

# Initialize network parameters
key = random.PRNGKey(0)
input_shape = (-1, 784)
hidden_shape = 256
output_shape = 10
w1 = random.normal(key, (784, hidden_shape))
b1 = jnp.zeros(hidden_shape)
w2 = random.normal(key, (hidden_shape, output_shape))
b2 = jnp.zeros(output_shape)
params = (w1, b1, w2, b2)

# Define optimizer
step_size = 0.001
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)

# Define training step
@jit
def update(params, x, y, opt_state):
  grads = grad(loss)(params, x, y)
  return opt_update(0, grads, opt_state), get_params(opt_state)

# Train the network
num_epochs = 10
batch_size = 128
num_batches = train_images.shape[0] // batch_size

for epoch in range(num_epochs):
  for batch in range(num_batches):
    start_idx = batch * batch_size
    end_idx = (batch + 1) * batch_size
    batch_images = train_images[start_idx:end_idx]
    batch_labels = train_labels[start_idx:end_idx]
    opt_state, params = update(params, batch_images, batch_labels, opt_state)
  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {}: train acc = {:.3f}, test acc = {:.3f}".format(epoch+1, train_acc, test_acc))


Epoch 1: train acc = 0.811, test acc = 0.809
Epoch 2: train acc = 0.853, test acc = 0.847
Epoch 3: train acc = 0.874, test acc = 0.867
Epoch 4: train acc = 0.888, test acc = 0.879
Epoch 5: train acc = 0.897, test acc = 0.888
Epoch 6: train acc = 0.905, test acc = 0.893
Epoch 7: train acc = 0.912, test acc = 0.898
Epoch 8: train acc = 0.917, test acc = 0.902
Epoch 9: train acc = 0.922, test acc = 0.906
Epoch 10: train acc = 0.927, test acc = 0.911
