Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: is it possible to implement flash attention with keops #286

Closed
jaak-s opened this issue Jan 21, 2023 · 5 comments
Closed

Question: is it possible to implement flash attention with keops #286

jaak-s opened this issue Jan 21, 2023 · 5 comments
Assignees

Comments

@jaak-s
Copy link

jaak-s commented Jan 21, 2023

Hi, I'm new to the pykeops and was wonder if it would be possible implement flash attention, which used for removing the quadratic memory requirements on the sequence length:
https://github.com/HazyResearch/flash-attention

Basic idea is that one does not need to use N^2 memory because for each row of the attention matrix can be computed independently and then multiplied to the V (so the whole NxN matrix does not need to be stored).

Thanks!

@jeanfeydy
Copy link
Contributor

Hi @jaak-s,

Thanks for your interest in our library!
This is entirely do-able, since KeOps and FlashAttention both rely on the same core numerical scheme that was initially documented by Nvidia for N-body computations in physics. (By the way, KeOps is discussed in appendix D of the FlashAttention paper.)

I actually did this back in April 2021 in the branch "attention" with a plug-in replacement for the MultiheadAttention layer.
Benchmarks are available here.

Please note, however, that KeOps is not competitive to implement standard attention layers with attention heads of size > 16:

  • KeOps is optimized for generality. Our main motivations are mathematical statistics (K-NN methods, kernels, Gaussian processes, optimal transport...) and geometric deep learning (especially for 3D data). As a consequence, we have put a lot of emphasis on generic operators and tutorials, but have not (yet) added widespread support for low-precision numerical types (float16, etc.) and tensor cores, which only target linear-like computations.
  • The main limitation of KeOps is that it is not optimized for "formulas" that involve more than ~100 arithmetic operations. This is clear in the benchmarks linked above, with good KeOps performance primarily tied to the use of small attention heads.
  • On the other hand, FlashAttention has been designed for the NLP community: it only supports "one" operation (the standard attention layer) but does it extremely well, with attention paid to tensor cores, etc.

In this context, I think that KeOps may be of interest to people who want to experiment with "original" attention layers (as we sometimes do in geometric deep learning), but not really a competitive option for Natural Language Processing.
I hope that this answers your question!

If you would like to ask anything else, please let me know.

Best regards,
Jean

@jeanfeydy jeanfeydy self-assigned this Jan 21, 2023
@jaak-s
Copy link
Author

jaak-s commented Jan 21, 2023

Thanks a lot for the detailed answer!

@jaak-s jaak-s closed this as completed Jan 21, 2023
@jeanfeydy
Copy link
Contributor

You're very welcome - that's an important question in today's context :-)

@jaak-s
Copy link
Author

jaak-s commented Jan 21, 2023

Agreed, it is a hot topic.

Even though the current implementation of flash-attention is well optimized for NLP there are applications outside NLP that need slight modifications like relative position encodings or distance based biases (ALIBI), which are not yet supported (Dao-AILab/flash-attention#17).

With a keops-based implementation these changes feel like one-liner modifications and would make any customization quite straightforward :-).

@jeanfeydy
Copy link
Contributor

I see, thanks for the pointers :-)

We are not close enough to Transformer experts to implement competitive layers ourselves (I already have my hands full applying KeOps to anatomical data and drug consumption records!), but I'm more than happy to provide performance tips and/or include useful features to KeOps if this could help the "attention" community.

Our priorities for 2023 lay closer to transparent usage on generic hardware (100% compatible numpy interface, CPU support...) than to bleeding edge performance on Nvidia GPUs (with automated mixed precision, etc.), but these are certainly interesting research directions.

Best regards,
Jean

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants