Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds "full eval" HOWTO. #2111

Merged
merged 4 commits into from May 19, 2022
Merged

Adds "full eval" HOWTO. #2111

merged 4 commits into from May 19, 2022

Conversation

andsteing
Copy link
Collaborator

@andsteing andsteing commented May 10, 2022

Adds a HOWTO explaining the problem and possible solutions for "full dataset processing", which is especially relevant for computing evaluation metrics.

The added function flax.jax_utils.pad_shard_unpad() is copied verbatim from big_vision and was created by @lucasb-eyer - thanks!

For reviewing this PR, see:

https://flax--2111.org.readthedocs.build/en/2111/howtos/full_eval.html

https://colab.research.google.com/github/andsteing/flax/blob/doc/docs/notebooks/full_eval.ipynb

Note that for running the Colab you'll need to replace the line

!pip install -q git+https://github.com/google/flax

with

!pip install -q git+https://github.com/andsteing/flax@doc

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov-commenter
Copy link

codecov-commenter commented May 10, 2022

Codecov Report

Merging #2111 (d0822db) into main (7c56ba9) will increase coverage by 0.19%.
The diff coverage is 96.87%.

@@            Coverage Diff             @@
##             main    #2111      +/-   ##
==========================================
+ Coverage   74.91%   75.10%   +0.19%     
==========================================
  Files          59       59              
  Lines        5042     5094      +52     
==========================================
+ Hits         3777     3826      +49     
- Misses       1265     1268       +3     
Impacted Files Coverage Δ
flax/jax_utils.py 66.66% <96.87%> (+10.28%) ⬆️
flax/linen/__init__.py 100.00% <0.00%> (ø)
flax/linen/activation.py 100.00% <0.00%> (ø)
flax/linen/partitioning.py 81.49% <0.00%> (+0.01%) ⬆️
flax/traverse_util.py 98.97% <0.00%> (+0.03%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7c56ba9...d0822db. Read the comment docs.

@andsteing andsteing marked this pull request as ready for review May 10, 2022 13:11
docs/howtos/full_eval.rst Outdated Show resolved Hide resolved
@andsteing andsteing requested a review from cgarciae May 10, 2022 15:05
Copy link
Collaborator

@marcvanzee marcvanzee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

does not form a complete batch at the end.


The problem
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the problem is two-fold: double compilation (both in training and eval), and incorrect metric results (in eval). Is this correct? Maybe it is worth emphasizing this. Currently you state that is especially important during eval but you don't explain why.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it more like this

  1. In eval, we care about processing the full dataset because otherwise the metrics are off. (in training we usually do multiple epochs and using some examples 1x less for training does not matter)
  2. When we want to avoid loosing data at the end, we run into other problems (like e.g. multiple compliations)

I think 1. is mentioned in the first paragraph "Especially when evaluating a model, it is important that we process all examples", and 2. is mentioned further down as disadvantage of some solutions.

docs/howtos/full_eval.rst Outdated Show resolved Hide resolved
@cgarciae
Copy link
Collaborator

One case that might be worth discussing is what to do if you update the metric state inside a compiled train_step/test_step. It would seem that you would have to use padding plus some form of masking?

@marcvanzee marcvanzee added Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) and removed Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. labels May 12, 2022
The added sub-section "Computing metrics in ``eval_step()``" and
corresponding Colab cell show how to use the new argument to compute
metrics inside the `eval_step()`.

In response to comment by @cgarciae
@andsteing
Copy link
Collaborator Author

One case that might be worth discussing is what to do if you update the metric state inside a compiled train_step/test_step. It would seem that you would have to use padding plus some form of masking?

@cgarciae, please check out the added section about using a eval_step() with the new argument pad_shard_unpad(static_return)

Adding "infinite padding"
~~~~~~~~~~~~~~~~~~~~~~~~~

Above solution works in most cases, but it has some limitations:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another solution is to let each host process indepedently and do the pmean(metrics) add the very end in a seperate pmapped program

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is actually very similar to what I do with count_correct_p() below

note that the above two sections Adding "infinite padding" and Computing metrics in eval_step() can easily be combined

so I think the different usecases are covered with the subsections, but if you feel there is a specific combination that should be added, feel free to add some more specific comments and I'll write it up.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah indeed a parallel stoppinig criteria is not so different when you use padding. I was thinking more towards a note on what to do when you don't want to pad. You might not want to include that route for simplicity though

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but without padding different hosts would have different batch sizes? and this would also lead to re-compilations?

(in that case I think we should keep it simple with "solution=padding" since it seems superior, at the cost of a little more complicated code, but that added complication is quite small, especially in the case where one compute the metrics in the main eval loop)

@andsteing andsteing requested a review from jheek May 16, 2022 12:05
@cgarciae
Copy link
Collaborator

cgarciae commented May 18, 2022

Left a new comment about static_return which is not urgent.
A previous comment that was marked as resolved says that maybe you were going to change the name vs_p to variables but I still see vs_p. Maybe I misunderstood, just wanted to clarify that in case its not intended.

@andsteing andsteing requested a review from cgarciae May 19, 2022 05:50
Copy link
Collaborator

@cgarciae cgarciae left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM @andsteing! Enjoyed the read, also made me realize all clu and jax_metrics metrics should support masking.

@andsteing andsteing removed the request for review from jheek May 19, 2022 11:49
@andsteing
Copy link
Collaborator Author

@cgarciae yes exactly, ideally all metrics support a single mask feature that is expected exactly for this reason; how the mask is incorporated for intermediate value updates then depends on the individual metrics ...

@copybara-service copybara-service bot merged commit 315da77 into google:main May 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) pull ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants