New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds "full eval" HOWTO. #2111
Adds "full eval" HOWTO. #2111
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
d764441
to
b281a16
Compare
Codecov Report
@@ Coverage Diff @@
## main #2111 +/- ##
==========================================
+ Coverage 74.91% 75.10% +0.19%
==========================================
Files 59 59
Lines 5042 5094 +52
==========================================
+ Hits 3777 3826 +49
- Misses 1265 1268 +3
Continue to review full report at Codecov.
|
23afb4d
to
2180895
Compare
742a8c5
to
a6c6ad0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
does not form a complete batch at the end. | ||
|
||
|
||
The problem |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the problem is two-fold: double compilation (both in training and eval), and incorrect metric results (in eval). Is this correct? Maybe it is worth emphasizing this. Currently you state that is especially important during eval but you don't explain why.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see it more like this
- In eval, we care about processing the full dataset because otherwise the metrics are off. (in training we usually do multiple epochs and using some examples 1x less for training does not matter)
- When we want to avoid loosing data at the end, we run into other problems (like e.g. multiple compliations)
I think 1. is mentioned in the first paragraph "Especially when evaluating a model, it is important that we process all examples", and 2. is mentioned further down as disadvantage of some solutions.
One case that might be worth discussing is what to do if you update the metric state inside a compiled |
The added sub-section "Computing metrics in ``eval_step()``" and corresponding Colab cell show how to use the new argument to compute metrics inside the `eval_step()`. In response to comment by @cgarciae
@cgarciae, please check out the added section about using a |
Adding "infinite padding" | ||
~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
Above solution works in most cases, but it has some limitations: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another solution is to let each host process indepedently and do the pmean(metrics) add the very end in a seperate pmapped program
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this is actually very similar to what I do with count_correct_p()
below
note that the above two sections Adding "infinite padding"
and Computing metrics in eval_step()
can easily be combined
so I think the different usecases are covered with the subsections, but if you feel there is a specific combination that should be added, feel free to add some more specific comments and I'll write it up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah indeed a parallel stoppinig criteria is not so different when you use padding. I was thinking more towards a note on what to do when you don't want to pad. You might not want to include that route for simplicity though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but without padding different hosts would have different batch sizes? and this would also lead to re-compilations?
(in that case I think we should keep it simple with "solution=padding" since it seems superior, at the cost of a little more complicated code, but that added complication is quite small, especially in the case where one compute the metrics in the main eval loop)
Left a new comment about |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM @andsteing! Enjoyed the read, also made me realize all clu
and jax_metrics
metrics should support masking.
@cgarciae yes exactly, ideally all metrics support a single |
Adds a HOWTO explaining the problem and possible solutions for "full dataset processing", which is especially relevant for computing evaluation metrics.
The added function
flax.jax_utils.pad_shard_unpad()
is copied verbatim frombig_vision
and was created by @lucasb-eyer - thanks!For reviewing this PR, see:
https://flax--2111.org.readthedocs.build/en/2111/howtos/full_eval.html
https://colab.research.google.com/github/andsteing/flax/blob/doc/docs/notebooks/full_eval.ipynb
Note that for running the Colab you'll need to replace the line
with