Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] Add FlaxBlenderbot #13633

Merged
merged 28 commits into from
Nov 30, 2021
Merged

Conversation

stancld
Copy link
Contributor

@stancld stancld commented Sep 17, 2021

What does this PR do?

This PR adds flax implementation of Blenderbot.

Before submitting


TODOs:

  • fix PT-Flax model equivalence

Who can review?

@patrickvonplaten @patil-suraj

@stancld stancld changed the title [Flax] Add FlaxBlenderbot [WIP] [Flax] Add FlaxBlenderbot Sep 17, 2021
@stancld stancld changed the title [WIP] [Flax] Add FlaxBlenderbot [Flax] Add FlaxBlenderbot Sep 17, 2021
@stancld
Copy link
Contributor Author

stancld commented Oct 15, 2021

@patrickvonplaten I would like to kindly ping for a review. :) I've been struggling to achieve the pt-flax equivalence, however, I cannot find that difference/bug in this new flax implementation.

Thanks a lot! :)

@patrickvonplaten
Copy link
Contributor

Hey @stancld,

Thanks a lot for the PR! The difference between PT and Flax in your PR is very close actually < 0.1 so it might also very well be that the implementation is correct!

I'll try to take a deeper look at the end of next week. Could you try one last thing:

add print statements such as:

print("PT", hidden_states.sum()) in PyTorch

and

print("FX", hidden_states.sum()) in Flax

before the word embeddings, after the word embeddings, each encoder transformer layer, before the decoder word embeddings, the decoder attention layers, ... to see when the activations start to diverge. If it happens gradually it might very well be that the model is correct and there is a difference. If it haapens all of a sudden at some point then there might be a subtle bug.

@stancld
Copy link
Contributor Author

stancld commented Oct 25, 2021

@patrickvonplaten Thank you for the tip! I'll have a look :)

@stancld
Copy link
Contributor Author

stancld commented Oct 28, 2021

Hello @patrickvonplaten, I ran a few tests it seems and one output is below. There is some level of divergence, but not sure if it's too severe. I'm gonna check the Flax code today once again :)

===PyTorch===
---Encoder---
PT first hidden-states:  tensor(-1.2589)
PT encoder after self-attn:  tensor(0.5862)
PT encoder: tensor(-0.7895)
PT encoder after self-attn:  tensor(0.0465)
PT encoder last hidden states before norm:  tensor(-0.2601)
PT encoder last hidden states after norm:  tensor(0.)
---Decoder---
PT decoder after self-attn:  tensor(1.1000)
PT decoder after cross-attn:  tensor(0.1547)
PT decoder: tensor(-0.0142)
PT decoder after self-attn:  tensor(0.9638)
PT decoder after cross-attn:  tensor(1.7759)
PT decoder: tensor(2.7198)
PT decoder last hidden states before norm:  tensor(2.7198)
PT decoder last hidden states after norm:  tensor(-5.7220e-06)
PT output:  tensor(-5.7220e-06)

===Flax===
---Encoder---
FX first hidden-states:  -1.2589027
FX encoder after self-attn:  0.59013414
FX encoder: -0.7862803
FX encoder after self-attn:  0.04762589
FX encoder last hidden states before norm:  -0.25001374
FX encoder last hidden states after norm:  4.053116e-06
---Decoder---
FX decoder after self-attn:  1.1029385
FX decoder after cross-attn:  0.15325405
FX decoder: -0.013041288
FX decoder after self-attn:  0.96520036
FX decoder after cross-attn:  1.7912248
FX decoder last hidden states before norm:  2.735363
FX decoder last hidden states after norm:  -1.1697412e-06
FX output:  -1.1697412e-06

@@ -405,7 +405,7 @@ def setup(self) -> None:
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj - the default in PyTorch is 1e-05, so I adapted it for all Bart-like models. Given that PT and Flax tests were passing for Bart before I think this "bug correction" is fine in terms of backwards compatibility

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me now!

@stancld
Copy link
Contributor Author

stancld commented Nov 3, 2021

@patrickvonplaten Thank you very much for spotting the problem! :]

@patrickvonplaten
Copy link
Contributor

Tests on master seem to be broken currently :-/

But I think the PR is good to go. @patil-suraj could you maybe take a look once you're back (and maybe rebase to master with @stancld to fix the circli ci runner)

@patrickvonplaten
Copy link
Contributor

Awesome - I let you merge @patil-suraj once you're back :-)

@huggingface huggingface deleted a comment from github-actions bot Nov 29, 2021
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not pass the dtype anymore to the kernel_init, it's meant to specify the dtype of computation and not of parameters. This was a bug in all flax models, which is fixed by #13098.

@stancld Could you please rebase the branch again with master and fix this according to what is explained in #13098?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj Thank you for providing me with the context. Should be fixed now :]

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @stancld for adding this, LGTM!

WIll push a couple of flax checkpoint and then merge :)

Comment on lines +466 to +472
if is_flax_available():
import jax

jax_device = jax.default_backend()
else:
jax_device = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@patil-suraj patil-suraj merged commit faacd74 into huggingface:master Nov 30, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants