In [94]:
import numpy as np 
import jax.numpy as jnp
import jax
import haiku as hk
from itertools import cycle
import tensorflow_datasets as tfds
# improving result with Adam from optax
import optax


In [26]:
# one_hot returns vector with only one 1 for the label index.
jax.nn.one_hot(jnp.array([0, 0, -1, 2, 2]), 4)

DeviceArray([[1., 0., 0., 0.],
             [1., 0., 0., 0.],
             [0., 0., 0., 0.],
             [0., 0., 1., 0.],
             [0., 0., 1., 0.]], dtype=float32)

In [125]:
# Fetch the dataset directly
mnist = tfds.image.MNIST()
# useful to know what's inside the dataset
mnist.info

tfds.core.DatasetInfo(
    name='mnist',
    full_name='mnist/3.0.1',
    description="""
    The MNIST database of handwritten digits.
    """,
    homepage='http://yann.lecun.com/exdb/mnist/',
    data_path='/home/clement/tensorflow_datasets/mnist/3.0.1',
    download_size=11.06 MiB,
    dataset_size=21.00 MiB,
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=60000, num_shards=1>,
    },
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
)

In [32]:
mnist.download_and_prepare()

[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/clement/tensorflow_datasets/mnist/3.0.1...[0m


HBox(children=(HTML(value='Dl Completed...'), FloatProgress(value=0.0, max=4.0), HTML(value='')))



[1mDataset mnist downloaded and prepared to /home/clement/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [124]:
datasets = mnist.as_dataset(batch_size=1000)


In [126]:
def softmax_cross_entropy(logits, labels):
    one_hot = jax.nn.one_hot(labels, logits.shape[-1])
    return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)


def net_fn(images):
    # LNET 300 100 10
    mlp = hk.Sequential([
        hk.Flatten(),
        hk.Linear(300),
        hk.Linear(100),
        hk.Linear(10)])
    return mlp(images)


# hk.Conv2D(output_channels=10, kernel_shape=6), jax.nn.relu,
        

# There are two transforms in Haiku, hk.transform and hk.transform_with_state.
# If our network updated state during the forward pass (e.g. like the moving
# averages in hk.BatchNorm) we would need hk.transform_with_state, but for our
# simple MLP we can just use hk.transform.
net_fn_t = hk.transform(net_fn)
# MLP is deterministic once we have our parameters, as such we will not need to
# pass an RNG key to apply. without_apply_rng is a convenience wrapper that will
# make the rng argument to `loss_fn_t.apply` default to `None`.
net_fn_t = hk.without_apply_rng(net_fn_t)


def loss_fn(params, images, labels):
    logits = net_fn_t.apply(params, images)
    return jnp.mean(softmax_cross_entropy(logits, labels))

In [127]:

def result_on_test_set(parameters):
    # Test on test set : 
    it_test = datasets['test'].as_numpy_iterator()

    def error_rate(y, y_hat):
        return jnp.mean(y != y_hat)

    errs = []
    for o in it_test:
        images, labels = o['image'].astype(jnp.float32) / 255., o['label']
        label_hat = net_fn_t.apply(parameters, images).argmax(axis=1)
        errs.append(error_rate(labels, label_hat))
    print("Error on test set ", np.mean(errs))


In [128]:
# `init` runs your function, as such we need an example input. Typically you can
# pass "dummy" inputs (e.g. ones of the same shape and dtype) since initialization
# is not usually data dependent.

it = cycle(datasets['train'].as_numpy_iterator())

o = next(it)
images, labels = o['image'].astype(jnp.float32) / 255., o['label']
rng = jax.random.PRNGKey(42)

# The result of `init` is a nested data structure of all the parameters in your
# network. You can pass this into `apply`.
params = loss_fn_t.init(rng, images, labels)


def sgd(param, update):
    return param - 0.01 * update


for i in range(500):
    o = next(it)
    images, labels = o['image'].astype(jnp.float32) / 255., o['label']
    loss, grads = jax.value_and_grad(loss_fn)(params, images, labels)
    params = jax.tree_multimap(sgd, params, grads)
    if i % 50 == 0:
        print("Loss on train set ", loss)
        result_on_test_set(params)

Loss on train set  2.2960014
Error on test set  0.8116
Loss on train set  1.888028
Error on test set  0.3524
Loss on train set  1.5463176
Error on test set  0.2642
Loss on train set  1.2657787
Error on test set  0.22259998
Loss on train set  1.0318453
Error on test set  0.2006
Loss on train set  0.89268565
Error on test set  0.18149999
Loss on train set  0.7739114
Error on test set  0.16849999
Loss on train set  0.717366
Error on test set  0.1593
Loss on train set  0.6926508
Error on test set  0.1504
Loss on train set  0.6375088
Error on test set  0.1418


In [147]:
@jax.jit
def ema_update(
    avg_params: hk.Params,
    new_params: hk.Params,
    epsilon: float = 0.001) -> hk.Params:
    return jax.tree_multimap(lambda p1, p2: (1 - epsilon) * p1 + epsilon * p2,
         avg_params, new_params)


def optimize(steps=1000):
    opt = optax.adam(learning_rate=1e-3)


    o = next(it)
    images, labels = o['image'].astype(jnp.float32) / 255., o['label']
    rng = jax.random.PRNGKey(42)
    # The result of `init` is a nested data structure of all the parameters in your
    # network. You can pass this into `apply`.
    params = net_fn_t.init(rng, images)
    avg_params = params
    opt_state = opt.init(params)



    for i in range(steps):
        o = next(it)
        images, labels = o['image'].astype(jnp.float32) / 255., o['label']
        loss, grads = jax.value_and_grad(loss_fn)(params, images, labels)

        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        avg_params = ema_update(avg_params, params)
        if i % 100 == 0:
            print("Loss on train set ", loss)
            result_on_test_set(avg_params)

In [148]:
optimize(1000)

Loss on train set  2.3423822
Error on test set  0.93509996
Loss on train set  0.2918328
Error on test set  0.29759997
Loss on train set  0.2829472
Error on test set  0.17999999
Loss on train set  0.24500448
Error on test set  0.147
Loss on train set  0.27240166
Error on test set  0.12920001
Loss on train set  0.2591893
Error on test set  0.118900016
Loss on train set  0.23423465
Error on test set  0.11029999
Loss on train set  0.2619259
Error on test set  0.1039
Loss on train set  0.2514375
Error on test set  0.099199995
Loss on train set  0.23181894
Error on test set  0.095


In [162]:
# Improve model using convolutions

def net_fn(images):
    # LNET 300 100 10
    mlp = hk.Sequential([
        hk.Conv2D(output_channels=32, kernel_shape=3),
        hk.MaxPool(window_shape=(2,2), padding='SAME', strides=1),
        hk.Flatten(),
        hk.Linear(300),
        hk.Linear(200),
        hk.Linear(10)])
    return mlp(images)
net_fn_t = hk.without_apply_rng(hk.transform(net_fn))
optimize(5000)

Loss on train set  2.347464
Error on test set  0.9476
Loss on train set  0.15585364
Error on test set  0.1347
Loss on train set  0.06889652
Error on test set  0.13960001
Loss on train set  0.05627359
Error on test set  0.15330002
Loss on train set  0.033857062
Error on test set  0.1635
Loss on train set  0.029153084
Error on test set  0.16849999
Loss on train set  0.065876655
Error on test set  0.16600001
Loss on train set  0.021700073
Error on test set  0.1596
Loss on train set  0.012087165
Error on test set  0.15309998
Loss on train set  0.01262256
Error on test set  0.1405
Loss on train set  0.012617512
Error on test set  0.1232
Loss on train set  0.027628185
Error on test set  0.107599996
Loss on train set  0.020550411
Error on test set  0.093
Loss on train set  0.00369231
Error on test set  0.0795
Loss on train set  0.018902548
Error on test set  0.0685
Loss on train set  0.01752936
Error on test set  0.058699995
Loss on train set  0.013117411
Error on test set  0.052599996
Loss o

DeviceArray([2, 0, 4, 8, 7, 6, 0, 6, 3, 1, 6, 0, 7, 9, 8, 4, 5, 3, 9, 0,
             6, 6, 3, 0, 2, 3, 6, 6, 7, 4, 0, 3, 8, 9, 5, 4, 2, 8, 5, 8,
             5, 2, 9, 2, 4, 2, 9, 0, 5, 1, 0, 7, 9, 9, 9, 6, 3, 8, 8, 6,
             9, 0, 5, 4], dtype=int32)