[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/full_eval.ipynb)
[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/notebooks/full_eval.ipynb)

This notebook only contains executable code cells for the examples mentioned in
https://flax.readthedocs.io/en/latest/guides/full_eval.html

Please refer to above link for a an explanation of the problem and the proposed solutions.

### setup

In [1]:
!pip install -q chex einops
# tfds.split_for_jax_process() was added in 4.5.1
!pip install -q tensorflow_datasets -U
# flax.jax_utils.pad_shard_unpad() is only available at HEAD
!pip install -q git+https://github.com/google/flax

[K     |████████████████████████████████| 72 kB 387 kB/s 
[K     |████████████████████████████████| 4.2 MB 4.9 MB/s 
[K     |████████████████████████████████| 140 kB 4.5 MB/s 
[?25h  Building wheel for flax (setup.py) ... [?25l[?25hdone


In [2]:
import collections

import chex
import einops
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import flax.jax_utils
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

chex.set_n_cpu_devices(8)

In [3]:
per_device_batch_size = 512
dataset_name = 'mnist'

In [4]:
class FakeModel(nn.Module):
  num_classes: int
  @nn.compact
  def __call__(self, x):
    return jax.nn.one_hot(jnp.zeros([len(x)], jnp.int32), self.num_classes)

model = FakeModel(num_classes=10)
variables = {}
inputs = jnp.zeros([2, 28, 28, 1])
model.apply(variables, inputs)



DeviceArray([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
             [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

### The problem

In [5]:
# last batch has different shape
collections.Counter(
    tuple(batch['image'].shape)
    for batch in tfds.load('mnist', split='test').batch(per_device_batch_size)
)

[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]

[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


Counter({(272, 28, 28, 1): 1, (512, 28, 28, 1): 19})

In [6]:
# need to drop remainder when using multiple batch levels in a dataparallel
# setup
sum(
    np.prod(batch['label'].shape)
    for batch in tfds.load('mnist', split='test')
        .batch(per_device_batch_size, drop_remainder=True)
        .batch(jax.local_device_count())
)

9728

In [7]:
# having different number of examples for different hosts will result in SPMD
# violation when all examples are to be processed
process_count = 6
[
    len(tfds.load(dataset_name, split=tfds.split_for_jax_process(
        'test', process_index=process_index, process_count=process_count)))
    for process_index in range(process_count)
]

[1667, 1667, 1667, 1667, 1666, 1666]

In [8]:
# baseline: simple batching, keep reminder
# => leads to recompilation & only works on single device

@jax.jit
def get_preds(variables, inputs):
  print('retrigger compilation', inputs.shape)
  return model.apply(variables, inputs)

ds = tfds.load(dataset_name, split='test')
ds = ds.batch(per_device_batch_size, drop_remainder=False)

correct = total = 0
for batch in ds.as_numpy_iterator():
  preds = get_preds(variables, batch['image'])
  total += len(batch['label'])
  correct += (batch['label'] == preds.argmax(axis=1)).sum()

correc = correct.item()
correct, total, correct / total

retrigger compilation (512, 28, 28, 1)
retrigger compilation (272, 28, 28, 1)


(DeviceArray(980, dtype=int32), 10000, DeviceArray(0.098, dtype=float32))

In [9]:
# when the remainder is dropped, we can use multiple devices and avoid
# recompilations
# => but results are incorrect

@jax.pmap
def get_preds(variables, inputs):
  print('retrigger compilation', inputs.shape)
  return model.apply(variables, inputs)

ds = tfds.load(dataset_name, split=tfds.split_for_jax_process('test'))
# This `drop_remainder=True` is required so we can do a second batch level.
ds = ds.batch(per_device_batch_size, drop_remainder=True)
# This `drop_remainder=True` is required so we can avoid a recompilation.
ds = ds.batch(jax.local_device_count(), drop_remainder=True)

correct = total = 0
for batch in ds.as_numpy_iterator():
  preds = get_preds(variables, batch['image'])
  total += len(batch['label'].flatten())
  correct += (batch['label'] == preds.argmax(axis=-1)).sum()

correc = correct.item()
correct, total, correct / total

retrigger compilation (512, 28, 28, 1)


(DeviceArray(814, dtype=int32), 8192, DeviceArray(0.09936523, dtype=float32))

### The solution: padding

#### Manual implementation

In [10]:
# manually padding
# => precise & allows for data parallelism

@jax.pmap
def get_preds(variables, inputs):
  print('retrigger compilation', inputs.shape)
  return model.apply(variables, inputs)

ds = tfds.load(dataset_name, split=tfds.split_for_jax_process('test'))
per_host_batch_size = per_device_batch_size * jax.local_device_count()
ds = ds.batch(per_host_batch_size, drop_remainder=False)

shard = lambda x: einops.rearrange(
    x, '(d b) ... -> d b ...', d=jax.local_device_count())
unshard = lambda x: einops.rearrange(x, 'd b ... -> (d b) ...')

correct = total = 0
for batch in ds.as_numpy_iterator():
  images = batch['image']
  n = len(images)
  padding = np.zeros([per_host_batch_size - n, *images.shape[1:]], images.dtype)
  padded_images = np.concatenate([images, padding])
  preds = unshard(get_preds(variables, shard(padded_images)))[:n]
  total += n
  correct += (batch['label'] == preds.argmax(axis=-1)).sum()

correct = correct.item()
correct, total, correct / total

retrigger compilation (512, 28, 28, 1)


(980, 10000, 0.098)

#### Using `pad_shard_unpad()`

In [12]:
# same as before, but using @pad_shard_unshard decorator

# manually padding
# => precise & allows for data parallelism

@jax.pmap
def get_preds(variables, inputs):
  print('retrigger compilation', inputs.shape)
  return model.apply(variables, inputs)

ds = tfds.load(dataset_name, split=tfds.split_for_jax_process('test'))
per_host_batch_size = per_device_batch_size * jax.local_device_count()
ds = ds.batch(per_host_batch_size, drop_remainder=False)

correct = total = 0
for batch in ds.as_numpy_iterator():
  preds = flax.jax_utils.pad_shard_unpad(get_preds)(
      variables, batch['image'], min_device_batch=per_device_batch_size)
  total += len(batch['image'])
  correct += (batch['label'] == preds.argmax(axis=-1)).sum()

correct = correct.item()
correct, total, correct / total

retrigger compilation (512, 28, 28, 1)


(980, 10000, 0.098)

#### Computing metrics in `eval_step`

In [1]:
# moving the metrics computation into `eval_step()` and using `static_return`

# this pattern is often used with more complicated `clu.metrics`

def eval_step(metrics, variables, batch):
  print('retrigger compilation', {k: v.shape for k, v in batch.items()})
  preds = model.apply(variables, batch['image'])
  correct = (batch['mask'] & (batch['label'] == preds.argmax(axis=-1))).sum()
  total = batch['mask'].sum()
  return dict(
      correct=metrics['correct'] + jax.lax.psum(correct, axis_name='batch'),
      total=metrics['total'] + jax.lax.psum(total, axis_name='batch'),
  )

eval_step = jax.pmap(eval_step, axis_name='batch')
eval_step = flax.jax_utils.pad_shard_unpad(
    eval_step, static_argnums=(0, 1), static_return=True)

ds = tfds.load(dataset_name, split=tfds.split_for_jax_process('test'))
per_host_batch_size = per_device_batch_size * jax.local_device_count()
ds = ds.batch(per_host_batch_size, drop_remainder=False)

metrics = flax.jax_utils.replicate(dict(
    correct=jnp.array(0, jnp.int32),
    total=jnp.array(0, jnp.int32),)
)
for batch in ds.as_numpy_iterator():
  batch['mask'] = np.ones_like(batch['label'])
  metrics = eval_step(
      metrics, variables, batch,
      min_device_batch=per_device_batch_size)

correct, total = metrics['correct'][0].item(), metrics['total'][0].item()
correct, total, correct / total

retrigger compilation {'image': (512, 28, 28, 1), 'label': (512,), 'mask': (512,)}
(980, 10000, 0.098)


#### Multi-host complications

In [13]:
# infinite zero padding

def with_infinite_padding(dataset):
  """Adds "infinite padding" to the dataset."""
  filler_element = tf.nest.map_structure(
      lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec)
  filler_element['mask'] = [False]
  filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element)
  dataset = dataset.map(
      lambda features: dict(mask=True, **features),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  return dataset.concatenate(filler_dataset.repeat(None))

@jax.pmap
def get_preds(variables, inputs):
  print('retrigger compilation', inputs.shape)
  return model.apply(variables, inputs)

count_p = jax.pmap(
    lambda mask: jax.lax.psum(mask.sum(), axis_name='batch'),
    axis_name='batch',
)
count_correct_p = jax.pmap(
    lambda labels, preds, mask:
        jax.lax.psum((mask & (labels == preds)).sum(), axis_name='batch'),
    axis_name='batch',
)

ds = tfds.load(dataset_name, split=tfds.split_for_jax_process('test'))
ds = with_infinite_padding(ds).batch(per_device_batch_size).batch(jax.local_device_count())

correct = total = 0
for batch in ds.as_numpy_iterator():
  n = count_p(batch['mask'])[0].item()  # adds sync barrier
  if not n: break

  preds = get_preds(variables, batch['image']).argmax(axis=-1)
  total += n
  correct += count_correct_p(batch['label'], preds, batch['mask'])[0]

correct = correct.item()
correct, total, correct / total

retrigger compilation (512, 28, 28, 1)


(980, 10000, 0.098)