In [None]:
%pip install \
    git+https://github.com/deepmind/dm-haiku@v0.0.4 \
    git+https://github.com/deepmind/optax@v0.0.9 \

# MNIST CNN
Based on [dm-haiku/mnist.py](https://github.com/deepmind/dm-haiku/blob/main/examples/mnist.py)

In [2]:
import haiku as hk

import jax
from jax import jit, partial, vmap, grad
from jax import random
import jax.lax as lax
import jax.nn as nn
import jax.numpy as np

import optax

import tensorflow_datasets as tfds

In [3]:
rng = random.PRNGKey(42)

In [4]:
def ravel_tree(tree):
    return np.concatenate(list(map(np.ravel, jax.tree_leaves(tree))))
ravel_tree((np.array([1, 2, 3]), np.array([[4, 5], [6, 7]])))

DeviceArray([1, 2, 3, 4, 5, 6, 7], dtype=int32)

In [5]:
train_batch_size = 128
eval_batch_size = 1024

def load_dataset(split, *, is_training, batch_size):
    ds = tfds.load("mnist:3.*.*", split=split).cache().repeat()
    if is_training:
        ds = ds.shuffle(10 * batch_size, seed=0)
    ds = ds.batch(batch_size)
    return iter(tfds.as_numpy(ds))
train = load_dataset("train", is_training=True, batch_size=train_batch_size)
train_eval = load_dataset("train", is_training=False, batch_size=eval_batch_size)
test_eval = load_dataset("test", is_training=False, batch_size=eval_batch_size)
batch = next(train)
batch['image'].shape, batch['label'].shape

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]


[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


((128, 28, 28, 1), (128,))

In [6]:
label_count = 10

@hk.without_apply_rng
@hk.transform
def model(batch):
    x = batch["image"].astype(np.float32) / 255.
    mlp = hk.Sequential([
        hk.Conv2D(64, kernel_shape=(3,3)), nn.relu,
        hk.Conv2D(32, kernel_shape=(3,3)), nn.relu,
        hk.Flatten(),
        hk.Linear(label_count),
    ])
    return mlp(x)
print(hk.experimental.tabulate(model, columns=['module', 'input', 'output', 'params_size'])(batch))

+----------------------------+-------------------+-------------------+---------------+
| Module                     | Input             | Output            |   Param count |
| sequential (Sequential)    | f32[128,28,28,1]  | f32[128,10]       |       269,994 |
+----------------------------+-------------------+-------------------+---------------+
| conv2_d (Conv2D)           | f32[128,28,28,1]  | f32[128,28,28,64] |           640 |
|  └ sequential (Sequential) |                   |                   |               |
+----------------------------+-------------------+-------------------+---------------+
| conv2_d_1 (Conv2D)         | f32[128,28,28,64] | f32[128,28,28,32] |        18,464 |
|  └ sequential (Sequential) |                   |                   |               |
+----------------------------+-------------------+-------------------+---------------+
| flatten (Flatten)          | f32[128,28,28,32] | f32[128,25088]    |             0 |
|  └ sequential (Sequential) |             

In [7]:
def loss(params, batch):
    logits = model.apply(params, batch)
    labels = nn.one_hot(batch['label'], label_count)
    l2_loss = np.sum(optax.l2_loss(ravel_tree(params)))
    softmax_xent = optax.softmax_cross_entropy(logits, labels)
    softmax_xent = np.mean(softmax_xent)
    softmax_xent = softmax_xent + 1e-4 * l2_loss
    return softmax_xent
rng, r = random.split(rng)
weights = average_weights= model.init(r, next(train))
loss(weights, batch)

DeviceArray(2.3454196, dtype=float32)

In [8]:
@jax.jit
def accuracy(weights, batch):
    predictions = model.apply(weights, batch)
    return np.mean(np.argmax(predictions, axis=-1) == batch["label"])
accuracy(weights, batch)

DeviceArray(0.0234375, dtype=float32)

In [9]:
optimizer = optax.adam(1e-3)
optimizer_state = optimizer.init(weights)

In [10]:
@jax.jit
def update(weights, optimizer_state, batch):
    loss_grads = grad(loss)(weights, batch)
    optimizer_updates, optimizer_state = optimizer.update(loss_grads, optimizer_state)
    weights = optax.apply_updates(weights, optimizer_updates)
    return weights, optimizer_state
np.mean(np.abs(ravel_tree(update(weights, optimizer_state, batch)[0])))

DeviceArray(0.00672121, dtype=float32)

In [11]:
@jit
def ema_update(weights, average_weights):
    return optax.incremental_update(weights, average_weights, step_size=0.001)
np.mean(np.abs(ravel_tree(ema_update(weights, average_weights))))

DeviceArray(0.0068145, dtype=float32)

In [12]:
for step in range(2000):
    if step % 100 == 0:
        train_accuracy = accuracy(average_weights, next(train_eval))
        test_accuracy = accuracy(average_weights, next(test_eval))
        print(f"[Step {step}] Train / Test accuracy: {train_accuracy:.3f} / {test_accuracy:.3f}.")

    weights, optimizer_state = update(weights, optimizer_state, next(train))
    average_weights = ema_update(weights, average_weights)

[Step 0] Train / Test accuracy: 0.071 / 0.061.
[Step 100] Train / Test accuracy: 0.707 / 0.693.
[Step 200] Train / Test accuracy: 0.893 / 0.887.
[Step 300] Train / Test accuracy: 0.921 / 0.921.
[Step 400] Train / Test accuracy: 0.940 / 0.951.
[Step 500] Train / Test accuracy: 0.962 / 0.960.
[Step 600] Train / Test accuracy: 0.971 / 0.972.
[Step 700] Train / Test accuracy: 0.978 / 0.978.
[Step 800] Train / Test accuracy: 0.981 / 0.977.
[Step 900] Train / Test accuracy: 0.980 / 0.979.
[Step 1000] Train / Test accuracy: 0.988 / 0.979.
[Step 1100] Train / Test accuracy: 0.985 / 0.988.
[Step 1200] Train / Test accuracy: 0.995 / 0.981.
[Step 1300] Train / Test accuracy: 0.990 / 0.979.
[Step 1400] Train / Test accuracy: 0.990 / 0.984.
[Step 1500] Train / Test accuracy: 0.995 / 0.987.
[Step 1600] Train / Test accuracy: 0.994 / 0.985.
[Step 1700] Train / Test accuracy: 0.997 / 0.991.
[Step 1800] Train / Test accuracy: 0.998 / 0.987.
[Step 1900] Train / Test accuracy: 0.995 / 0.989.
