Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds "full eval" HOWTO. #2111

Merged
merged 4 commits into from May 19, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -15,7 +15,7 @@ vNext
-
-
-
-
- Added `flax.jax_utils.ad_shard_unpad()` by @lucasb-eyer
-
-
-
Expand Down
1 change: 1 addition & 0 deletions docs/flax.jax_utils.rst
Expand Up @@ -20,3 +20,4 @@ Multi device utilities

.. autofunction:: pmean

.. autofunction:: pad_shard_unpad
161 changes: 161 additions & 0 deletions docs/howtos/full_eval.rst
@@ -0,0 +1,161 @@
.. image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/full_eval.ipynb

Processing the entire Dataset
=============================

For efficiency reasons, we form batches that contain multiple examples and
process them in parallel. Especially when evaluating a model, it is important
that we process all examples and **avoid losing the remainder** of examples that
does not form a complete batch at the end.


The problem
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the problem is two-fold: double compilation (both in training and eval), and incorrect metric results (in eval). Is this correct? Maybe it is worth emphasizing this. Currently you state that is especially important during eval but you don't explain why.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it more like this

  1. In eval, we care about processing the full dataset because otherwise the metrics are off. (in training we usually do multiple epochs and using some examples 1x less for training does not matter)
  2. When we want to avoid loosing data at the end, we run into other problems (like e.g. multiple compliations)

I think 1. is mentioned in the first paragraph "Especially when evaluating a model, it is important that we process all examples", and 2. is mentioned further down as disadvantage of some solutions.

-----------

When evaluating on a single device, one can either drop the last incomplete
batch, or one can form a last batch with a shape different from the preceding
batches. Doing the latter has the disadvantage that this will trigger a
**recompilation** of the ``eval_step()`` because XLA is not shape polymorphic.

.. code-block:: python

collections.Counter(
tuple(batch['image'].shape)
for batch in tfds.load('mnist', split='test').batch(per_device_batch_size)
)
# output:
# Counter({(272, 28, 28, 1): 1, (512, 28, 28, 1): 19})

The problem is accentuated when using multiple devices for data parallelism. If
the batch size is not **divisible by the number devices**, then that last step
must be executed on a single device (or a subset of devices). Usually one would
drop the last batch, but this will lead to incorrect results.


.. code-block:: python

sum(
np.prod(batch['label'].shape)
for batch in tfds.load('mnist', split='test')
.batch(per_device_batch_size, drop_remainder=True)
.batch(jax.local_device_count())
)
# output:
# 9728

Using multiple hosts further complicates the situation because JAX uses the SPMD
paradigm and every host must execute the same program. We would usually form
non-overlapping splits for different hosts with |tfds.split_for_jax_process()|_,
but this can lead to **different numbers for different hosts**, resulting in
different JAX programs when all examples are to be processed.

.. code-block:: python

process_count = 6
[
len(tfds.load(dataset_name, split=tfds.split_for_jax_process(
'test', process_index=process_index, process_count=process_count)))
for process_index in range(process_count)
]
# output:
# [1667, 1667, 1667, 1667, 1666, 1666]



.. |tfds.split_for_jax_process()| replace:: ``tfds.split_for_jax_process()``
.. _tfds.split_for_jax_process(): https://www.tensorflow.org/datasets/api_docs/python/tfds/split_for_jax_process


The solution: padding
---------------------

Even though it's possible to solve this problem by cleverly adjusting the number
jheek marked this conversation as resolved.
Show resolved Hide resolved
of batches executed by different devices on different hosts, such a solution
quickly becomes complicated and makes the main eval loop hard to read with a lot
of cumbersome logic.

The more straight forward solution to this problem is to use padding at the end
andsteing marked this conversation as resolved.
Show resolved Hide resolved
of the dataset to make sure that the last batch has the same size as the
preceding batches.


Manual implementation
~~~~~~~~~~~~~~~~~~~~~

