Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A ring attention with flash attention kernel implementation #4

Closed
zhuzilin opened this issue Feb 21, 2024 · 19 comments
Closed

A ring attention with flash attention kernel implementation #4

zhuzilin opened this issue Feb 21, 2024 · 19 comments

Comments

@zhuzilin
Copy link

zhuzilin commented Feb 21, 2024

Hi! Thank you for your work on implementing the ring attention in pytorch!

I've just tried to implement a ring_flash_attn_qkvpacked_func (corresponding to flash_attn_qkvpacked_func in flash attention) with the flash attention kernels here: https://github.com/zhuzilin/ring-flash-attention/

Maybe this can help :)


Updates:

  • ring_flash_attn_varlen_qkvpacked_func is also implemented.
@lucidrains
Copy link
Owner

lucidrains commented Feb 21, 2024

@zhuzilin hey Zilin! this looks like a good start, and what I intended to do at the very end! i was imagining that the ring communication could be done within CUDA using IPC? (however, I am far from CUDA expert, so I could be wrong and it is not possible) Are you planning on upstreaming the finalized implementation to Tri Dao's official flash attention repository? That would be a big contribution!

@lucidrains
Copy link
Owner

lucidrains commented Feb 21, 2024

@zhuzilin if you do embark on the pull request, the minimal features would be the ring IPC, able to specify the maximum number of ring passes (as I believe they must have curriculum learned the local attention to a full global, or mixed local and global using variable ring passes throughout the transformer), and finally, if you have the bandwidth, specialize masking logic for striped autoregressive attention to balance the workload

@zhuzilin zhuzilin changed the title A ring_flash_attn_qkvpacked_func implementation A ring attention with flash attention kernel implementation Feb 21, 2024
@lucidrains
Copy link
Owner

Screen Shot 2024-02-21 at 7 11 49 AM

thank you! 🚀 ❤️

@lucidrains
Copy link
Owner

@zhuzilin actually, after looking into CUDA IPC stuff, your approach may be the best for now

@zhuzilin
Copy link
Author

zhuzilin commented Feb 22, 2024

Are you planning on upstreaming the finalized implementation to Tri Dao's official flash attention repository?

I'll draft an issue to the flash attention repo to see if they have interest in upstreaming (or designing a better version) in the official repo :)

after looking into CUDA IPC stuff, your approach may be the best for now

yeah, using nccl based p2p communication would be at least an easier way to implement with acceptable performance.

@andreaskoepf
Copy link

andreaskoepf commented Feb 22, 2024

@zhuzilin awesome work, we‘ll organize a little hack today 19:00 UTC on the cuda-mode discord to hack on your impl (do some testing, benchmarking and discussion about best comms options for single node and multi node etc.) - just fyi https://x.com/neurosp1ke/status/1760558683136589983

@lucidrains
Copy link
Owner

lucidrains commented Feb 22, 2024

Germany, Beijing, San Francisco

only in open source (and science)

@lucidrains
Copy link
Owner

8gp0cg

@lucidrains
Copy link
Owner

i also wanted to do some LOTR references, but one meme is enough

@zhuzilin
Copy link
Author

oh... sorry, I took a day off and missed all the notification from github....

@lucidrains
Copy link
Owner

@zhuzilin i think my version is working too now, with a modified forward flash attention kernel to minimize ring passes

thanks for sharing your repo for proof of concept!

@andreaskoepf
Copy link

@lucidrains thanks a lot for your hard work & very interesting that you used a custom triton kernel! :-)

@lucidrains
Copy link
Owner

lucidrains commented Feb 28, 2024

@andreaskoepf thanks! seems like there's still issue with backwards, but i'll leave it to someone or some team to fix. yup, i think the forwards requires the key, values to be iterated on the outer loop (to save on extraneous ring passes), so the reduced outputs, row maxes, lse needs to be stored and passed back in on the next ring pass. but i could be wrong and there may be a simpler way

@ericauld
Copy link

ericauld commented Feb 29, 2024

@lucidrains What is the issue you're referring to with the backward pass?

@lucidrains
Copy link
Owner

it isn't correct, probably something small with regards to how i'm using the flash attention api

feel free to submit a PR, i likely won't be able to get to this as i'll be running around bay area meeting people next month

@lucidrains
Copy link
Owner

@ericauld ah, good news, the cuda backwards actually yielded the right gradients (full attention, no causal or key padding mask). it is my naive version that is broken

alright, i guess it is safe to remove the wip

@apaz-cli
Copy link

apaz-cli commented Mar 4, 2024

@lucidrains Knowing the LSE doesn't actually help you compute the backwards for softmax though, correct? The derivative of LSE is softmax, not the other way around. What am I missing, and what is the utility of returning the LSE?

@andreaskoepf
Copy link

What am I missing, and what is the utility of returning the LSE?

The returned log sum exp is what allows to apply flash-attenion in a blockwise manner (e.g. without it it wouldn't be possible to use flash-attn to implement ring-attn). See ring_flash_attn/utils.py#L19-L21.

@apaz-cli
Copy link

apaz-cli commented Mar 4, 2024

Ah, alright. That's what I'm missing. Makes sense :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants