Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] from_pretrained does not consider the passed dtype #12534

Open
patil-suraj opened this issue Jul 6, 2021 · 7 comments
Open

[Flax] from_pretrained does not consider the passed dtype #12534

patil-suraj opened this issue Jul 6, 2021 · 7 comments
Assignees
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@patil-suraj
Copy link
Contributor

Environment info

When loading a flax model with from_pretrained the type argument is not used. The weights are initialized with the dtype of saved weights.

So if you do

model = FlaxGPT2ForCausalLM.from_pretrained("gpt2", dtype=jnp.dtype("bfloat16"))

# check the dtype of one of the params 
model.params["transformer"]["wpe"]["embedding"].dtype
=> dtype("float32")

We should probably cast the weights to self.dtype.

As a workaround for bf16, one could manually cast the weighs with

def to_bf16(t):
    return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)

model.params = to_bf16(model.params)

cc @patrickvonplaten

@patrickvonplaten
Copy link
Contributor

I wonder whether this might be problematic for layer norm weights since those should usually always be of type float32, no?

@patrickvonplaten
Copy link
Contributor

Would love to hear what @avital @marcvanzee think here

@avital
Copy link
Contributor

avital commented Jul 6, 2021

I think it's fine to manually port weights to bfloat16 if you want to. In general all Flax layers accept a dtype attribute when it's safe to do intermediate computation in bloat16 and you can set dtype=bfloat16 for those layers. Keeping parameters as bfloat16 should only be necessary if the model is huge and the parameters can't fit on device memory, from what I know. I think it's tricky to get that right and requires careful attention to which parameters are safe to keep in bfloat16, but I don't have too much personal context here. I can ask others if that helps.

So I'm first curious whether indeed it's necessary to keep parameters as bfloat16 in this case, and if so, why

@XGCWYY111
Copy link

hello so

@github-actions
Copy link

github-actions bot commented Aug 5, 2021

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@patrickvonplaten
Copy link
Contributor

will soon be taken care of by @patil-suraj :-)

@patil-suraj patil-suraj added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Sep 2, 2021
@mohamad-amin
Copy link

This issue still exists. I was trying to make a float64 GPT2 instance (using this) and noticed the initiated LayerNorm parameters having float32 dtype.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests

5 participants