Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add a fast implementation of Rabe and Staats algorigthm (mem efficient attention) on GPU #161

Closed
blefaudeux opened this issue Dec 21, 2021 · 3 comments 路 Fixed by #267 or #281
Assignees
Labels

Comments

@blefaudeux
Copy link
Contributor

馃殌 Feature

Implement https://arxiv.org/pdf/2112.05682v2.pdf using Triton

Motivation

There are existing implementations in Pytorch, but they re bound to be a little slow. It s actually not that much work to write that down in Triton, give it a shot. Given the FW speed (should be similar to normal attention, without the memory) and the expected BW speed (about 60% of the vanilla attention), feels like a compromise that many would use

Pitch

The required kernel is actually not that far from some of the kernels that we already have, at least for the FW. The chunk strategy proposed by the paper is actually fairly classic in that field, nothing out of the ordinary (see for instance), so it's bound to be pretty fast if correctly implemented.

Alternatives

At least support a pure pytorch variant in xformers ?

@erip
Copy link
Contributor

erip commented Dec 21, 2021

Another reference impl can be found here -- same caveats as outlined above.

@blefaudeux
Copy link
Contributor Author

I've started something, it feels like some of the logic would need to be changed a bit for that to make sense at a kernel level, at least for triton. In particular it's hard to sequence things outside of a kernel, and reproducing the same logic as the one from the paper would lead to big buffers (if the computation is tiled), which diminish the interest a lot. It feels like the best approach is with a kernel owning the whole line, and a couple of rows at a time to help with data fetch reuse

@blefaudeux
Copy link
Contributor Author

Another reference impl can be found here -- same caveats as outlined above.

thanks !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment