Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] Addition of FlaxPegasus #13420

Merged

Conversation

bhadreshpsavani
Copy link
Contributor

@bhadreshpsavani bhadreshpsavani commented Sep 4, 2021

What does this PR do?

Before submitting

Who can review?

@patrickvonplaten
@patil-suraj

@bhadreshpsavani
Copy link
Contributor Author

Hi @patil-suraj and @patrickvonplaten,
I was not able to figure out how to add PegasusSinusoidalPositionalEmbedding in the Flax Version and QuestionAnswering and Classification classes are not added yet since the original torch version don't have the classes. Shall we add it?
Please let me know your review on this PR.

@patil-suraj
Copy link
Contributor

patil-suraj commented Sep 6, 2021

You could find the flax version of SinusoidalPositionalEmbedding in the FlaxMarian

self.embed_positions = create_sinusoidal_positions(

Also, Pegasus isn't really intended for QA and classification so it's okay to not add those heads yet.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this @bhadreshpsavani , great work!

The PR looks good overall, I've left a few comments below.

Mainly

  • Let's try to use as many copied from ... statements as possible
  • The order of layer_norm in FlaxPegasusEncoder and FlaxPegasusDecoder should be changed. I've left details in the comment.

Let me know if it's not clear or if you need any other help :)

src/transformers/models/auto/modeling_flax_auto.py Outdated Show resolved Hide resolved
src/transformers/models/pegasus/modeling_flax_pegasus.py Outdated Show resolved Hide resolved
src/transformers/models/pegasus/modeling_flax_pegasus.py Outdated Show resolved Hide resolved
src/transformers/models/pegasus/modeling_flax_pegasus.py Outdated Show resolved Hide resolved
src/transformers/models/pegasus/modeling_flax_pegasus.py Outdated Show resolved Hide resolved
src/transformers/models/pegasus/modeling_flax_pegasus.py Outdated Show resolved Hide resolved
src/transformers/models/pegasus/modeling_flax_pegasus.py Outdated Show resolved Hide resolved
src/transformers/models/pegasus/modeling_flax_pegasus.py Outdated Show resolved Hide resolved
tests/test_modeling_flax_pegasus.py Outdated Show resolved Hide resolved
@patrickvonplaten
Copy link
Contributor

Thanks a lot for more or less completing the PR - great job @bhadreshpsavani !

It seems like there are some small differences between the PyTorch & Flax Model. This could be due to slightly different activation functions or small differences with the position ids....

It would be awesome if you could try to debug layer by layer what might be the problem there @bhadreshpsavani

Another possibility is that there is no difference and it's just the framework that causes the difference. In this case, we'll just have to accept it and change the tolerance.

@bhadreshpsavani
Copy link
Contributor Author

Sure @patrickvonplaten,
I will compare the code and debug it :)

@bhadreshpsavani
Copy link
Contributor Author

bhadreshpsavani commented Sep 10, 2021

Hi @patil-suraj and @patrickvonplaten,
Thanks for the review and suggestions. Please let me know if anything missing in the PR.

@patil-suraj
Copy link
Contributor

Thanks a lot for fixing the issues, looks good now. If you could give me access to this branch I would like to update the slow tests.

@bhadreshpsavani
Copy link
Contributor Author

Done!
Please go ahead with the fix for slow tests.
Once this is merged, I will create another PR for that Typo Fix in BART and PEGASUS that I come across

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this model, great work @bhadreshpsavani !

I updated the slow tests and also pushed a couple of flax checkpoints (pegasus-large, pegasus-xum) to the hub. WIll also push the remaining official weights later.

@patrickvonplaten do you wanna give it another look?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

@patrickvonplaten patrickvonplaten merged commit c1e47bf into huggingface:master Sep 14, 2021
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 13, 2022
* added initial files

* fixes pipeline

* fixes style and quality

* fixes doc issue and positional encoding

* fixes layer norm and test

* fixes quality issue

* fixes code quality

* removed extra layer norm

* added layer norm back in encoder and decoder

* added more code copy quality checks

* update tests

* Apply suggestions from code review

* fix import

* fix test

Co-authored-by: patil-suraj <surajp815@gmail.com>
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 27, 2022
* added initial files

* fixes pipeline

* fixes style and quality

* fixes doc issue and positional encoding

* fixes layer norm and test

* fixes quality issue

* fixes code quality

* removed extra layer norm

* added layer norm back in encoder and decoder

* added more code copy quality checks

* update tests

* Apply suggestions from code review

* fix import

* fix test

Co-authored-by: patil-suraj <surajp815@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants