Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Flax params dtype #13098

Merged
merged 36 commits into from
Nov 11, 2021
Merged

Conversation

patil-suraj
Copy link
Contributor

@patil-suraj patil-suraj commented Aug 12, 2021

What does this PR do?

The dtype argument in flax models is used ambiguously. This argument is actually supposed to specify the dtype of computation and not the dtype of model parameters. But in some models/modules, it's passed to kernel_initializers which causes the kernel parameters to be initialized with that dtype. This causes the following issues

  • in flax models, we don't pass bias_init to Dense layers since the default value is as expected by our models. So if we pass dtype=jnp.bfloat16 it's only passed to kernel_init, so for a dense layer the kernel params are in bfloat16 while the bias params are in fp32
  • This also causes issues with saving and loading models as explained in [Flax] from_pretrained does not consider the passed dtype #12534

This PR corrects the usage of dtype in flax models and adds to_bf16, to_fp16 and to_fp32 methods in FlaxPreTrainedModel. These methods could accept any arbitrary params tree and change its dtype. So if users want they could keep certain params in bf16 and certain others in fp32 however they like, by just passing the right parameters to these methods.

To allow keeping only certain params in half-precision the to_bf16 method accepts a mask that specifies what params to keep in bf16 and what params in fp32. For example

import jax
import jax.numpy as jnp
from flax.core.frozen_dict import freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers import FlaxBertModel, BertConfig

config = BertConfig(num_hidden_layers=1)
model = FlaxBertModel(config, dtype=jnp.dtype("bfloat16"))

# keep layer norm in fp32
def mask_fn(params):
    flat_params = flatten_dict(params)
    flat_mask = {path: (path[-2:] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
    return unflatten_dict(flat_mask)

mask = mask_fn(model.params)
params = model.to_bf16(model.params, mask)

jax.eval_shape(lambda x : x, freeze(params)) # view the dtypes
  • This PR also fixes an issue in some models where the dtype was never passed to some modules, so those modules were always doing computation in fp32 even if the user passed bf16 or fp16 dtype .

  • This should now help enable mixed-precision training in flax models as we can keep the params and computation dtype separate.


🚨 BREAKING CHANGE 🚨
Note that: this will be a breaking change since the meaning of dtype is now changed and it's only used to specify the data type of computation and does not influence the data type of model parameters.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I agree with your solution @patil-suraj but I'll defer to @patrickvonplaten 's final say as I don't know Jax/Flax enough to have a strong opinion.

Initializing a model in bfloat16 of float16 seems like the wrong thing to do.

@patrickvonplaten
Copy link
Contributor

I like the design and I think it follows jax-design quite nicely (similar to how optax optimizers mask certain weights:

)

This PR will necessarly have some breaking changes as after it loading a model with dtype=bfloat16 won't convert the weights into bfloat16 anymore, so we should announce it well.

Also it would be great if @avital could maybe quickly give his opinion on the API here

@avital
Copy link
Contributor

avital commented Sep 9, 2021

Hi folks, sorry for the radio silence, I'm back now. @jheek has thought carefully about the exact meaning of dtypes in Flax modules so I'd like to hear his take on this confusion.

@jheek
Copy link

jheek commented Sep 10, 2021

I think masking is the right approach here. The right dtype is very context dependent. During inference half precision is almost always better while during training it's pretty much never worth it. And then of course there is fine-tuning where the masked weights are basically in inference mode. The mask based API captures this complexity really well.

@patil-suraj patil-suraj force-pushed the fix-flax-dtype branch 2 times, most recently from 97e24de to 11d1afd Compare September 15, 2021 08:26
@patil-suraj
Copy link
Contributor Author

patil-suraj commented Sep 15, 2021

just noticed a sneaky bug in some flax models, the dtype is never passed to some modules, for example here in bart,

self.self_attn = FlaxBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
)

attention never receives dtype, so it’s always in fp32 even if the user passed bf16 .

same with T5 here

self.encoder = FlaxT5Stack(encoder_config, self.shared)

@patrickvonplaten @sgugger I propose we make dtype required for all modules except user-facing once? So all main model classes will have a default type (fp32) but for all other submodules make it required to avoid such bugs.

@patil-suraj
Copy link
Contributor Author

Hey @patrickvonplaten !

  • added a couple more tests as you suggested
  • updated the templates
  • ran tests on both GPU and TPU and they pass

Would be awesome if you could take quick final look :)

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, the functionality looks good. Left a couple of suggestions for improvement for the docstring. I wonder whether we should even leave a couple more sentences under to_bfloat16(...) and to_float16(...) that state why one would call such a method. E.g. bfloat16 for TPU to save memory and float16 for GPU to save memory. Should we maybe also adapt the docstring of the dtype input that one can pass to a Flax model to really make sure that people understand that dtype never influences the weights but just the computations?

Also would be great to check that the docstring is nicely shown in the docs .

Finally, we should leave a big statement in the PR description with 🚨 that this PR is backwards breaking.

We should really bring across the message that dtype == computation precision & to_...() converts parameters

src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
@patrickvonplaten
Copy link
Contributor

Thanks for finishing the PR!

@patrickvonplaten
Copy link
Contributor

Think we just need to update the Flax templates now and we're good to go :-)

@patil-suraj patil-suraj merged commit e92190c into huggingface:master Nov 11, 2021
@patil-suraj patil-suraj deleted the fix-flax-dtype branch November 11, 2021 09:15
@patil-suraj patil-suraj mentioned this pull request Nov 29, 2021
3 tasks
stancld added a commit to stancld/transformers that referenced this pull request Nov 29, 2021
patil-suraj added a commit that referenced this pull request Nov 30, 2021
* Init Flax implementation for Blenderbot

* Add a majority of stuff except for tests

* make style quality

* Add tests and fix some bugs

* Add tests

* Clean source code and fix some bugs

* Fix copies and docs

* Fix jax device condition for tests

* Fix layer norm in the encoder

* Fix a few typos in the test file

* make fix-copies

* make fix-copies

* fix layer norm

* Fix Flax params dtype (#13090)

* Fix PR reference (#13098)

* make fix-copies

* Update tests/test_modeling_flax_blenderbot.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 27, 2022
* fix inits

* fix embed dtype

* fix embed dtype

* add test to check default dtype

* quality

* add type conversion methods for flax models

* more robust casting

* cast sinusoidal positions

* update pegasus

* update albert

* update test

* make sure dtype is passed to every module

* style

* fix electra dense

* fix t5

* quality

* add more tests

* better name

* use the dtype for lm head computation

* fix albert

* style

* fix albert embed dtype

* more tests

* fix vision enc-dec

* cleanup

* fix embed dtype pegasus

* fix default param test

* doc

* update template

* fix final_logits_bias dtype

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix doc

* fix doc

* add detailed docstring for dtype parameter

* remove un-necessary import

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 27, 2022
* Init Flax implementation for Blenderbot

* Add a majority of stuff except for tests

* make style quality

* Add tests and fix some bugs

* Add tests

* Clean source code and fix some bugs

* Fix copies and docs

* Fix jax device condition for tests

* Fix layer norm in the encoder

* Fix a few typos in the test file

* make fix-copies

* make fix-copies

* fix layer norm

* Fix Flax params dtype (huggingface#13090)

* Fix PR reference (huggingface#13098)

* make fix-copies

* Update tests/test_modeling_flax_blenderbot.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants