-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flax] Addition of FlaxPegasus #13420
[Flax] Addition of FlaxPegasus #13420
Conversation
Hi @patil-suraj and @patrickvonplaten, |
You could find the flax version of
Also, Pegasus isn't really intended for QA and classification so it's okay to not add those heads yet. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this @bhadreshpsavani , great work!
The PR looks good overall, I've left a few comments below.
Mainly
- Let's try to use as many
copied from ...
statements as possible - The order of
layer_norm
inFlaxPegasusEncoder
andFlaxPegasusDecoder
should be changed. I've left details in the comment.
Let me know if it's not clear or if you need any other help :)
Thanks a lot for more or less completing the PR - great job @bhadreshpsavani ! It seems like there are some small differences between the PyTorch & Flax Model. This could be due to slightly different activation functions or small differences with the position ids.... It would be awesome if you could try to debug layer by layer what might be the problem there @bhadreshpsavani Another possibility is that there is no difference and it's just the framework that causes the difference. In this case, we'll just have to accept it and change the tolerance. |
Sure @patrickvonplaten, |
Hi @patil-suraj and @patrickvonplaten, |
Thanks a lot for fixing the issues, looks good now. If you could give me access to this branch I would like to update the slow tests. |
Done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this model, great work @bhadreshpsavani !
I updated the slow tests and also pushed a couple of flax checkpoints (pegasus-large
, pegasus-xum
) to the hub. WIll also push the remaining official weights later.
@patrickvonplaten do you wanna give it another look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
* added initial files * fixes pipeline * fixes style and quality * fixes doc issue and positional encoding * fixes layer norm and test * fixes quality issue * fixes code quality * removed extra layer norm * added layer norm back in encoder and decoder * added more code copy quality checks * update tests * Apply suggestions from code review * fix import * fix test Co-authored-by: patil-suraj <surajp815@gmail.com>
* added initial files * fixes pipeline * fixes style and quality * fixes doc issue and positional encoding * fixes layer norm and test * fixes quality issue * fixes code quality * removed extra layer norm * added layer norm back in encoder and decoder * added more code copy quality checks * update tests * Apply suggestions from code review * fix import * fix test Co-authored-by: patil-suraj <surajp815@gmail.com>
What does this PR do?
Before submitting
Pull Request section?
to it if that's the case. link of PR
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@patrickvonplaten
@patil-suraj