-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A ring attention with flash attention kernel implementation #4
Comments
@zhuzilin hey Zilin! this looks like a good start, and what I intended to do at the very end! i was imagining that the ring communication could be done within CUDA using IPC? (however, I am far from CUDA expert, so I could be wrong and it is not possible) Are you planning on upstreaming the finalized implementation to Tri Dao's official flash attention repository? That would be a big contribution! |
@zhuzilin if you do embark on the pull request, the minimal features would be the ring IPC, able to specify the maximum number of ring passes (as I believe they must have curriculum learned the local attention to a full global, or mixed local and global using variable ring passes throughout the transformer), and finally, if you have the bandwidth, specialize masking logic for striped autoregressive attention to balance the workload |
ring_flash_attn_qkvpacked_func
implementation
@zhuzilin actually, after looking into CUDA IPC stuff, your approach may be the best for now |
I'll draft an issue to the flash attention repo to see if they have interest in upstreaming (or designing a better version) in the official repo :)
yeah, using nccl based p2p communication would be at least an easier way to implement with acceptable performance. |
@zhuzilin awesome work, we‘ll organize a little hack today 19:00 UTC on the cuda-mode discord to hack on your impl (do some testing, benchmarking and discussion about best comms options for single node and multi node etc.) - just fyi https://x.com/neurosp1ke/status/1760558683136589983 |
Germany, Beijing, San Francisco only in open source (and science) |
i also wanted to do some LOTR references, but one meme is enough |
oh... sorry, I took a day off and missed all the notification from github.... |
@zhuzilin i think my version is working too now, with a modified forward flash attention kernel to minimize ring passes thanks for sharing your repo for proof of concept! |
@lucidrains thanks a lot for your hard work & very interesting that you used a custom triton kernel! :-) |
@andreaskoepf thanks! seems like there's still issue with backwards, but i'll leave it to someone or some team to fix. yup, i think the forwards requires the key, values to be iterated on the outer loop (to save on extraneous ring passes), so the reduced outputs, row maxes, lse needs to be stored and passed back in on the next ring pass. but i could be wrong and there may be a simpler way |
@lucidrains What is the issue you're referring to with the backward pass? |
it isn't correct, probably something small with regards to how i'm using the flash attention api feel free to submit a PR, i likely won't be able to get to this as i'll be running around bay area meeting people next month |
@ericauld ah, good news, the cuda backwards actually yielded the right gradients (full attention, no causal or key padding mask). it is my naive version that is broken alright, i guess it is safe to remove the wip |
@lucidrains Knowing the LSE doesn't actually help you compute the backwards for softmax though, correct? The derivative of LSE is softmax, not the other way around. What am I missing, and what is the utility of returning the LSE? |
The returned log sum exp is what allows to apply flash-attenion in a blockwise manner (e.g. without it it wouldn't be possible to use flash-attn to implement ring-attn). See ring_flash_attn/utils.py#L19-L21. |
Ah, alright. That's what I'm missing. Makes sense :) |
Hi! Thank you for your work on implementing the ring attention in pytorch!
I've just tried to implement a
ring_flash_attn_qkvpacked_func
(corresponding toflash_attn_qkvpacked_func
in flash attention) with the flash attention kernels here: https://github.com/zhuzilin/ring-flash-attention/Maybe this can help :)
Updates:
ring_flash_attn_varlen_qkvpacked_func
is also implemented.The text was updated successfully, but these errors were encountered: