-
Notifications
You must be signed in to change notification settings - Fork 25.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Flax params dtype #13098
Fix Flax params dtype #13098
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I agree with your solution @patil-suraj but I'll defer to @patrickvonplaten 's final say as I don't know Jax/Flax enough to have a strong opinion.
Initializing a model in bfloat16 of float16 seems like the wrong thing to do.
I like the design and I think it follows jax-design quite nicely (similar to how optax optimizers mask certain weights:
This PR will necessarly have some breaking changes as after it loading a model with Also it would be great if @avital could maybe quickly give his opinion on the API here |
Hi folks, sorry for the radio silence, I'm back now. @jheek has thought carefully about the exact meaning of dtypes in Flax modules so I'd like to hear his take on this confusion. |
I think masking is the right approach here. The right dtype is very context dependent. During inference half precision is almost always better while during training it's pretty much never worth it. And then of course there is fine-tuning where the masked weights are basically in inference mode. The mask based API captures this complexity really well. |
97e24de
to
11d1afd
Compare
just noticed a sneaky bug in some flax models, the transformers/src/transformers/models/bart/modeling_flax_bart.py Lines 400 to 405 in 3fbb55c
attention never receives dtype , so it’s always in fp32 even if the user passed bf16 .
same with T5 here
@patrickvonplaten @sgugger I propose we make |
bb1e151
to
ae2b6d1
Compare
Hey @patrickvonplaten !
Would be awesome if you could take quick final look :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, the functionality looks good. Left a couple of suggestions for improvement for the docstring. I wonder whether we should even leave a couple more sentences under to_bfloat16(...)
and to_float16(...)
that state why one would call such a method. E.g. bfloat16 for TPU to save memory and float16 for GPU to save memory. Should we maybe also adapt the docstring of the dtype
input that one can pass to a Flax model to really make sure that people understand that dtype
never influences the weights but just the computations?
Also would be great to check that the docstring is nicely shown in the docs .
Finally, we should leave a big statement in the PR description with 🚨 that this PR is backwards breaking.
We should really bring across the message that dtype
== computation precision & to_...()
converts parameters
Thanks for finishing the PR! |
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Think we just need to update the Flax templates now and we're good to go :-) |
...er-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py
Outdated
Show resolved
Hide resolved
* Init Flax implementation for Blenderbot * Add a majority of stuff except for tests * make style quality * Add tests and fix some bugs * Add tests * Clean source code and fix some bugs * Fix copies and docs * Fix jax device condition for tests * Fix layer norm in the encoder * Fix a few typos in the test file * make fix-copies * make fix-copies * fix layer norm * Fix Flax params dtype (#13090) * Fix PR reference (#13098) * make fix-copies * Update tests/test_modeling_flax_blenderbot.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
* fix inits * fix embed dtype * fix embed dtype * add test to check default dtype * quality * add type conversion methods for flax models * more robust casting * cast sinusoidal positions * update pegasus * update albert * update test * make sure dtype is passed to every module * style * fix electra dense * fix t5 * quality * add more tests * better name * use the dtype for lm head computation * fix albert * style * fix albert embed dtype * more tests * fix vision enc-dec * cleanup * fix embed dtype pegasus * fix default param test * doc * update template * fix final_logits_bias dtype * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix doc * fix doc * add detailed docstring for dtype parameter * remove un-necessary import Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* Init Flax implementation for Blenderbot * Add a majority of stuff except for tests * make style quality * Add tests and fix some bugs * Add tests * Clean source code and fix some bugs * Fix copies and docs * Fix jax device condition for tests * Fix layer norm in the encoder * Fix a few typos in the test file * make fix-copies * make fix-copies * fix layer norm * Fix Flax params dtype (huggingface#13090) * Fix PR reference (huggingface#13098) * make fix-copies * Update tests/test_modeling_flax_blenderbot.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
What does this PR do?
The
dtype
argument in flax models is used ambiguously. This argument is actually supposed to specify thedtype
of computation and not thedtype
of model parameters. But in some models/modules, it's passed to kernel_initializers which causes thekernel
parameters to be initialized with thatdtype
. This causes the following issuesbias_init
toDense
layers since the default value is as expected by our models. So if we passdtype=jnp.bfloat16
it's only passed tokernel_init
, so for a dense layer the kernel params are inbfloat16
while thebias
params are infp32
This PR corrects the usage of
dtype
in flax models and addsto_bf16
,to_fp16
andto_fp32
methods inFlaxPreTrainedModel
. These methods could accept any arbitrary params tree and change itsdtype
. So if users want they could keep certain params in bf16 and certain others in fp32 however they like, by just passing the right parameters to these methods.To allow keeping only certain params in half-precision the
to_bf16
method accepts a mask that specifies what params to keep inbf16
and what params infp32
. For exampleThis PR also fixes an issue in some models where the
dtype
was never passed to some modules, so those modules were always doing computation in fp32 even if the user passedbf16
orfp16
dtype .This should now help enable mixed-precision training in flax models as we can keep the params and computation
dtype
separate.🚨 BREAKING CHANGE 🚨
Note that: this will be a breaking change since the meaning of
dtype
is now changed and it's only used to specify the data type of computation and does not influence the data type of model parameters.