-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Description
Now that we have an integration with the kernels
lib to use Flash Attention 3 (FA3), it'd be nice to gather community interest about which kernels we should try to incorporate in the library through the kernels
lib. FA3 delivers a significant speedup on Hopper GPUs.
I have done some work in the kernelize
branch to see if replacing GELU
, SiLU
, and RMSNorm
with their optimized kernels would have any speedups on Flux. So far, it hasn't had any. Benchmarking script: https://gist.github.com/sayakpaul/35236dd96e15d9f7d658a7ad11918411. One can compare the changes here: https://github.com/huggingface/diffusers/compare/kernelize?expand=1.
Note
The changes in the kernelize
branch are quite hacky as we're still evaluating things.
Please use this issue to let us know which kernels we should try to support in Diffusers. Some notes to keep in mind:
- Layers where the
forward()
method is easily replaceable with thekernelize()
mechanism would be prioritized. A reference is here: Add kernelize to transformers transformers#38205. - Even if a kernel isn't directly compatible with
kernels
, we can try to make it so, like we have for https://huggingface.co/kernels-community/flash-attn3. - Not all kernels contribute non-trivial gains in terms of speedup. So, please bear that in mind when proposing a kernel.
Cc: @MekkCyber