-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update jax trainer function to save memory buffer. #897
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #897 +/- ##
==========================================
+ Coverage 76.83% 76.85% +0.01%
==========================================
Files 329 329
Lines 31434 31435 +1
Branches 6114 6114
==========================================
+ Hits 24151 24158 +7
+ Misses 5719 5715 -4
+ Partials 1564 1562 -2
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
Thanks for the PR! The code looks good. Seems there's a CI failure though:
|
So this is bit complicated here, for the particular test that failed, there are variables appears in both non-trainable variables and metrics variables. This was causing issue when the memory get donated once first for non-trainable var, and failed when the same variable is accessed for metrics variable. In general, I am not sure if we should make the non-trainable variable and metric variable to be mutually exclusive (which I think make much sense, eg the optimizer variable is not considered as non-trainable/trainable variable). |
Also, took a closer look to the trainer/layer code, seems that we have a non-trainable variable which contains metrics, and a non-trainable weights which doesn't contain metrics. Just curious why we have this differentiation? I am not sure which approach is better and make more logical sense:
|
"non trainable variables include metric variables" is something we say in a bunch of places. The benefit of that is that you only have to manage We could certainly change that, and exclude metric variables. The fact that our JAX The other route we could go is to actually double down on the idea that Which is better? |
I see. I think it might make sense to exclude the metrics from non-trainable variable. My understanding of trainable/non-trainable variables are used for training/inference process, and affect the numerical output of the model. Eg the beta/gamma for BN is a good case for non-trainable variable. Same for the seed, which is used for generate the RNGs, either for initializer or dropout. The metrics variables on the other hand doesn't affect the training/inference outcome, even optimizer weights has a big contribution to the model output. Also, just when I print out the state for the model under testing, the non-trainable variable doesn't even include all the metrics variables. It include the weights that are attached the model as metrics, but not for those under model.compile().
Having said that, if we take this approach, then we won't have a easy way to visit those variable, unless we explicitly visit all the metrics for the layer and find those weights. Should we add a |
Ok, sounds good, let's keep both lists separate.
Yes, let's do that. It should be on the |
Ack. Done #910. |
Rebased this PR, and the unit test should pass now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
This is another attempt to save memory for training function. (In addition to #888)
With jax.jit(donate_arg), we can force jax to reuse the input arg memory buffer for the output, which will save one copy of memory size.
Since the donated memory can't be reused, I have to update the eval/test_on_batch function to use the output trainable_variables, since the original input copy has been donated.
The xporf for OPT model has already show some positive result:
Before: https://xprof.corp.google.com/memory_profile/scottzhu-15065682269222877644
After: https://xprof.corp.google.com/memory_profile/scottzhu-6823282872073158074.
As you can see the heap allocation is greatly reduced.