# HAIKU MNIST example
https://github.com/deepmind/dm-haiku/blob/main/examples/mnist.py

In [1]:
# Sets how much GPU memory JAX preallocate
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.8

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.8


In [2]:
from typing import Iterator, NamedTuple

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
import lovely_jax as lj
lj.monkey_patch()

### Loading Data

In [3]:
class Batch(NamedTuple):
  image: np.ndarray  # [B, H, W, 1]
  label: np.ndarray  # [B]
  
def load_dataset(
    split: str,
    *,
    shuffle: bool,
    batch_size: int,
) -> Iterator[Batch]:
  """Loads the MNIST dataset."""
  ds = tfds.load("mnist:3.*.*", split=split).cache().repeat()
  if shuffle:
    ds = ds.shuffle(10 * batch_size, seed=0)
  ds = ds.batch(batch_size)
  ds = ds.map(lambda x: Batch(**x))
  return iter(tfds.as_numpy(ds))

In [11]:
 bs = 1_000
 train = load_dataset("train", shuffle=True, batch_size=bs)
 batch = next(train)
 type(batch), len(batch)

2023-01-13 15:23:57.343196: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2023-01-13 15:23:57.345746: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


(__main__.Batch, 2)

### Model

In [7]:
NUM_CLASSES = 10

def net_fn(images: jnp.ndarray) -> jnp.ndarray:
  """Standard LeNet-300-100 MLP network."""
  x = images.astype(jnp.float32) / 255.
  mlp = hk.Sequential([
      hk.Flatten(),
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(NUM_CLASSES),
  ])
  return mlp(x)

In [13]:
network = hk.without_apply_rng(hk.transform(net_fn))

In [14]:
initial_params = network.init(jax.random.PRNGKey(seed=0), next(train).image)
initial_params

