Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flash Attention 2] Add flash attention 2 for GPT-Neo-X #26463

Merged
merged 14 commits into from
Dec 6, 2023

Conversation

younesbelkada
Copy link
Contributor

What does this PR do?

Adds flash attention support for GPT-Neo-X

Fixes: #26444

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but left a few nits

Comment on lines 390 to 393
query = query.to(torch.float16)
key = key.to(torch.float16)
value = value.to(torch.float16)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should take into account bfloat16 here as well

Comment on lines 377 to 381
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this also true for GPTNeoX? (Comment is the same as Llama 😓 )

@huggingface huggingface deleted a comment from github-actions bot Oct 29, 2023
@btrude
Copy link

btrude commented Nov 13, 2023

Any plans on completing this or should someone else pick it up? For what it's worth, this implementation is working very well for me 👍

@younesbelkada
Copy link
Contributor Author

cc @amyeroberts let me know if I need to address anything else in this PR!

@avnermay
Copy link

Checking on the progress here. What's the ETA on merging this with the main branch? Thanks!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks for adding!

Just needs a performance example to be added to the docs before merging

@younesbelkada younesbelkada merged commit 9270ab0 into huggingface:main Dec 6, 2023
18 checks passed
@younesbelkada younesbelkada deleted the add-flash-neo-x branch December 6, 2023 16:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

add flush attention support model
6 participants