Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Community contribution: Adding Flash Attention 2 support for more architectures #26350

Open
19 of 24 tasks
younesbelkada opened this issue Sep 22, 2023 · 103 comments · Fixed by #29226
Open
19 of 24 tasks

Community contribution: Adding Flash Attention 2 support for more architectures #26350

younesbelkada opened this issue Sep 22, 2023 · 103 comments · Fixed by #29226
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!

Comments

@younesbelkada
Copy link
Contributor

younesbelkada commented Sep 22, 2023

Feature request

Flash Attention 2 is a library that provides attention operation kernels for faster and more memory efficient inference and training: https://github.com/Dao-AILab/flash-attention

Screenshot 2023-09-22 at 17 49 18

Let's try to add Flash Attention 2 support for more architectures! Currently supported architectures are

  • Llama
  • Falcon

It would be great to add the support for more architectures such as

... and many more

Adding this feature would require to follow the same protocol as in #25598
. First create a new module inside the corresponding modeling file termed as xxxFlashAttention that inherits from xxxAttention and override the foward method to use the public methods from flash-attn. Make sure to have access to a GPU that supports Flash Attention 2.

Given the slight challenge of the issue, labelling it as a good second issue!

If you are interested to take up the challenge, comment below with the architecture name you want to integrate and open a PR!

Once you open a PR, feel free to ping @LysandreJik @ArthurZucker @amyeroberts @younesbelkada @fxmarty @SunMarc @pacman100 for a review

Motivation

Making LLMs more memory efficient and faster !

Your contribution

Reviewing PRs and possibly adding the support for more models

@younesbelkada younesbelkada added the Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! label Sep 22, 2023
@sahilbhosale63
Copy link
Contributor

sahilbhosale63 commented Sep 22, 2023

Hi @younesbelkada - I want to work on adding Flash Attention 2 support for GPTBigCode (Starcoder). Can I take this task? Can you please assign this task to me?

@flozi00
Copy link
Contributor

flozi00 commented Sep 22, 2023

Will definitely take a look next week
Great to see it merged now 💪

@rajveer43
Copy link
Contributor

I would like to work on MPT @younesbelkada

@susnato
Copy link
Contributor

susnato commented Sep 24, 2023

I would like to work on OPT.

@ZeusFSX
Copy link

ZeusFSX commented Sep 25, 2023

Is it possible to add FlashAttention2 to GPT2 models?

@younesbelkada
Copy link
Contributor Author

@sahilbhosale63 @flozi00 @rajveer43 @susnato thanks very much for your interest! Indeed it would be great if you could help us!
Before assigning you to this issue can you confirm you have access to a GPU that does support Flash Attention 2: https://github.com/Dao-AILab/flash-attention#installation-and-features in order to be able to run the tests ?
@ZeusFSX , yes I think that it is possible, I'll update the list accodingly

@rajveer43
Copy link
Contributor

@younesbelkada Yes I have

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Sep 25, 2023

OK perfect, I will assign you to MPT ! Feel free to let me know if you need any help or if you have any question, as a starting point, I would recommend to have a look at #25598 and see if you can replicate the PR for MPT. For running flash attention tests you can just run (once PR is ready):

RUN_SLOW=1 pytest -m flash_attn_test tests/models/mpt/

@susnato
Copy link
Contributor

susnato commented Sep 25, 2023

@younesbelkada yes I have.

@younesbelkada
Copy link
Contributor Author

Thanks @susnato , perfect then, let me know whenever you start the PR and if you have any question ! Check out my instructions above for more details

@sahilbhosale63
Copy link
Contributor

@younesbelkada Unfortunately, My GPU is not supported

@rajveer43
Copy link
Contributor

OK perfect, I will assign you to MPT ! Feel free to let me know if you need any help or if you have any question, as a starting point, I would recommend to have a look at #25598 and see if you can replicate the PR for MPT. For running flash attention tests you can just run (once PR is ready):

RUN_SLOW=1 pytest -m flash_attn_tests tests/models/mpt/

Sure I will work on it!

@jeromeku
Copy link

@younesbelkada Would like to work on Persimmon. I have access to A4000, A5000, and A6000, which I believe should be compatible with FA2.

