In [None]:
!pip install --upgrade -q pip jax jaxlib flax optax tensorflow-datasets

[K     |████████████████████████████████| 2.1 MB 35.8 MB/s 
[K     |████████████████████████████████| 990 kB 54.3 MB/s 
[K     |████████████████████████████████| 71.3 MB 61 kB/s 
[K     |████████████████████████████████| 197 kB 69.9 MB/s 
[K     |████████████████████████████████| 140 kB 72.4 MB/s 
[K     |████████████████████████████████| 4.3 MB 51.9 MB/s 
[K     |████████████████████████████████| 38.1 MB 1.2 MB/s 
[K     |████████████████████████████████| 98 kB 8.5 MB/s 
[K     |████████████████████████████████| 596 kB 69.2 MB/s 
[K     |████████████████████████████████| 217 kB 71.7 MB/s 
[K     |████████████████████████████████| 51 kB 7.9 MB/s 
[K     |████████████████████████████████| 72 kB 696 kB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
albumentations 0.1.12 requires i

In [None]:
import jax
import jax.numpy as jnp               # JAX NumPy

from flax import linen as nn          # The Linen API
from flax.training import train_state
import optax                          # The Optax gradient processing and optimization library

import numpy as np                    # Ordinary NumPy
import tensorflow_datasets as tfds    # TFDS for MNIST

In [None]:
class CNN(nn.Module):

  @nn.compact
  # Provide a constructor to register a new parameter 
  # and return its initial value
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1)) # Flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)    # There are 10 classes in MNIST
    return x

In [None]:
def compute_metrics(logits, labels):
  loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=10)))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }
  return metrics

In [None]:
def get_datasets():
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  # Split into training/test sets
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  # Convert to floating-points
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
  return train_ds, test_ds

In [None]:
@jax.jit
def train_step(state, batch):
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = jnp.mean(optax.softmax_cross_entropy(
        logits=logits, 
        labels=jax.nn.one_hot(batch['label'], num_classes=10)))
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits, batch['label'])
  return state, metrics

In [None]:
@jax.jit
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits, batch['label'])

In [None]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]  # Skip an incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))

  batch_metrics = []

  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch)
    batch_metrics.append(metrics)

  training_batch_metrics = jax.device_get(batch_metrics)
  training_epoch_metrics = {
      k: np.mean([metrics[k] for metrics in training_batch_metrics])
      for k in training_batch_metrics[0]}

  print('Training - epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, training_epoch_metrics['loss'], training_epoch_metrics['accuracy'] * 100))

  return state, training_epoch_metrics

In [None]:
def eval_model(model, test_ds):
  metrics = eval_step(model, test_ds)
  metrics = jax.device_get(metrics)
  eval_summary = jax.tree_map(lambda x: x.item(), metrics)
  return eval_summary['loss'], eval_summary['accuracy']

In [None]:
train_ds, test_ds = get_datasets()

[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to ~/tensorflow_datasets/mnist/3.0.1...[0m


Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]

[1mDataset mnist downloaded and prepared to ~/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m




In [None]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

In [None]:
cnn = CNN()
params = cnn.init(init_rng, jnp.ones([1, 28, 28, 1]))['params']

In [None]:
nesterov_momentum = 0.9
learning_rate = 0.001
tx = optax.sgd(learning_rate=learning_rate, nesterov=nesterov_momentum)

In [None]:
state = train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

In [None]:
num_epochs = 10
batch_size = 32

In [None]:
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state, train_metrics = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch
  test_loss, test_accuracy = eval_model(state.params, test_ds)
  print('Testing - epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))

Training - epoch: 1, loss: 1.7941, accuracy: 62.73
Testing - epoch: 1, loss: 0.93, accuracy: 82.32
Training - epoch: 2, loss: 0.6114, accuracy: 85.10
Testing - epoch: 2, loss: 0.44, accuracy: 88.45
Training - epoch: 3, loss: 0.4128, accuracy: 88.40
Testing - epoch: 3, loss: 0.36, accuracy: 89.89
Training - epoch: 4, loss: 0.3598, accuracy: 89.67
Testing - epoch: 4, loss: 0.32, accuracy: 90.81
Training - epoch: 5, loss: 0.3280, accuracy: 90.50
Testing - epoch: 5, loss: 0.30, accuracy: 91.54
Training - epoch: 6, loss: 0.3047, accuracy: 91.18
Testing - epoch: 6, loss: 0.28, accuracy: 91.95
Training - epoch: 7, loss: 0.2853, accuracy: 91.72
Testing - epoch: 7, loss: 0.26, accuracy: 92.26
Training - epoch: 8, loss: 0.2680, accuracy: 92.16
Testing - epoch: 8, loss: 0.24, accuracy: 92.89
Training - epoch: 9, loss: 0.2522, accuracy: 92.73
Testing - epoch: 9, loss: 0.23, accuracy: 93.12
Training - epoch: 10, loss: 0.2384, accuracy: 92.99
Testing - epoch: 10, loss: 0.22, accuracy: 93.54
