New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds "full eval" HOWTO. #2111
Adds "full eval" HOWTO. #2111
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ vNext | |
- | ||
- | ||
- | ||
- | ||
- Added `flax.jax_utils.ad_shard_unpad()` by @lucasb-eyer | ||
- | ||
- | ||
- | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,3 +20,4 @@ Multi device utilities | |
|
||
.. autofunction:: pmean | ||
|
||
.. autofunction:: pad_shard_unpad |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
.. image:: https://colab.research.google.com/assets/colab-badge.svg | ||
:target: https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/full_eval.ipynb | ||
|
||
Processing the entire Dataset | ||
============================= | ||
|
||
For efficiency reasons, we form batches that contain multiple examples and | ||
process them in parallel. Especially when evaluating a model, it is important | ||
that we process all examples and **avoid losing the remainder** of examples that | ||
does not form a complete batch at the end. | ||
|
||
|
||
The problem | ||
----------- | ||
|
||
When evaluating on a single device, one can either drop the last incomplete | ||
batch, or one can form a last batch with a shape different from the preceding | ||
batches. Doing the latter has the disadvantage that this will trigger a | ||
**recompilation** of the ``eval_step()`` because XLA is not shape polymorphic. | ||
|
||
.. code-block:: python | ||
|
||
collections.Counter( | ||
tuple(batch['image'].shape) | ||
for batch in tfds.load('mnist', split='test').batch(per_device_batch_size) | ||
) | ||
# output: | ||
# Counter({(272, 28, 28, 1): 1, (512, 28, 28, 1): 19}) | ||
|
||
The problem is accentuated when using multiple devices for data parallelism. If | ||
the batch size is not **divisible by the number devices**, then that last step | ||
must be executed on a single device (or a subset of devices). Usually one would | ||
drop the last batch, but this will lead to incorrect results. | ||
|
||
|
||
.. code-block:: python | ||
|
||
sum( | ||
np.prod(batch['label'].shape) | ||
for batch in tfds.load('mnist', split='test') | ||
.batch(per_device_batch_size, drop_remainder=True) | ||
.batch(jax.local_device_count()) | ||
) | ||
# output: | ||
# 9728 | ||
|
||
Using multiple hosts further complicates the situation because JAX uses the SPMD | ||
paradigm and every host must execute the same program. We would usually form | ||
non-overlapping splits for different hosts with |tfds.split_for_jax_process()|_, | ||
but this can lead to **different numbers for different hosts**, resulting in | ||
different JAX programs when all examples are to be processed. | ||
|
||
.. code-block:: python | ||
|
||
process_count = 6 | ||
[ | ||
len(tfds.load(dataset_name, split=tfds.split_for_jax_process( | ||
'test', process_index=process_index, process_count=process_count))) | ||
for process_index in range(process_count) | ||
] | ||
# output: | ||
# [1667, 1667, 1667, 1667, 1666, 1666] | ||
|
||
|
||
|
||
.. |tfds.split_for_jax_process()| replace:: ``tfds.split_for_jax_process()`` | ||
.. _tfds.split_for_jax_process(): https://www.tensorflow.org/datasets/api_docs/python/tfds/split_for_jax_process | ||
|
||
|
||
The solution: padding | ||
--------------------- | ||
|
||
Even though it's possible to solve this problem by cleverly adjusting the number | ||
jheek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
of batches executed by different devices on different hosts, such a solution | ||
quickly becomes complicated and makes the main eval loop hard to read with a lot | ||
of cumbersome logic. | ||
|
||
The more straight forward solution to this problem is to use padding at the end | ||
andsteing marked this conversation as resolved.
Show resolved
Hide resolved
|
||
of the dataset to make sure that the last batch has the same size as the | ||
preceding batches. | ||
|
||
|
||
Manual implementation | ||
~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
The last batch is manually padded to contain the same number of examples as in | ||
the preceding batches. The predictions for the padded examples are discarded | ||
from the computation. | ||
|
||
.. code-block:: python | ||
|
||
shard = lambda x: einops.rearrange( | ||
x, '(d b) ... -> d b ...', d=jax.local_device_count()) | ||
unshard = lambda x: einops.rearrange(x, 'd b ... -> (d b) ...') | ||
|
||
correct = total = 0 | ||
for batch in ds.as_numpy_iterator(): | ||
images = batch['image'] | ||
n = len(images) | ||
padding = np.zeros([per_host_batch_size - n, *images.shape[1:]], images.dtype) | ||
padded_images = np.concatenate([images, padding]) | ||
preds = unshard(get_preds(vs_p, shard(padded_images)))[:n] | ||
andsteing marked this conversation as resolved.
Show resolved
Hide resolved
|
||
total += n | ||
correct += (batch['label'] == preds.argmax(axis=-1)).sum() | ||
|
||
|
||
Using ``pad_shard_unpad()`` | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
The above pattern, namely the pad→shard→predict→unshard→unpad sequence, can be | ||
extracted into a utility wrapper ``pad_shard_unpad()``, which greatly simplifies | ||
above evaluation loop. | ||
|
||
.. code-block:: python | ||
|
||
correct = total = 0 | ||
for batch in ds.as_numpy_iterator(): | ||
preds = flax.jax_utils.pad_shard_unpad(get_preds)( | ||
vs, batch['image'], min_device_batch=per_device_batch_size) | ||
total += len(batch['image']) | ||
correct += (batch['label'] == preds.argmax(axis=-1)).sum() | ||
|
||
|
||
Adding "infinite padding" | ||
~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
Above solution works in most cases, but it has some limitations: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another solution is to let each host process indepedently and do the pmean(metrics) add the very end in a seperate pmapped program There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, this is actually very similar to what I do with note that the above two sections so I think the different usecases are covered with the subsections, but if you feel there is a specific combination that should be added, feel free to add some more specific comments and I'll write it up. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah indeed a parallel stoppinig criteria is not so different when you use padding. I was thinking more towards a note on what to do when you don't want to pad. You might not want to include that route for simplicity though There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but without padding different hosts would have different batch sizes? and this would also lead to re-compilations? (in that case I think we should keep it simple with "solution=padding" since it seems superior, at the cost of a little more complicated code, but that added complication is quite small, especially in the case where one compute the metrics in the main eval loop) |
||
|
||
1. In the rare case where even splitting of the dataset on multiple hosts leads | ||
to a different number of batches. Imagine having a dataset of ``n=4097`` | ||
examples, and evaluating this on ``h=8``, each having ``d=8`` local devices, | ||
and forming on-device batch sizes of ``b=128``. With even dataset splitting, | ||
the first host would get ``4096/8+1==513`` examples, and all other hosts | ||
would get ``4096/8==512`` examples. Forming per-host batches of ``d*b==512`` | ||
this would lead to two batches on the first host, and a single batch on all | ||
other hosts, violating SPMD principles and hanging the multi-host setup in | ||
the last ``psum()`` directive (which would only be executed by the first | ||
host, but not the others). | ||
|
||
2. When dropping examples dynamically by using ``ds.filter()``. | ||
|
||
In these more complicated cases we could add "infinite padding" to the dataset, | ||
on each of the hosts independently, and continuing processing examples until | ||
*all* hosts run out of unpadded examples. | ||
|
||
.. code-block:: python | ||
|
||
correct = total = 0 | ||
for batch in ds.as_numpy_iterator(): | ||
n = count_p(batch['mask'])[0].item() # adds sync barrier | ||
if not n: break | ||
|
||
preds = get_preds(vs, batch['image']).argmax(axis=-1) | ||
total += n | ||
correct += count_correct_p(batch['label'], preds, batch['mask'])[0] | ||
|
||
As for the other examples in this HOWTO, the complete executable code can be | ||
found in the Colab: | ||
|
||
.. image:: https://colab.research.google.com/assets/colab-badge.svg | ||
:target: https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/full_eval.ipynb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the problem is two-fold: double compilation (both in training and eval), and incorrect metric results (in eval). Is this correct? Maybe it is worth emphasizing this. Currently you state that is especially important during eval but you don't explain why.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see it more like this
I think 1. is mentioned in the first paragraph "Especially when evaluating a model, it is important that we process all examples", and 2. is mentioned further down as disadvantage of some solutions.