From Flax's annotated MNIST

In [1]:
import jax
import jax.numpy as jnp

from flax import linen as nn
from flax.training import train_state

import numpy as np
import optax
import tensorflow_datasets as tfds
import tensorflow as tf
tf.config.experimental.set_visible_devices([], "GPU")

In [2]:
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x

def cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
    onehot = jax.nn.one_hot(labels, num_classes=10)
    return -jnp.mean(jnp.sum(onehot * logits, axis=-1))

def compute_metrics(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
    loss = cross_entropy_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

    return {'loss': loss, 'accuracy': accuracy}

In [3]:
def get_datasets():
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
    return train_ds, test_ds

def create_train_state(rng, learning_rate, momentum):
    cnn = CNN()
    params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

In [4]:
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = CNN().apply({'params': params}, batch['image'])
        loss = cross_entropy_loss(logits, batch['label'])
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, batch['label'])
    return state, metrics

In [5]:
@jax.jit
def eval_step(params, batch):
    logits = CNN().apply({'params': params}, batch['image'])
    return compute_metrics(logits, labels=batch['label'])

In [6]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch*batch_size]
    perms = perms.reshape((steps_per_epoch, batch_size))
    batch_metrics = []
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)
    
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }

    print(f"train epoch: {epoch} loss: {epoch_metrics_np['loss']:.4f} accuracy: {epoch_metrics_np['accuracy']:.4f}")
    return state

In [7]:
def eval_model(params, test_ds):
    metrics = eval_step(params, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_map(lambda x: x.item(), metrics)
    return summary['loss'], summary['accuracy']

In [8]:
train_ds, test_ds = get_datasets()

In [9]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

In [10]:
learning_rate = 0.1
momentum = 0.9

state = create_train_state(init_rng, learning_rate, momentum)

In [11]:
num_epochs = 10
batch_size = 32

for epoch in range(num_epochs):
    rng, input_rng = jax.random.split(rng)
    state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
    test_loss, test_accuracy = eval_model(state.params, test_ds)
    print(f">>> loss:{test_loss:.4f} accuracy: {test_accuracy:.4f}")

train epoch: 0 loss: 0.1334 accuracy: 0.9592
>>> loss:0.0614 accuracy: 0.9796
train epoch: 1 loss: 0.0481 accuracy: 0.9853
>>> loss:0.0540 accuracy: 0.9842
train epoch: 2 loss: 0.0336 accuracy: 0.9898
>>> loss:0.0311 accuracy: 0.9900
train epoch: 3 loss: 0.0246 accuracy: 0.9921
>>> loss:0.0360 accuracy: 0.9912
train epoch: 4 loss: 0.0212 accuracy: 0.9932
>>> loss:0.0340 accuracy: 0.9905
train epoch: 5 loss: 0.0174 accuracy: 0.9948
>>> loss:0.0286 accuracy: 0.9915
train epoch: 6 loss: 0.0114 accuracy: 0.9965
>>> loss:0.0413 accuracy: 0.9888
train epoch: 7 loss: 0.0100 accuracy: 0.9971
>>> loss:0.0462 accuracy: 0.9890
train epoch: 8 loss: 0.0096 accuracy: 0.9971
>>> loss:0.0381 accuracy: 0.9897
train epoch: 9 loss: 0.0083 accuracy: 0.9974
>>> loss:0.0366 accuracy: 0.9917
