Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jax trainer function to save memory buffer. #897

Merged
merged 3 commits into from
Sep 18, 2023

Conversation

qlzh727
Copy link
Member

@qlzh727 qlzh727 commented Sep 15, 2023

This is another attempt to save memory for training function. (In addition to #888)

With jax.jit(donate_arg), we can force jax to reuse the input arg memory buffer for the output, which will save one copy of memory size.

Since the donated memory can't be reused, I have to update the eval/test_on_batch function to use the output trainable_variables, since the original input copy has been donated.

The xporf for OPT model has already show some positive result:
Before: https://xprof.corp.google.com/memory_profile/scottzhu-15065682269222877644
After: https://xprof.corp.google.com/memory_profile/scottzhu-6823282872073158074.

As you can see the heap allocation is greatly reduced.

@codecov
Copy link

codecov bot commented Sep 15, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: +0.01% 🎉

Comparison is base (9d39e9a) 76.83% compared to head (4aecd5b) 76.85%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #897      +/-   ##
==========================================
+ Coverage   76.83%   76.85%   +0.01%     
==========================================
  Files         329      329              
  Lines       31434    31435       +1     
  Branches     6114     6114              
==========================================
+ Hits        24151    24158       +7     
+ Misses       5719     5715       -4     
+ Partials     1564     1562       -2     
Flag Coverage Δ
keras_core 76.75% <100.00%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
keras_core/backend/jax/trainer.py 96.09% <100.00%> (+0.01%) ⬆️

... and 1 file with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@fchollet
Copy link
Member

Thanks for the PR! The code looks good.

Seems there's a CI failure though:

FAILED keras_core/trainers/trainer_test.py::TestTrainer::test_metric_tracking - jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Invalid buffer passed: buffer has been deleted or donated.

@qlzh727
Copy link
Member Author

qlzh727 commented Sep 17, 2023

Thanks for the PR! The code looks good.

Seems there's a CI failure though:

FAILED keras_core/trainers/trainer_test.py::TestTrainer::test_metric_tracking - jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Invalid buffer passed: buffer has been deleted or donated.

So this is bit complicated here, for the particular test that failed, there are variables appears in both non-trainable variables and metrics variables. This was causing issue when the memory get donated once first for non-trainable var, and failed when the same variable is accessed for metrics variable.

In general, I am not sure if we should make the non-trainable variable and metric variable to be mutually exclusive (which I think make much sense, eg the optimizer variable is not considered as non-trainable/trainable variable).

@qlzh727
Copy link
Member Author

qlzh727 commented Sep 17, 2023

Also, took a closer look to the trainer/layer code, seems that we have a non-trainable variable which contains metrics, and a non-trainable weights which doesn't contain metrics. Just curious why we have this differentiation?

I am not sure which approach is better and make more logical sense:

  1. Update trainer.non_trainable_variable to exclude anything from metrics_variables.
  2. Use the non_trainable_weights as the inputs to the model training function, this might be complicated due to the rng seed.

@fchollet
Copy link
Member

"non trainable variables include metric variables" is something we say in a bunch of places. The benefit of that is that you only have to manage trainable_variables + non_trainable_variables as being the whole state of the model, included any metrics attached to it, plus random seeds, etc.

We could certainly change that, and exclude metric variables. The fact that our JAX train_step separates both is a point in that direction. Another point: metric variables are not included in saving (this is because their state is not useful to keep across reloads).

The other route we could go is to actually double down on the idea that non_trainable_variables include metrics variables. In that case we'd try to stop special casing metric_variables in the JAX trainer.

Which is better?

@qlzh727
Copy link
Member Author

qlzh727 commented Sep 17, 2023

"non trainable variables include metric variables" is something we say in a bunch of places. The benefit of that is that you only have to manage trainable_variables + non_trainable_variables as being the whole state of the model, included any metrics attached to it, plus random seeds, etc.

We could certainly change that, and exclude metric variables. The fact that our JAX train_step separates both is a point in that direction. Another point: metric variables are not included in saving (this is because their state is not useful to keep across reloads).

The other route we could go is to actually double down on the idea that non_trainable_variables include metrics variables. In that case we'd try to stop special casing metric_variables in the JAX trainer.

Which is better?

I see. I think it might make sense to exclude the metrics from non-trainable variable. My understanding of trainable/non-trainable variables are used for training/inference process, and affect the numerical output of the model. Eg the beta/gamma for BN is a good case for non-trainable variable. Same for the seed, which is used for generate the RNGs, either for initializer or dropout. The metrics variables on the other hand doesn't affect the training/inference outcome, even optimizer weights has a big contribution to the model output.

Also, just when I print out the state for the model under testing, the non-trainable variable doesn't even include all the metrics variables. It include the weights that are attached the model as metrics, but not for those under model.compile().

keras_core/trainers/trainer_test.py::TestTrainer::test_metric_tracking 
print(model.non_trainable_variables)
[<KerasVariable shape=(), dtype=float32, path=my_metric/total>, <KerasVariable shape=(), dtype=float32, path=my_metric/count>]
print(model.metrics_variables)
[<KerasVariable shape=(), dtype=float32, path=loss/total>, <KerasVariable shape=(), dtype=float32, path=loss/count>, <KerasVariable shape=(), dtype=float32, path=my_metric/total>, <KerasVariable shape=(), dtype=float32, path=my_metric/count>]

Having said that, if we take this approach, then we won't have a easy way to visit those variable, unless we explicitly visit all the metrics for the layer and find those weights. Should we add a metrics_variables attribute for the layer?

@fchollet
Copy link
Member

Ok, sounds good, let's keep both lists separate.

Should we add a metrics_variables attribute for the layer?

Yes, let's do that. It should be on the Layer class, but also overridden on the Trainer class to add compile metrics variables if the model is compiled.

@qlzh727
Copy link
Member Author

qlzh727 commented Sep 18, 2023

Ack. Done #910.

@qlzh727
Copy link
Member Author

qlzh727 commented Sep 18, 2023

Rebased this PR, and the unit test should pass now.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

@fchollet fchollet merged commit 2dfe475 into keras-team:main Sep 18, 2023
8 checks passed
@qlzh727 qlzh727 deleted the jit_dontate branch September 18, 2023 22:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants