diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md index d86b9f37b25a..d77e72bb0f71 100644 --- a/docs/source/en/api/pipelines/hunyuan_video15.md +++ b/docs/source/en/api/pipelines/hunyuan_video15.md @@ -56,8 +56,8 @@ export_to_video(video, "output.mp4", fps=15) - HunyuanVideo1.5 use attention masks with variable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently. - - **H100/H800:** `_flash_3_hub` or `_flash_varlen_3` - - **A100/A800/RTX 4090:** `flash_hub` or `flash_varlen` + - **H100/H800:** `_flash_3_hub` or `_flash_3_varlen_hub` + - **A100/A800/RTX 4090:** `flash_hub` or `flash_varlen_hub` - **Other GPUs:** `sage_hub` Refer to the [Attention backends](../../optimization/attention_backends) guide for more details about using a different backend. diff --git a/docs/source/en/optimization/attention_backends.md b/docs/source/en/optimization/attention_backends.md index bce5a8adaee9..e640c4a5451a 100644 --- a/docs/source/en/optimization/attention_backends.md +++ b/docs/source/en/optimization/attention_backends.md @@ -141,10 +141,12 @@ Refer to the table below for a complete list of available attention backends and | `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 | | `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from kernels | | `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention | +| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels | | `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm | | `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 | | `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 | | `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels | +| `_flash_3_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 from kernels | | `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) | | `sage_hub` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) from kernels | | `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention | diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 3660e8d1d3ac..ffad94cc7f27 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -168,10 +168,11 @@ class AttentionBackendName(str, Enum): FLASH = "flash" FLASH_HUB = "flash_hub" FLASH_VARLEN = "flash_varlen" + FLASH_VARLEN_HUB = "flash_varlen_hub" _FLASH_3 = "_flash_3" _FLASH_VARLEN_3 = "_flash_varlen_3" _FLASH_3_HUB = "_flash_3_hub" - # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet. + _FLASH_3_VARLEN_HUB = "_flash_3_varlen_hub" # `aiter` AITER = "aiter" @@ -263,9 +264,17 @@ class _HubKernelConfig: AttentionBackendName._FLASH_3_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs" ), + AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig( + repo_id="kernels-community/flash-attn3", + function_attr="flash_attn_varlen_func", + # revision="fake-ops-return-probs", + ), AttentionBackendName.FLASH_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None ), + AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( + repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None + ), AttentionBackendName.SAGE_HUB: _HubKernelConfig( repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None ), @@ -425,8 +434,13 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." ) - # TODO: add support Hub variant of varlen later - elif backend in [AttentionBackendName._FLASH_3_HUB, AttentionBackendName.FLASH_HUB, AttentionBackendName.SAGE_HUB]: + elif backend in [ + AttentionBackendName.FLASH_HUB, + AttentionBackendName.FLASH_VARLEN_HUB, + AttentionBackendName._FLASH_3_HUB, + AttentionBackendName._FLASH_3_VARLEN_HUB, + AttentionBackendName.SAGE_HUB, + ]: if not is_kernels_available(): raise RuntimeError( f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." @@ -1387,6 +1401,63 @@ def _flash_attention_hub( return (out, lse) if return_lse else out +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH_VARLEN_HUB, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=False, +) +def _flash_varlen_attention_hub( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn + out = func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_attn_probs=return_lse, + ) + out = out.unflatten(0, (batch_size, -1)) + + return out + + @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_VARLEN, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], @@ -1509,6 +1580,60 @@ def _flash_attention_3_hub( return (out[0], out[1]) if return_attn_probs else out +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_3_VARLEN_HUB, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=False, +) +def _flash_attention_3_varlen_hub( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + is_causal: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn + out, lse, *_ = func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=scale, + causal=is_causal, + ) + out = out.unflatten(0, (batch_size, -1)) + + return (out, lse) if return_lse else out + + @_AttentionBackendRegistry.register( AttentionBackendName._FLASH_VARLEN_3, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],