@younesbelkada
Copy link
Contributor Author

Perfect sounds great, thanks for your help, I will assign you to Persimmon !

@susnato
Copy link
Contributor

susnato commented Sep 26, 2023

Since @sahilbhosale63 is not working on GPTBigCode (Starcoder)(as he said here) can I take that @younesbelkada?

@younesbelkada
Copy link
Contributor Author

Yes no problem, thanks very much for proposing your help on this ! As a starting point you can have a look at @pacman100 's implementation here: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/personal_copilot/training/starcoder_flash_attn_monkey_patch.py

@sorenmc
Copy link

sorenmc commented Sep 26, 2023

@younesbelkada I would like to implement it for BERT if it hasn't already been done? A lot of the models topping MTEB are still relying on this architecture! I have tested that i can run flash attention 2 on my nvidia geforce RTX 3060 TI.

@younesbelkada
Copy link
Contributor Author

Awesome, thanks a lot for your help, ok I will assign you to BERT then!

@DougTrajano
Copy link
Contributor

Hi everyone, I would like to help implement this with GPT2 if you want.

@jeromeku
Copy link

jeromeku commented Sep 27, 2023

@younesbelkada

I have a working version for Persimmon that passes the flash_attn_v2 tests except for generate_padding_right as the original PersimmonFlashAttention does not have padding_mask as a kw input (as opposed to the Llama and Falcon flash implementations). Is this something that needs to be changed in both Persimmon Flash v1 and v2?

Also, any plans on incorporating additional optimizations, e.g., Flash Attention repo has fused layers for dense, rotary, and layer norm for faster training; and Triton kernels, more generally? Happy to investigate more!

Also, would like to help with Mistral-7b (just released). They use xformers memory efficient attention in their released implementation but also mention Tri Dao's FA in the blogpost.

@younesbelkada
Copy link
Contributor Author

Hi @DougTrajano
Awesome! Can you confirm you have access to a hardware that is supported by FA-2?

Screenshot 2023-09-28 at 11 23 36

@jeromeku awesome thanks! Can you move forward for Persimmon by opening a PR so that I can have a look?

Also, any plans on incorporating additional optimizations, e.g., Flash Attention repo has fused layers for dense, rotary, and layer norm for faster training; and Triton kernels, more generally? Happy to investigate more!

If that is something that can nicely fit into the API without any breaking behaviour that would be great !

Also, would like to help with Mistral-7b (just released). They use xformers memory efficient attention in their released implementation but also mention Tri Dao's FA in the blogpost.

I think Mistral's attention has been released in the latest version of FA-2 --> Would you be happy to open a PoC PR so that I can play with it and see what we can do?

Again thanks a lot!

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Sep 28, 2023

Hi @jeromeku
I had to check internally for Mistral, given the very recent release and the urgency, we'll take this over (#26464); if you have started a PR, I'm very happy to start from it or to add you as a co-author to the PR !
We might also refactor things a bit to support Local attention introduced by Mistral, so that needs further investigation, I'll keep you posted

@rajveer43
Copy link
Contributor

@younesbelkada what is the expected deadline to complete MPT, I have other issues to tackle on so I can plan accordingly

@susnato
Copy link
Contributor

susnato commented Sep 28, 2023

Hi @younesbelkada , I am talking this up for GPT-neo.

@younesbelkada
Copy link
Contributor Author

Awesome @susnato ! Thanks !
@rajveer43 thanks for taking up MPT, will check it out!

@b-albar
Copy link

b-albar commented Feb 21, 2024

@b-albar, I'm intrigued by your statement:

(I have a custom optimized implementation of T5 using this patch and few other tricks).

That's not open source by any chance, is it? 🥺

@Taytay It is now : https://github.com/catie-aq/flashT5
I mention it here because if the PR with the biases for FA2 finally gets merged 🤞 , it can probably help to make the PR for the T5 here.

@EduardoPach
Copy link
Contributor

Working on GPT2 here is the PR #29226

@jprivera44
Copy link
Contributor

Hello @younesbelkada i can take on GPT-2 if no one else taken it?

@EduardoPach
Copy link
Contributor

