In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import nn
from matplotlib import pyplot as plt
import gzip

In [None]:
!wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

In [None]:
# https://stackoverflow.com/a/53570674

def load_x(file_name):
  with gzip.open(file_name, 'r') as f:
    f.read(16)  # skip header
    buf = f.read()
    return jnp.frombuffer(buf, dtype=jnp.uint8).reshape(-1, 28, 28)

def load_y(file_name):
  with gzip.open(file_name, 'r') as f:
    f.read(8)  # skip header
    buf = f.read()
    return jnp.frombuffer(buf, dtype=jnp.uint8)

x_train = load_x('train-images-idx3-ubyte.gz').reshape(-1, 28*28) / 255
y_train = load_y('train-labels-idx1-ubyte.gz')
x_test = load_x('t10k-images-idx3-ubyte.gz').reshape(-1, 28*28) / 255
y_test = load_y('t10k-labels-idx1-ubyte.gz')

In [None]:
# How big is our dataset? What kind of data do we have? 

print(x_train.shape, ', ', x_train.dtype)
print(y_train.shape, ', ', y_train.dtype)
print()
print(x_test.shape, ', ', x_test.dtype)
print(y_test.shape, ', ', y_test.dtype)

In [None]:
# Look at an example
# Images are monochrome with integer pixel values between 0 and 255 (inclusive)

jnp.set_printoptions(linewidth=1000)
print(x_train[0].reshape(28, 28))

In [None]:
# Visualize some images and check their labels

import matplotlib.pyplot as plt
import numpy as np

for i in range(6):
  print(y_train[i])
  plt.imshow(x_train[i].reshape(28, 28))
  plt.show()
  print('')

# Define the model

In [None]:
rkey = random.PRNGKey(0)
INPUT_DIM = x_train.shape[-1]
HIDDEN_DIM = 256
OUTPUT_DIM = 10
SCALE = 1e-2

# Define and initialize the model parameters by sampling each element i.i.d. from a normal distribution
rkeys = random.split(rkey, 4)
W1 = SCALE * random.normal(rkey, (INPUT_DIM, HIDDEN_DIM))
W2 = SCALE * random.normal(rkey, (HIDDEN_DIM, OUTPUT_DIM))
B1 = SCALE * random.normal(rkey, (1, HIDDEN_DIM))
B2 = SCALE * random.normal(rkey, (1, OUTPUT_DIM))

parameters = [W1, W2, B1, B2]


def relu(x):
  return jnp.maximum(0, x)

def model_fn(x):
  h = relu(x @ W1 + B1)
  return h @ W2 + B2


In [None]:
def mle_loss(logits, labels):
  log_probs = nn.log_softmax(logits, axis=1)
  selected = jnp.take_along_axis(log_probs, labels[:, None], axis=1)
  return -selected.mean()

def regularizer(params):
  return sum(jnp.linalg.norm(p, 1) for p in params)  # regularization

def accuracy(logits, target):
  argmaxs = jnp.argmax(logits, axis=1)
  corrects = jnp.equal(argmaxs, target)
  return corrects.mean()

In [None]:
def forward(params, x, y, regularizer_weight):
  logits = model_fn(x)
  return mle_loss(logits, y) + regularizer_weight * regularizer(params)


@jit
def update(params, x, y, learning_rate, regularizer_weight):
  grads = grad(forward)(params, x, y, regularizer_weight)
  return [p - learning_rate * g for p, g in zip(params, grads)]

In [None]:
batch_size = 100
learning_rate = 1e-2
regularizer_weight = 1e-4

for epoch in range(1000):
  idx = np.random.permutation(x_train.shape[0])  # random ordering of the training set

  ## training step
  for i in range(0, x_train.shape[0], batch_size):
    x = x_train[idx[i:i+batch_size]]
    y = y_train[idx[i:i+batch_size]]

    parameters = update(parameters, x, y, learning_rate, regularizer_weight)

    if i % (batch_size * 100) == 0:
      loss_ = forward(parameters, x, y, regularizer_weight).tolist()
      print('Step:', i//batch_size, '; Loss:', loss_)

  train_logits = model_fn(x_train)  # Training accurate
  test_logits = model_fn(x_test)  # Test accuracy
  print('')
  print('Epoch: %d | Train Accuracy: %.2f | Test Accuracy: %.2f' % (epoch, accuracy(train_logits, y_train).tolist(), accuracy(test_logits, y_test).tolist()))
  print('')