-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flash Attention 2
] Add flash attention 2 for GPT-Neo-X
#26463
[Flash Attention 2
] Add flash attention 2 for GPT-Neo-X
#26463
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but left a few nits
query = query.to(torch.float16) | ||
key = key.to(torch.float16) | ||
value = value.to(torch.float16) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should take into account bfloat16
here as well
# In PEFT, usually we cast the layer norms in float32 for training stability reasons | ||
# therefore the input hidden states gets silently casted in float32. Hence, we need | ||
# cast them back in float16 just to be sure everything works as expected. | ||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms | ||
# in fp32. (LlamaRMSNorm handles it correctly) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this also true for GPTNeoX? (Comment is the same as Llama 😓 )
Any plans on completing this or should someone else pick it up? For what it's worth, this implementation is working very well for me 👍 |
cc @amyeroberts let me know if I need to address anything else in this PR! |
Checking on the progress here. What's the ETA on merging this with the main branch? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thanks for adding!
Just needs a performance example to be added to the docs before merging
What does this PR do?
Adds flash attention support for GPT-Neo-X
Fixes: #26444