Hello @younesbelkada i can take on GPT-2 if no one else taken it?

There's already an open PR here #29226

@jprivera44
Copy link
Contributor

Hey thank you for letting me know on the existing PR @EduardoPach , @younesbelkada & @ArthurZucker are there any available architectures left to implement? I see there are "many more" that are not included on this list. Can you give me an example of one?

@jprivera44
Copy link
Contributor

It seems that T5 is still open?

@rubenweitzman
Copy link

@jprivera44 yes it seems so, cannot find any FA/sdpa version. Would be great if you could get that working. I am also looking at integrating sdpa into ESM @ddofer or others would love your help if there is interest!

@William-WANG2
Copy link

Hello! My classmate @DELTA-DoubleWise and I are trying to write a project proposal for a course and we would like to extend this to RAG model (which is not listed above). Would you mind assigning this to us?
Thank you very much! @EduardoPach, @younesbelkada & @ArthurZucker

@ddofer
Copy link

ddofer commented Mar 7, 2024

@rubenweitzman I wish I could help, but I'm only familiar with keras, not pytorch :\

@ArthurZucker
Copy link
Collaborator

Hey @William-WANG2, we usually don't assign issues, and rather let the code talk: if a PR is open and pinned then that means someone is working on something and the entire community can check the progress 😉 we try to prevent work duplication with this!

@sayakpaul
Copy link
Member

Does adding FA2 to CLIP make any sense?

@amyeroberts
Copy link
Collaborator

@sayakpaul Yes! It would be great to add - had a draft in #27444 but was hitting some large differences on just some of the elements which I wasn't able to track. I don't have bandwidth atm so very happy for anyone to take it up!

@miladm
Copy link

miladm commented Apr 22, 2024

Integration question:
I'd love to see this implementation be available for TPU backends. Is there anything we can leverage from this line of work (out of the box), or shall TorchXLA focus on a Pallas implementation under optimum-tpu?

cc @philschmid @alanwaketan @shauheen @allenwang28

@philschmid
Copy link
Member

Adding @mfuntowicz @tengomucho to answer this

@tengomucho
Copy link

@miladm for now there is a check on models to see if flash attention is available, but flash_attn is implemented in cuda. To allow it to run smoothly on TPU, IMO the cleanest option would be to contribute to flash_attn to implement a Torch XLA (or pallas) alternative when available. Another alternative would be to implement it directly on optimum-tpu, but that would mean we would need to patch models to use that. If you choose the second path, I will be happy to help you to integrate your contribution.

@Ingvarstep
Copy link

For anyone who is interested in optimized T5 version, I just finished my project on creating flash attention version with fused attention bias calculation. It allows to fix the major drawbacks of T5 and allow to run it on 100k sequences on single L4 GPU (22.5 GB). Check it here.

@EduardoPach
Copy link
Contributor

Can we update the current state of the list of models in the issue description? For instance, GPT2 has already been merged

@michaelshekasta
Copy link

Hi @LysandreJik @ArthurZucker @amyeroberts @younesbelkada @fxmarty @SunMarc @pacman100 ,

I want to try to add flash attention 2 to xlmr-large.

Do you have any guidelines?

@amyeroberts
Copy link
Collaborator

@michaelshekasta I'd recommend referring to other PRs which have added this for models e.g. for GPT2 and reading the contributing guidelines

@michaelshekasta
Copy link

@michaelshekasta I'd recommend referring to other PRs which have added this for models e.g. for GPT2 and reading the contributing guidelines

@amyeroberts Thanks for your comment! I noticed that @DavidAfonsoValente has already implemented the majority of the code. You can find their work on this pull request: #28713. What are the differences that still need to be addressed before merging it?

@amyeroberts
Copy link
Collaborator

@michaelshekasta Following the PR history, I don't believe there was much more to add, there was just a dependency on another piece of work and the PR eventually became inactive

@davidgxue
Copy link
Contributor

Has someone worked on FA2 to T5? I see someone has an SDPA support for T5 PR open (#30375) that is almost done. Is there still a point in adding FA2 for T5?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!
Projects
None yet
Development

Successfully merging a pull request may close this issue.