Skip to content

Commit d176f61

Browse files
authored
[core] support sage attention + FA2 through kernels (#12439)
* up * support automatic dispatch. * disable compile support for now./ * up * flash too. * document. * up * up * up * up
1 parent 354d35a commit d176f61

File tree

3 files changed

+91
-10
lines changed

3 files changed

+91
-10
lines changed

docs/source/en/optimization/attention_backends.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,14 @@ Refer to the table below for a complete list of available attention backends and
139139
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
140140
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
141141
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
142+
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from kernels |
142143
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
143144
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
144145
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
145146
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
146147
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
147148
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
149+
| `sage_hub` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) from kernels |
148150
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
149151
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
150152
| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |

src/diffusers/models/attention_dispatch.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import math
1919
from dataclasses import dataclass
2020
from enum import Enum
21-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
21+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
2222

2323
import torch
2424

@@ -160,16 +160,13 @@ def wrap(func):
160160
# - CP with sage attention, flex, xformers, other missing backends
161161
# - Add support for normal and CP training with backends that don't support it yet
162162

163-
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
164-
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
165-
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
166-
167163

168164
class AttentionBackendName(str, Enum):
169165
# EAGER = "eager"
170166

171167
# `flash-attn`
172168
FLASH = "flash"
169+
FLASH_HUB = "flash_hub"
173170
FLASH_VARLEN = "flash_varlen"
174171
_FLASH_3 = "_flash_3"
175172
_FLASH_VARLEN_3 = "_flash_varlen_3"
@@ -191,6 +188,7 @@ class AttentionBackendName(str, Enum):
191188

192189
# `sageattention`
193190
SAGE = "sage"
191+
SAGE_HUB = "sage_hub"
194192
SAGE_VARLEN = "sage_varlen"
195193
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
196194
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
@@ -264,7 +262,13 @@ class _HubKernelConfig:
264262
# TODO: temporary revision for now. Remove when merged upstream into `main`.
265263
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
266264
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
267-
)
265+
),
266+
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
267+
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
268+
),
269+
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
270+
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
271+
),
268272
}
269273

270274

@@ -420,8 +424,8 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
420424
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
421425
)
422426

423-
# TODO: add support Hub variant of FA3 varlen later
424-
elif backend in [AttentionBackendName._FLASH_3_HUB]:
427+
# TODO: add support Hub variant of varlen later
428+
elif backend in [AttentionBackendName._FLASH_3_HUB, AttentionBackendName.FLASH_HUB, AttentionBackendName.SAGE_HUB]:
425429
if not is_kernels_available():
426430
raise RuntimeError(
427431
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
@@ -1350,6 +1354,38 @@ def _flash_attention(
13501354
return (out, lse) if return_lse else out
13511355

13521356

1357+
@_AttentionBackendRegistry.register(
1358+
AttentionBackendName.FLASH_HUB,
1359+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1360+
supports_context_parallel=False,
1361+
)
1362+
def _flash_attention_hub(
1363+
query: torch.Tensor,
1364+
key: torch.Tensor,
1365+
value: torch.Tensor,
1366+
dropout_p: float = 0.0,
1367+
is_causal: bool = False,
1368+
scale: Optional[float] = None,
1369+
return_lse: bool = False,
1370+
_parallel_config: Optional["ParallelConfig"] = None,
1371+
) -> torch.Tensor:
1372+
lse = None
1373+
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
1374+
out = func(
1375+
q=query,
1376+
k=key,
1377+
v=value,
1378+
dropout_p=dropout_p,
1379+
softmax_scale=scale,
1380+
causal=is_causal,
1381+
return_attn_probs=return_lse,
1382+
)
1383+
if return_lse:
1384+
out, lse, *_ = out
1385+
1386+
return (out, lse) if return_lse else out
1387+
1388+
13531389
@_AttentionBackendRegistry.register(
13541390
AttentionBackendName.FLASH_VARLEN,
13551391
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
@@ -1431,6 +1467,7 @@ def _flash_attention_3(
14311467
@_AttentionBackendRegistry.register(
14321468
AttentionBackendName._FLASH_3_HUB,
14331469
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1470+
supports_context_parallel=False,
14341471
)
14351472
def _flash_attention_3_hub(
14361473
query: torch.Tensor,
@@ -1444,6 +1481,9 @@ def _flash_attention_3_hub(
14441481
return_attn_probs: bool = False,
14451482
_parallel_config: Optional["ParallelConfig"] = None,
14461483
) -> torch.Tensor:
1484+
if _parallel_config:
1485+
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
1486+
14471487
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
14481488
out = func(
14491489
q=query,
@@ -1938,6 +1978,38 @@ def _sage_attention(
19381978
return (out, lse) if return_lse else out
19391979

19401980

1981+
@_AttentionBackendRegistry.register(
1982+
AttentionBackendName.SAGE_HUB,
1983+
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1984+
supports_context_parallel=False,
1985+
)
1986+
def _sage_attention_hub(
1987+
query: torch.Tensor,
1988+
key: torch.Tensor,
1989+
value: torch.Tensor,
1990+
is_causal: bool = False,
1991+
scale: Optional[float] = None,
1992+
return_lse: bool = False,
1993+
_parallel_config: Optional["ParallelConfig"] = None,
1994+
) -> torch.Tensor:
1995+
lse = None
1996+
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
1997+
if _parallel_config is None:
1998+
out = func(
1999+
q=query,
2000+
k=key,
2001+
v=value,
2002+
tensor_layout="NHD",
2003+
is_causal=is_causal,
2004+
sm_scale=scale,
2005+
return_lse=return_lse,
2006+
)
2007+
if return_lse:
2008+
out, lse, *_ = out
2009+
2010+
return (out, lse) if return_lse else out
2011+
2012+
19412013
@_AttentionBackendRegistry.register(
19422014
AttentionBackendName.SAGE_VARLEN,
19432015
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],

tests/others/test_attention_backends.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434

3535
# fmt: off
3636
FORWARD_CASES = [
37-
("flash_hub", None),
37+
(
38+
"flash_hub",
39+
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
40+
),
3841
(
3942
"_flash_3_hub",
4043
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
@@ -54,7 +57,11 @@
5457
]
5558

5659
COMPILE_CASES = [
57-
("flash_hub", None, True),
60+
(
61+
"flash_hub",
62+
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
63+
True
64+
),
5865
(
5966
"_flash_3_hub",
6067
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),

0 commit comments

Comments
 (0)