The last batch is manually padded to contain the same number of examples as in
the preceding batches. The predictions for the padded examples are discarded
from the computation.

.. code-block:: python

shard = lambda x: einops.rearrange(
x, '(d b) ... -> d b ...', d=jax.local_device_count())
unshard = lambda x: einops.rearrange(x, 'd b ... -> (d b) ...')

correct = total = 0
for batch in ds.as_numpy_iterator():
images = batch['image']
n = len(images)
padding = np.zeros([per_host_batch_size - n, *images.shape[1:]], images.dtype)
padded_images = np.concatenate([images, padding])
preds = unshard(get_preds(vs_p, shard(padded_images)))[:n]
andsteing marked this conversation as resolved.
Show resolved Hide resolved
total += n
correct += (batch['label'] == preds.argmax(axis=-1)).sum()


Using ``pad_shard_unpad()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~

The above pattern, namely the pad→shard→predict→unshard→unpad sequence, can be
extracted into a utility wrapper ``pad_shard_unpad()``, which greatly simplifies
above evaluation loop.

.. code-block:: python

correct = total = 0
for batch in ds.as_numpy_iterator():
preds = flax.jax_utils.pad_shard_unpad(get_preds)(
vs, batch['image'], min_device_batch=per_device_batch_size)
total += len(batch['image'])
correct += (batch['label'] == preds.argmax(axis=-1)).sum()


Adding "infinite padding"
~~~~~~~~~~~~~~~~~~~~~~~~~

Above solution works in most cases, but it has some limitations:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another solution is to let each host process indepedently and do the pmean(metrics) add the very end in a seperate pmapped program

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is actually very similar to what I do with count_correct_p() below

note that the above two sections Adding "infinite padding" and Computing metrics in eval_step() can easily be combined

so I think the different usecases are covered with the subsections, but if you feel there is a specific combination that should be added, feel free to add some more specific comments and I'll write it up.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah indeed a parallel stoppinig criteria is not so different when you use padding. I was thinking more towards a note on what to do when you don't want to pad. You might not want to include that route for simplicity though

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but without padding different hosts would have different batch sizes? and this would also lead to re-compilations?

(in that case I think we should keep it simple with "solution=padding" since it seems superior, at the cost of a little more complicated code, but that added complication is quite small, especially in the case where one compute the metrics in the main eval loop)


1. In the rare case where even splitting of the dataset on multiple hosts leads
to a different number of batches. Imagine having a dataset of ``n=4097``
examples, and evaluating this on ``h=8``, each having ``d=8`` local devices,
and forming on-device batch sizes of ``b=128``. With even dataset splitting,
the first host would get ``4096/8+1==513`` examples, and all other hosts
would get ``4096/8==512`` examples. Forming per-host batches of ``d*b==512``
this would lead to two batches on the first host, and a single batch on all
other hosts, violating SPMD principles and hanging the multi-host setup in
the last ``psum()`` directive (which would only be executed by the first
host, but not the others).

2. When dropping examples dynamically by using ``ds.filter()``.

In these more complicated cases we could add "infinite padding" to the dataset,
on each of the hosts independently, and continuing processing examples until
*all* hosts run out of unpadded examples.

.. code-block:: python

correct = total = 0
for batch in ds.as_numpy_iterator():
n = count_p(batch['mask'])[0].item() # adds sync barrier
if not n: break

preds = get_preds(vs, batch['image']).argmax(axis=-1)
total += n
correct += count_correct_p(batch['label'], preds, batch['mask'])[0]

As for the other examples in this HOWTO, the complete executable code can be
found in the Colab:

.. image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/full_eval.ipynb
1 change: 1 addition & 0 deletions docs/index.rst
Expand Up @@ -44,6 +44,7 @@ For a quick introduction and short example snippets, see our `README
howtos/convert_pytorch_to_flax
howtos/optax_update_guide
howtos/linen_upgrade_guide
howtos/full_eval


.. toctree::
Expand Down