{'linear': {'w': Array[784, 300] n=235200 x∈[-0.071, 0.071] μ=-5.580e-05 σ=0.031 gpu:0,
  'b': Array[300] [38;2;127;127;127mall_zeros[0m gpu:0},
 'linear_1': {'w': Array[300, 100] n=30000 x∈[-0.115, 0.115] μ=0.000 σ=0.051 gpu:0,
  'b': Array[100] [38;2;127;127;127mall_zeros[0m gpu:0},
 'linear_2': {'w': Array[100, 10] n=1000 x∈[-0.196, 0.197] μ=-0.001 σ=0.091 gpu:0,
  'b': Array[10] [38;2;127;127;127mall_zeros[0m gpu:0 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]}}

In [15]:
logits = network.apply(initial_params, batch.image)
logits

Array[1000, 10] n=10000 x∈[-0.449, 0.470] μ=-0.006 σ=0.121 gpu:0

In [18]:
@jax.jit
def evaluate(params: hk.Params, batch: Batch) -> jnp.ndarray:
    """Evaluation metric (classification accuracy)."""
    logits = network.apply(params, batch.image)
    predictions = jnp.argmax(logits, axis=-1)
    return jnp.mean(predictions == batch.label)

evaluate(initial_params, batch)

Array gpu:0 0.112

### Loss

In [19]:
def loss(params: hk.Params, batch: Batch, wd=1e-4) -> jnp.ndarray:
    """Cross-entropy classification loss, regularised by L2 weight decay."""
    batch_size, *_ = batch.image.shape
    logits = network.apply(params, batch.image)
    labels = jax.nn.one_hot(batch.label, NUM_CLASSES)

    l2_regulariser = 0.5 * sum(
        jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
    log_likelihood = jnp.sum(labels * jax.nn.log_softmax(logits))

    return -log_likelihood / batch_size + wd * l2_regulariser

In [20]:
loss(initial_params, batch)

Array gpu:0 2.326

### Learning

In [24]:
class TrainingState(NamedTuple):
  params: hk.Params
  avg_params: hk.Params
  opt_state: optax.OptState

optimiser = optax.adam(1e-3)
initial_opt_state = optimiser.init(initial_params)
state = TrainingState(initial_params, initial_params, initial_opt_state)
state

TrainingState(params={'linear': {'w': Array[784, 300] n=235200 x∈[-0.071, 0.071] μ=-5.580e-05 σ=0.031 gpu:0, 'b': Array[300] [38;2;127;127;127mall_zeros[0m gpu:0}, 'linear_1': {'w': Array[300, 100] n=30000 x∈[-0.115, 0.115] μ=0.000 σ=0.051 gpu:0, 'b': Array[100] [38;2;127;127;127mall_zeros[0m gpu:0}, 'linear_2': {'w': Array[100, 10] n=1000 x∈[-0.196, 0.197] μ=-0.001 σ=0.091 gpu:0, 'b': Array[10] [38;2;127;127;127mall_zeros[0m gpu:0 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]}}, avg_params={'linear': {'w': Array[784, 300] n=235200 x∈[-0.071, 0.071] μ=-5.580e-05 σ=0.031 gpu:0, 'b': Array[300] [38;2;127;127;127mall_zeros[0m gpu:0}, 'linear_1': {'w': Array[300, 100] n=30000 x∈[-0.115, 0.115] μ=0.000 σ=0.051 gpu:0, 'b': Array[100] [38;2;127;127;127mall_zeros[0m gpu:0}, 'linear_2': {'w': Array[100, 10] n=1000 x∈[-0.196, 0.197] μ=-0.001 σ=0.091 gpu:0, 'b': Array[10] [38;2;127;127;127mall_zeros[0m gpu:0 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]}}, opt_state=(ScaleByAdamState(count=Arra

In [32]:
grads = jax.grad(loss)(state.params, batch)
grads

{'linear': {'b': Array[300] x∈[-0.011, 0.004] μ=-0.001 σ=0.002 gpu:0,
  'w': Array[784, 300] n=235200 x∈[-0.010, 0.006] μ=-0.000 σ=0.001 gpu:0},
 'linear_1': {'b': Array[100] x∈[-0.016, 0.006] μ=-0.004 σ=0.005 gpu:0,
  'w': Array[300, 100] n=30000 x∈[-0.012, 0.008] μ=-0.000 σ=0.001 gpu:0},
 'linear_2': {'b': Array[10] x∈[-0.023, 0.026] μ=1.593e-08 σ=0.015 gpu:0 [0.026, -0.023, 0.007, 0.011, 0.005, -0.023, 0.009, -0.015, -0.002, 0.006],
  'w': Array[100, 10] n=1000 x∈[-0.026, 0.016] μ=-1.259e-07 σ=0.005 gpu:0}}

In [34]:
updates, opt_state = optimiser.update(grads, state.opt_state)
updates, opt_state

({'linear': {'b': Array[300] x∈[-0.001, 0.001] μ=0.000 σ=0.001 gpu:0,
   'w': Array[784, 300] n=235200 x∈[-0.001, 0.001] μ=7.347e-05 σ=0.001 gpu:0},
  'linear_1': {'b': Array[100] x∈[-0.001, 0.001] μ=0.000 σ=0.001 gpu:0,
   'w': Array[300, 100] n=30000 x∈[-0.001, 0.001] μ=0.000 σ=0.001 gpu:0},
  'linear_2': {'b': Array[10] x∈[-0.001, 0.001] μ=-8.737e-05 σ=0.001 gpu:0 [-0.001, 0.001, -0.001, -0.001, -0.001, 0.001, -0.001, 0.001, 0.001, -0.000],
   'w': Array[100, 10] n=1000 x∈[-0.001, 0.001] μ=-0.000 σ=0.001 gpu:0}},
 (ScaleByAdamState(count=Array i32 gpu:0 3, mu={'linear': {'b': Array[300] x∈[-0.002, 0.001] μ=-0.000 σ=0.000 gpu:0, 'w': Array[784, 300] n=235200 x∈[-0.003, 0.002] μ=-2.849e-05 σ=0.000 gpu:0}, 'linear_1': {'b': Array[100] x∈[-0.003, 0.002] μ=-0.001 σ=0.001 gpu:0, 'w': Array[300, 100] n=30000 x∈[-0.003, 0.002] μ=-7.652e-05 σ=0.000 gpu:0}, 'linear_2': {'b': Array[10] x∈[-0.005, 0.005] μ=3.475e-09 σ=0.003 gpu:0 [0.005, -0.005, 0.001, 0.002, 0.001, -0.004, 0.004, -0.004, -0.00

In [30]:
@jax.jit
def update(state: TrainingState, batch: Batch) -> TrainingState:
    """Learning rule (stochastic gradient descent)."""
    grads = jax.grad(loss)(state.params, batch)
    updates, opt_state = optimiser.update(grads, state.opt_state)
    params = optax.apply_updates(state.params, updates)
    # Compute avg_params, the exponential moving average of the "live" params.
    # We use this only for evaluation (cf. https://doi.org/10.1137/0330046).
    avg_params = optax.incremental_update(
        params, state.avg_params, step_size=0.001) #lambda new, old: step_size * new + (1.0 - step_size) * old
    return TrainingState(params, avg_params, opt_state)

In [28]:
batch = next(train)
state = update(state,batch)
loss(state.params, batch)

Array gpu:0 2.117

[0;31mSignature:[0m
[0moptax[0m[0;34m.[0m[0mapply_updates[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mparams[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mjax[0m[0;34m.[0m[0mArray[0m[0;34m,[0m [0mIterable[0m[0;34m[[0m[0mForwardRef[0m[0;34m([0m[0;34m'ArrayTree'[0m[0;34m)[0m[0;34m][0m[0;34m,[0m [0mMapping[0m[0;34m[[0m[0mAny[0m[0;34m,[0m [0mForwardRef[0m[0;34m([0m[0;34m'ArrayTree'[0m[0;34m)[0m[0;34m][0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mupdates[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mjax[0m[0;34m.[0m[0mArray[0m[0;34m,[0m [0mIterable[0m[0;34m[[0m[0mForwardRef[0m[0;34m([0m[0;34m'ArrayTree'[0m[0;34m)[0m[0;34m][0m[0;34m,[0m [0mMapping[0m[0;34m[[0m[0mAny[0m[0;34m,[0m [0mForwardRef[0m[0;34m([0m[0;34m'ArrayTree'[0m[0;34m)[0m[0;34m][0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0mUnion[0m[0;34m[[0m[0mjax[0m[0;34m.[0m[0mArray[0m[0;34m,[0m [0

### Training and evaluation loop

In [37]:
eval_datasets = {split: load_dataset(split, shuffle=False, batch_size=10_000) for split in ("train", "test")}
eval_datasets

{'train': <generator object _eager_dataset_iterator at 0x7fd5383ae120>,
 'test': <generator object _eager_dataset_iterator at 0x7fd5383aedd0>}

In [39]:
for step in range(3001):
    if step % 100 == 0:
      # Periodically evaluate classification accuracy on train & test sets.
      # Note that each evaluation is only on a (large) batch.
      for split, dataset in eval_datasets.items():
        accuracy = np.array(evaluate(state.avg_params, next(dataset))).item()
        print({"step": step, "split": split, "accuracy": f"{accuracy:.3f}"})

    # Do SGD on a batch of training examples.
    state = update(state, next(train))

{'step': 0, 'split': 'train', 'accuracy': '0.123'}
{'step': 0, 'split': 'test', 'accuracy': '0.121'}
{'step': 100, 'split': 'train', 'accuracy': '0.436'}
{'step': 100, 'split': 'test', 'accuracy': '0.440'}
{'step': 200, 'split': 'train', 'accuracy': '0.635'}
{'step': 200, 'split': 'test', 'accuracy': '0.645'}
{'step': 300, 'split': 'train', 'accuracy': '0.781'}
{'step': 300, 'split': 'test', 'accuracy': '0.792'}
{'step': 400, 'split': 'train', 'accuracy': '0.866'}
{'step': 400, 'split': 'test', 'accuracy': '0.864'}
{'step': 500, 'split': 'train', 'accuracy': '0.902'}
{'step': 500, 'split': 'test', 'accuracy': '0.904'}
{'step': 600, 'split': 'train', 'accuracy': '0.930'}
{'step': 600, 'split': 'test', 'accuracy': '0.928'}
{'step': 700, 'split': 'train', 'accuracy': '0.947'}
{'step': 700, 'split': 'test', 'accuracy': '0.942'}
{'step': 800, 'split': 'train', 'accuracy': '0.955'}
{'step': 800, 'split': 'test', 'accuracy': '0.952'}
{'step': 900, 'split': 'train', 'accuracy': '0.966'}
{'step

### Refactor

In [42]:
network = hk.without_apply_rng(hk.transform(net_fn))
optimiser = optax.adam(1e-3)

def loss(params: hk.Params, batch: Batch) -> jnp.ndarray:
    """Cross-entropy classification loss, regularised by L2 weight decay."""
    batch_size, *_ = batch.image.shape
    logits = network.apply(params, batch.image)
    labels = jax.nn.one_hot(batch.label, NUM_CLASSES)

    l2_regulariser = 0.5 * sum(
        jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
    log_likelihood = jnp.sum(labels * jax.nn.log_softmax(logits))

    return -log_likelihood / batch_size + 1e-4 * l2_regulariser

@jax.jit
def evaluate(params: hk.Params, batch: Batch) -> jnp.ndarray:
    """Evaluation metric (classification accuracy)."""
    logits = network.apply(params, batch.image)
    predictions = jnp.argmax(logits, axis=-1)
    return jnp.mean(predictions == batch.label)

@jax.jit
def update(state: TrainingState, batch: Batch) -> TrainingState:
    """Learning rule (stochastic gradient descent)."""
    grads = jax.grad(loss)(state.params, batch)
    updates, opt_state = optimiser.update(grads, state.opt_state)
    params = optax.apply_updates(state.params, updates)
    # Compute avg_params, the exponential moving average of the "live" params.
    # We use this only for evaluation (cf. https://doi.org/10.1137/0330046).
    avg_params = optax.incremental_update(
        params, state.avg_params, step_size=0.001)
    return TrainingState(params, avg_params, opt_state)

In [44]:
# Training & evaluation loop.
n_epochs = 2
for epoch in range(n_epochs):
    for _, batch in enumerate(train): state = update(state, batch)
    
    for split, dataset in eval_datasets.items():
        accuracy = np.array(evaluate(state.avg_params, next(dataset))).item()
        print({"epoch": epoch, "split": split, "accuracy": f"{accuracy:.3f}"})

{'epoch': 0, 'split': 'train', 'accuracy': '1.000'}
{'epoch': 0, 'split': 'test', 'accuracy': '0.983'}
{'epoch': 1, 'split': 'train', 'accuracy': '1.000'}
{'epoch': 1, 'split': 'test', 'accuracy': '0.983'}


: 

### Creating a Learner

In [None]:
import fastcore.all as fc

In [None]:
class Learner:
    def __init__(self,dls): fc.store_attr()
    def one_batch(self):
        self.xb, self.yb = self.batch
    def one_epoch(self, is_training):
        dl = self.dls.train if is_training else self.dls.valid
        for self.num, self.batch in enumerate(dl): self.one_batch()

    def fit(self, n_epochs):
        self.n_epochs = n_epochs
        for self.epoch in range(n_epochs):
            self.one_epoch(True)
            self.one_epoch(False)