From 9a28c2f020640f996b51cacc5731a169343ed9bb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 18 Mar 2026 08:51:56 +0530 Subject: [PATCH 1/3] start fa4 support. --- src/diffusers/models/attention_dispatch.py | 38 ++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 5b1f831ed060..233d3c7d9a0d 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -229,6 +229,7 @@ class AttentionBackendName(str, Enum): FLASH_HUB = "flash_hub" FLASH_VARLEN = "flash_varlen" FLASH_VARLEN_HUB = "flash_varlen_hub" + FLASH_4_HUB = "flash_4_hub" _FLASH_3 = "_flash_3" _FLASH_VARLEN_3 = "_flash_varlen_3" _FLASH_3_HUB = "_flash_3_hub" @@ -358,6 +359,11 @@ class _HubKernelConfig: function_attr="sageattn", version=1, ), + AttentionBackendName.FLASH_4_HUB: _HubKernelConfig( + repo_id="kernels-staging/flash-attn4", + function_attr="flash_attn_func", + version=0, + ), } @@ -521,6 +527,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None AttentionBackendName._FLASH_3_HUB, AttentionBackendName._FLASH_3_VARLEN_HUB, AttentionBackendName.SAGE_HUB, + AttentionBackendName.FLASH_4_HUB, ]: if not is_kernels_available(): raise RuntimeError( @@ -2676,6 +2683,37 @@ def _flash_attention_3_varlen_hub( return (out, lse) if return_lse else out +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH_4_HUB, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=False, +) +def _flash_attention_4_hub( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + scale: float | None = None, + is_causal: bool = False, + return_lse: bool = False, + _parallel_config: "ParallelConfig" | None = None, +) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 4.") + + func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn + out = func( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=is_causal, + ) + if isinstance(out, tuple): + return (out[0], out[1]) if return_lse else out[0] + return out + + @_AttentionBackendRegistry.register( AttentionBackendName._FLASH_VARLEN_3, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], From ae76da7cdbe5dd09c2c651f044213c3695c31a8d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 18 Mar 2026 15:22:17 +0530 Subject: [PATCH 2/3] up --- docs/source/en/optimization/attention_backends.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/optimization/attention_backends.md b/docs/source/en/optimization/attention_backends.md index f3ff4781c6ec..6dab9a2b1f50 100644 --- a/docs/source/en/optimization/attention_backends.md +++ b/docs/source/en/optimization/attention_backends.md @@ -143,6 +143,7 @@ Refer to the table below for a complete list of available attention backends and | `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention | | `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels | | `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm | +| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 | | `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 | | `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 | | `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels | From aa6cf2a06ae9cf7b3ca27877b19acef447e79570 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 20 Mar 2026 16:11:46 +0530 Subject: [PATCH 3/3] specify minimum version --- src/diffusers/models/attention_dispatch.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 233d3c7d9a0d..c407f59037e6 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -538,6 +538,11 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`." ) + if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"): + raise RuntimeError( + f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`." + ) + elif backend == AttentionBackendName.AITER: if not _CAN_USE_AITER_ATTN: raise RuntimeError(