-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Docs
/ BetterTransformer
] Added more details about flash attention + SDPA
#25265
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding these additional details! 😄
docs/source/en/perf_infer_gpu_one.md
Outdated
As of PyTorch 2.0, the attention fastpath is supported for both encoders and decoders. The list of supported architectures can be found [here](https://huggingface.co/docs/optimum/bettertransformer/overview#supported-models). | ||
|
||
For decoder-based models (e.g. GPT, T5, Llama, etc.), the `BetterTransformer` API will convert all attention operations to use the [`torch.nn.functional.scaled_dot_product_attention` method](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) (SDPA), that is available only from PyTorch 2.0 and onwards. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comments for the rest of this section as in perf_infer_gpu_many.md
(you can probably copy the changes over) :)
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Thanks a lot for the extensive review @stevhliu ! 🎉 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, that is much better.
I'll release on Optimum side to include huggingface/optimum#1225 that allows training with encoder models + SDPA as well.
It could be worth noting that a few models (Falcon, M4) start to have native SDPA support in transformers (but they may not dispatch to flash), see these discussions:
For encoder models, the [`~PreTrainedModel.reverse_bettertransformer`] method reverts to the original model, which should be used before saving the model to use the canonical transformers modeling: | ||
|
||
```python | ||
model = model.reverse_bettertransformer() | ||
model.save_pretrained("saved_model") | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should not make the distinction between encoder / decoder models when it come to using reverse_bettertransformer
.
For example, for encoder-decoder models (e.g. t5), both SDPA (in the decoder) and nestedtensor (in the encoder) are used. So in case one wants to save the model, he'll need to use reverse_bettertransformer
.
To me the distinction is more in that you can get speedups for inference with encoder models (since nestedtensor is used), but for decoder models the speedup / dispatch to flash will only come (in pytorch 2.0) for training & batch size = 1 for inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! I refactored a bit that section and removed the reverse_bettertransformer
part as it is relevant only for training (that section is for inference only)
# Use it for training or inference | ||
``` | ||
|
||
SDPA can also call [Flash-Attention](https://arxiv.org/abs/2205.14135) kernels under the hood. If you want to force the usage of Flash Attention, use [`torch.backends.cuda.sdp_kernel(enable_flash=True)`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.backends.cuda.sdp_kernel(enable_flash=True)
is not enough. You need torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False
as below
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome! I added some minor comments to make it a bit easier to read, and if you could also copy the changes from perf_infer_gpu_many
to their corresponding sections in perf_infer_gpu_one
that'd be great 🤗
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! 🚀
…ion + SDPA (huggingface#25265) * added more details about flash attention * correct and add more details * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * few modifs * more details * up * Apply suggestions from code review Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * adapt from suggestion * Apply suggestions from code review Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * trigger CI * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fix nits and copies * add new section --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
What does this PR do?
as discussed offline with @LysandreJik
This PR clarifies to users how it is possible to use Flash Attention as a backend for most used models in transformers. As we have a seen some questions from users asking whether it is possible to integrate flash attention into HF models, whereas you can already benefit from it when using
model.to_bettertransformer()
, leveraging theBetterTransformer
API from 🤗 optimum.The informations are based from the official documentation of
torch.nn.functional.scaled_dot_product
In the near future, we could also have a small blogpost explaining this as well
To do list / To clarify list:
Let me know if I missed anything else
cc @fxmarty @MKhalusova @stevhliu