In [21]:
%load_ext autoreload

from collections.abc import Iterator
from typing import Mapping, NamedTuple

import matplotlib.pyplot as plt
import numpy as np
import haiku as hk
import jax
import jax.numpy as jnp
import optax

import tensorflow_datasets as tfds

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [22]:
NUM_CLASSES = 10

def net_fn(images: jax.Array, is_training: bool) -> jax.Array:
  # normalize the images
  x = images.astype(jnp.float32) / 255.0 

  # The most simple network - Flatten the image and pass through some
  # convolutions, a mlp, and output the 10 categories as logits
  cnet1 = hk.Sequential([
    hk.Conv2D(32, (3, 3)), jax.nn.relu,
    hk.MaxPool((2,2), 2, "SAME"),
  ])

  cnet2 = hk.Sequential([
    hk.Conv2D(32, (3, 3)), jax.nn.relu,
    hk.MaxPool((2,2), 2, "SAME"),
    hk.Conv2D(32, (3, 3)), jax.nn.relu,
  ])

  dnet = hk.Sequential([
    hk.Flatten(),
    hk.Linear(64), jax.nn.relu,
    hk.Linear(32), jax.nn.relu,
    hk.Linear(NUM_CLASSES),
  ])

  x = cnet1(x)

  if is_training:
    x = hk.dropout(hk.next_rng_key(), 0.5, x)

  x = cnet2(x)

  if is_training:
    x = hk.dropout(hk.next_rng_key(), 0.5, x)

  return dnet(x)

In [25]:
# Following the haiku examples, we define a class that inherets NamedTuple and this will hold
# our data. It basically redefines a dictionary to a class for typing purposes I presume. Type
# checking improves readablity and maintenance / bug-fixing
class Batch(NamedTuple):
  image: np.ndarray  # [B, H, W, 1]
  label: np.ndarray  # [B]

def load_dataset(
    split: str,
    *,
    is_training: bool,
    batch_size: int,) -> (Iterator[Batch], Mapping):
  """Loads the dataset as a generator of batches."""
  ds, ds_info = tfds.load("cifar10:3.*.*", split=split, with_info=True)
  ds = ds.cache()
  if is_training:
    ds = ds.shuffle(ds_info.splits[split].num_examples, seed=0)
    ds = ds.repeat() # We want an infinite iterator of the data
  ds = ds.batch(batch_size)
  ds = ds.map(lambda x: Batch(x['image'], x['label'])) # ** operator unpacks a dictionary
  return iter(tfds.as_numpy(ds)), ds_info

train, ds_info = load_dataset("train", is_training=True, batch_size=256)

test, ds_info = load_dataset("test", is_training=True, batch_size=64)

In [5]:
print(ds_info)

tfds.core.DatasetInfo(
    name='cifar10',
    full_name='cifar10/3.0.2',
    description="""
    The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
    """,
    homepage='https://www.cs.toronto.edu/~kriz/cifar.html',
    data_dir='/home/don/tensorflow_datasets/cifar10/3.0.2',
    file_format=tfrecord,
    download_size=162.17 MiB,
    dataset_size=132.40 MiB,
    features=FeaturesDict({
        'id': Text(shape=(), dtype=string),
        'image': Image(shape=(32, 32, 3), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=50000, num_shards=1>,
    },
    citation="""@TECHREPORT{Krizhevsky09learningmultiple,
        author = {Alex Krizhevsky},
        title = {Le

In [27]:
# We should now turn this into a proper training loop.
import tensorflow as tf
#tf.config.experimental.set_visible_devices([], "GPU")

class TrainingState(NamedTuple):
  params: hk.Params
  rng_key: jax.Array
  opt_state: optax.OptState

SEED = 5

EPOCHS = 4000
LEARNING_RATE = 0.001

optimiser = optax.adam(LEARNING_RATE)

# When there is no rng you can have hk handle this for you with without_apply_rng
net = hk.without_apply_rng(hk.transform(net_fn))
#net = hk.transform(net_fn)

# Training loss (cross-entropy). - right from the examples. pretty simple function
@hk.transform
def my_loss(batch: Batch) -> jnp.ndarray:
  """Compute the loss of the network, including L2."""
  logits = net_fn(batch.image, True)
  labels = jax.nn.one_hot(batch.label, NUM_CLASSES)

  #l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(state.params))
  softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
  softmax_xent /= labels.shape[0] # Using the average of the batch. This is why increased batch size tends to regularize your networks as well.
                                  # More data makes the loss function "smoother". It is less likely to overfit to small numbers of cases and 
                                  # get stuck in local minimums.

  return softmax_xent #+ 1e-3 * l2_loss
  #return softmax_xent

@jax.jit
def evaluate(params: hk.Params, batch: Batch) -> jax.Array:
  """Evaluation metric (classification accuracy)."""
  logits = net.apply(params, batch.image, False)
  predictions = jnp.argmax(logits, axis=-1)
  return jnp.mean(predictions == batch.label)

@jax.jit
def update(
    state: TrainingState,
    batch: Batch,
) -> (TrainingState, float):
  """Learning rule (stochastic gradient descent)."""
  rng, net_rng = jax.random.split(state.rng_key)
  loss_and_grad_fn = jax.value_and_grad(my_loss.apply)
  loss, grads = loss_and_grad_fn(state.params, net_rng, batch)
  updates, opt_state = optimiser.update(grads, state.opt_state)
  new_params = optax.apply_updates(state.params, updates)
  return TrainingState(new_params, rng, opt_state), loss

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_loss.reset_states()

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_loss.reset_states()

test_acc = tf.keras.metrics.Mean(name='test_acc')
test_acc.reset_states()

example = next(train)

rng = jax.random.PRNGKey(SEED)
rng, init_rng = jax.random.split(rng)

# Get the initial parameters
params = my_loss.init(init_rng, example)
opt_state = optimiser.init(params)

state = TrainingState(params, rng, opt_state)


for step in range(EPOCHS):
    state, loss = update(state, next(train))
    train_loss(loss)

    if step % 500 == 0:
      #test, ds_info = load_dataset("test", is_training=False, batch_size=64)
      #for b in test:
      #  acc = evaluate(state.params, b)
      #  test_acc(acc)
      print(f'Epoch {step + 1} Train Loss {train_loss.result():.4f} Test acc {test_acc.result():.4f}')
      #logging.info({train_loss: train_loss})
      #test_acc.reset_states()

Epoch 1 Train Loss 2.3306 Test acc 0.0000
Epoch 501 Train Loss 1.8394 Test acc 0.0000
Epoch 1001 Train Loss 1.6910 Test acc 0.0000


KeyboardInterrupt: 