From 6900b293fc3f971a7e3d04bcec9c2a7626a586d9 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 22:57:57 +0800 Subject: [PATCH 01/29] Adds backend auto-selection API Exposes backend availability flags to let callers probe supported runtimes without import errors. Provides auto-selection helper to fall back to the first available backend for attention execution. --- flash_sparse_attn/__init__.py | 97 +++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 flash_sparse_attn/__init__.py diff --git a/flash_sparse_attn/__init__.py b/flash_sparse_attn/__init__.py new file mode 100644 index 0000000..5f5c536 --- /dev/null +++ b/flash_sparse_attn/__init__.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025, Jingze Shi. + +from typing import Optional + +__version__ = "1.2.3" + + +# Import CUDA functions when available +try: + from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func, flash_sparse_attn_varlen_func + CUDA_AVAILABLE = True +except ImportError: + CUDA_AVAILABLE = False + flash_sparse_attn_func, flash_sparse_attn_varlen_func = None, None + +# Import Triton functions when available +try: + from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func + TRITON_AVAILABLE = True +except ImportError: + TRITON_AVAILABLE = False + triton_sparse_attn_func = None + +# Import Flex functions when available +try: + from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func + FLEX_AVAILABLE = True +except ImportError: + FLEX_AVAILABLE = False + flex_sparse_attn_func = None + + +def get_available_backends(): + """Return a list of available backends.""" + backends = [] + if CUDA_AVAILABLE: + backends.append("cuda") + if TRITON_AVAILABLE: + backends.append("triton") + if FLEX_AVAILABLE: + backends.append("flex") + return backends + + +def flash_sparse_attn_func_auto(backend: Optional[str] = None, **kwargs): + """ + Flash Dynamic Mask Attention function with automatic backend selection. + + Args: + backend (str, optional): Backend to use ('cuda', 'triton', 'flex'). + If None, will use the first available backend in order: cuda, triton, flex. + **kwargs: Arguments to pass to the attention function. + + Returns: + The attention function for the specified or auto-selected backend. + """ + if backend is None: + # Auto-select backend + if CUDA_AVAILABLE: + backend = "cuda" + elif TRITON_AVAILABLE: + backend = "triton" + elif FLEX_AVAILABLE: + backend = "flex" + else: + raise RuntimeError("No flash attention backend is available. Please install at least one of: triton, transformers, or build the CUDA extension.") + + if backend == "cuda": + if not CUDA_AVAILABLE: + raise RuntimeError("CUDA backend is not available. Please build the CUDA extension.") + return flash_sparse_attn_func + + elif backend == "triton": + if not TRITON_AVAILABLE: + raise RuntimeError("Triton backend is not available. Please install triton: pip install triton") + return triton_sparse_attn_func + + elif backend == "flex": + if not FLEX_AVAILABLE: + raise RuntimeError("Flex backend is not available. Please install transformers: pip install transformers") + return flex_sparse_attn_func + + else: + raise ValueError(f"Unknown backend: {backend}. Available backends: {get_available_backends()}") + + +__all__ = [ + "CUDA_AVAILABLE", + "TRITON_AVAILABLE", + "FLEX_AVAILABLE", + "flash_sparse_attn_func", + "flash_sparse_attn_varlen_func", + "triton_sparse_attn_func", + "flex_sparse_attn_func", + "get_available_backends", + "flash_sparse_attn_func_auto", +] From 71add0060e77b8bcdc6cb6284ee0f605639ae8fc Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 22:58:47 +0800 Subject: [PATCH 02/29] Fix docstring for flash_sparse_attn_func_auto to reflect correct function name --- flash_sparse_attn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_sparse_attn/__init__.py b/flash_sparse_attn/__init__.py index 5f5c536..23309da 100644 --- a/flash_sparse_attn/__init__.py +++ b/flash_sparse_attn/__init__.py @@ -44,7 +44,7 @@ def get_available_backends(): def flash_sparse_attn_func_auto(backend: Optional[str] = None, **kwargs): """ - Flash Dynamic Mask Attention function with automatic backend selection. + Flash Sparse Attention function with automatic backend selection. Args: backend (str, optional): Backend to use ('cuda', 'triton', 'flex'). From e02668c0b47e6db62c33857ec34051ec860b385b Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 22:59:27 +0800 Subject: [PATCH 03/29] Adds Flex flash-sparse attention hook Introduces a Flex Attention forward path that constructs causal block masks, normalizes mask and bias defaults, and applies compile-friendly kernel options to ease sparse Flash workloads. --- flash_sparse_attn/flash_sparse_attn_flex.py | 80 +++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 flash_sparse_attn/flash_sparse_attn_flex.py diff --git a/flash_sparse_attn/flash_sparse_attn_flex.py b/flash_sparse_attn/flash_sparse_attn_flex.py new file mode 100644 index 0000000..055c2c8 --- /dev/null +++ b/flash_sparse_attn/flash_sparse_attn_flex.py @@ -0,0 +1,80 @@ +from typing import Optional, Tuple +import math +import torch +from torch.nn.attention.flex_attention import create_block_mask +from transformers.integrations.flex_attention import compile_friendly_flex_attention + + +def flex_attention_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + attn_bias: Optional[torch.Tensor] = None, + is_causal: Optional[bool] = None, + softmax_scale: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + batch, seqlen_q, nheads, dhead = query.shape + _, seqlen_k, _, _ = key.shape + query = query.transpose(1, 2).contiguous() # [B, H, Q_LEN, D] + key = key.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] + value = value.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] + if attn_mask is not None: + attn_mask = attn_mask[:, :, :, : key.shape[-2]] + else: + attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) + if attn_bias is not None: + attn_bias = attn_bias[:, :, :, : key.shape[-2]] + else: + attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) + if is_causal is None: + is_causal = True + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(dhead) + + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx] + return score + + def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + # It looks like you're attempting to use a Tensor in some data-dependent control flow. + # We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 . + # return q_idx >= kv_idx and attn_mask[batch_idx][head_idx][q_idx][kv_idx] > 0 + return q_idx >= kv_idx + + block_mask = create_block_mask( + mask_mod=causal_mask_mod, + B=query.shape[0], + H=None, + Q_LEN=query.shape[2], + KV_LEN=key.shape[2], + device=query.device, + _compile=True, + ) + + kernel_options = { + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_DMODEL": 32, + "num_stages": 1, + "num_warps": 8, + } + attn_output = compile_friendly_flex_attention( + query, + key, + value, + score_mod=score_mod, + block_mask=block_mask if is_causal else None, + scale=softmax_scale, + kernel_options=kernel_options, + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. + return_lse=False, + training=False, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output + +flex_sparse_attn_func = flex_attention_forward \ No newline at end of file From 508d2d18badd2aba880f870acece9f1dc54f326f Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 22:59:40 +0800 Subject: [PATCH 04/29] Remove unused files related to Flash Dynamic Mask Attention integration - Deleted `modeling_flash_dynamic_mask_attention_utils.py` as it contained redundant code and was not being utilized. - Removed `mask.py` and `padding.py` files which were not necessary for the current implementation, streamlining the codebase. --- flash_dmattn/__init__.py | 97 -- flash_dmattn/flash_dmattn_flex.py | 80 -- flash_dmattn/flash_dmattn_interface.py | 760 ---------- flash_dmattn/flash_dmattn_triton.py | 1246 ----------------- flash_dmattn/flash_dmattn_triton_special.py | 1244 ---------------- .../flash_dynamic_mask_attention.py | 111 -- flash_dmattn/integrations/import_utils.py | 95 -- ...ling_flash_dynamic_mask_attention_utils.py | 597 -------- flash_dmattn/utils/mask.py | 240 ---- flash_dmattn/utils/padding.py | 170 --- 10 files changed, 4640 deletions(-) delete mode 100644 flash_dmattn/__init__.py delete mode 100644 flash_dmattn/flash_dmattn_flex.py delete mode 100644 flash_dmattn/flash_dmattn_interface.py delete mode 100644 flash_dmattn/flash_dmattn_triton.py delete mode 100644 flash_dmattn/flash_dmattn_triton_special.py delete mode 100644 flash_dmattn/integrations/flash_dynamic_mask_attention.py delete mode 100644 flash_dmattn/integrations/import_utils.py delete mode 100644 flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py delete mode 100644 flash_dmattn/utils/mask.py delete mode 100644 flash_dmattn/utils/padding.py diff --git a/flash_dmattn/__init__.py b/flash_dmattn/__init__.py deleted file mode 100644 index 484e8d1..0000000 --- a/flash_dmattn/__init__.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) 2025, Jingze Shi. - -from typing import Optional - -__version__ = "1.2.2" - - -# Import CUDA functions when available -try: - from flash_dmattn.flash_dmattn_interface import flash_dmattn_func, flash_dmattn_varlen_func - CUDA_AVAILABLE = True -except ImportError: - CUDA_AVAILABLE = False - flash_dmattn_func, flash_dmattn_varlen_func = None, None - -# Import Triton functions when available -try: - from flash_dmattn.flash_dmattn_triton import triton_dmattn_func - TRITON_AVAILABLE = True -except ImportError: - TRITON_AVAILABLE = False - triton_dmattn_func = None - -# Import Flex functions when available -try: - from flash_dmattn.flash_dmattn_flex import flex_dmattn_func - FLEX_AVAILABLE = True -except ImportError: - FLEX_AVAILABLE = False - flex_dmattn_func = None - - -def get_available_backends(): - """Return a list of available backends.""" - backends = [] - if CUDA_AVAILABLE: - backends.append("cuda") - if TRITON_AVAILABLE: - backends.append("triton") - if FLEX_AVAILABLE: - backends.append("flex") - return backends - - -def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs): - """ - Flash Dynamic Mask Attention function with automatic backend selection. - - Args: - backend (str, optional): Backend to use ('cuda', 'triton', 'flex'). - If None, will use the first available backend in order: cuda, triton, flex. - **kwargs: Arguments to pass to the attention function. - - Returns: - The attention function for the specified or auto-selected backend. - """ - if backend is None: - # Auto-select backend - if CUDA_AVAILABLE: - backend = "cuda" - elif TRITON_AVAILABLE: - backend = "triton" - elif FLEX_AVAILABLE: - backend = "flex" - else: - raise RuntimeError("No flash attention backend is available. Please install at least one of: triton, transformers, or build the CUDA extension.") - - if backend == "cuda": - if not CUDA_AVAILABLE: - raise RuntimeError("CUDA backend is not available. Please build the CUDA extension.") - return flash_dmattn_func - - elif backend == "triton": - if not TRITON_AVAILABLE: - raise RuntimeError("Triton backend is not available. Please install triton: pip install triton") - return triton_dmattn_func - - elif backend == "flex": - if not FLEX_AVAILABLE: - raise RuntimeError("Flex backend is not available. Please install transformers: pip install transformers") - return flex_dmattn_func - - else: - raise ValueError(f"Unknown backend: {backend}. Available backends: {get_available_backends()}") - - -__all__ = [ - "CUDA_AVAILABLE", - "TRITON_AVAILABLE", - "FLEX_AVAILABLE", - "flash_dmattn_func", - "flash_dmattn_varlen_func", - "triton_dmattn_func", - "flex_dmattn_func", - "get_available_backends", - "flash_dmattn_func_auto", -] diff --git a/flash_dmattn/flash_dmattn_flex.py b/flash_dmattn/flash_dmattn_flex.py deleted file mode 100644 index 379f984..0000000 --- a/flash_dmattn/flash_dmattn_flex.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import Optional, Tuple -import math -import torch -from torch.nn.attention.flex_attention import create_block_mask -from transformers.integrations.flex_attention import compile_friendly_flex_attention - - -def flex_attention_forward( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - attn_bias: Optional[torch.Tensor] = None, - is_causal: Optional[bool] = None, - softmax_scale: Optional[float] = None, - **kwargs, -) -> Tuple[torch.Tensor, torch.Tensor]: - batch, seqlen_q, nheads, dhead = query.shape - _, seqlen_k, _, _ = key.shape - query = query.transpose(1, 2).contiguous() # [B, H, Q_LEN, D] - key = key.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] - value = value.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] - if attn_mask is not None: - attn_mask = attn_mask[:, :, :, : key.shape[-2]] - else: - attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) - if attn_bias is not None: - attn_bias = attn_bias[:, :, :, : key.shape[-2]] - else: - attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) - if is_causal is None: - is_causal = True - if softmax_scale is None: - softmax_scale = 1.0 / math.sqrt(dhead) - - def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): - score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx] - return score - - def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): - # It looks like you're attempting to use a Tensor in some data-dependent control flow. - # We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 . - # return q_idx >= kv_idx and attn_mask[batch_idx][head_idx][q_idx][kv_idx] > 0 - return q_idx >= kv_idx - - block_mask = create_block_mask( - mask_mod=causal_mask_mod, - B=query.shape[0], - H=None, - Q_LEN=query.shape[2], - KV_LEN=key.shape[2], - device=query.device, - _compile=True, - ) - - kernel_options = { - "BLOCK_M": 64, - "BLOCK_N": 64, - "BLOCK_DMODEL": 32, - "num_stages": 1, - "num_warps": 8, - } - attn_output = compile_friendly_flex_attention( - query, - key, - value, - score_mod=score_mod, - block_mask=block_mask if is_causal else None, - scale=softmax_scale, - kernel_options=kernel_options, - # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. - # For simplification, we thus always return it as no additional computations are introduced. - return_lse=False, - training=False, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output - -flex_dmattn_func = flex_attention_forward \ No newline at end of file diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py deleted file mode 100644 index d34adb0..0000000 --- a/flash_dmattn/flash_dmattn_interface.py +++ /dev/null @@ -1,760 +0,0 @@ -# Copyright (c) 2025, Jingze Shi. - -from typing import Optional, Tuple, Any -from packaging import version -import torch - -import flash_dmattn_cuda as flash_dmattn_gpu # type: ignore - - -def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - - -def _sanitize_tensors(*tensors: Optional[torch.Tensor], nan: float = 0.0, posinf: float = 1e6, neginf: float = -1e6) -> None: - for t in tensors: - if t is not None and isinstance(t, torch.Tensor): - torch.nan_to_num_(t, nan=nan, posinf=posinf, neginf=neginf) - - -def _get_block_size_n(device, head_dim, is_causal): - # This should match the block sizes in the CUDA kernel - assert head_dim <= 256 - major, minor = torch.cuda.get_device_capability(device) - is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) - is_sm80 = major == 8 and minor == 0 - is_sm90 = major == 9 and minor == 0 - if head_dim <= 32: - return 128 - if head_dim <= 64: - return 128 - elif head_dim <= 96: - return 64 - elif head_dim <= 128: - if is_sm8x: - return 64 if (is_causal) else 32 - else: - return 64 - elif head_dim <= 192: - return 64 - elif head_dim <= 224: - return 64 - elif head_dim <= 256: - return 64 - - -def round_multiple(x, m): - return (x + m - 1) // m * m - - -# torch.compile() support is only enabled for pytorch >= 2.4 -# The reason for this is that we are using the new custom_op and register_fake -# APIs, which support inplace modification of inputs in the function itself -if version.parse(torch.__version__) >= version.parse("2.4.0"): - _torch_custom_op_wrapper = torch.library.custom_op - _torch_register_fake_wrapper = torch.library.register_fake -else: - def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None): - def wrap(func): - return func - if fn is None: - return wrap - return fn - def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1): - def wrap(func): - return func - if fn is None: - return wrap - return fn - _torch_custom_op_wrapper = noop_custom_op_wrapper - _torch_register_fake_wrapper = noop_register_fake_wrapper - - -@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_forward", mutates_args=(), device_types="cuda") -def _flash_dmattn_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: Optional[torch.Tensor], - bias: Optional[torch.Tensor], - softmax_scale: float, - is_causal: bool, - softcap: float, - return_softmax: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)] - out, softmax_lse, S_dmask = flash_dmattn_gpu.fwd( - q, - k, - v, - mask, - bias, - None, - softmax_scale, - is_causal, - softcap, - return_softmax, - ) - # _sanitize_tensors(out, nan=0.0, posinf=0.0, neginf=0.0) - return out, softmax_lse, S_dmask - - -@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_forward") -def _flash_dmattn_forward_fake( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: Optional[torch.Tensor], - bias: Optional[torch.Tensor], - softmax_scale: float, - is_causal: bool, - softcap: float, - return_softmax: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)] - batch_size, seqlen_q, num_heads, head_size = q.shape - seqlen_k = k.shape[1] - out = torch.empty_like(q) - softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout) - p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) - if return_softmax: - p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) - - return out, softmax_lse, p - - -_wrapped_flash_dmattn_forward = _flash_dmattn_forward - - -@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_forward", mutates_args=(), device_types="cuda") -def _flash_dmattn_varlen_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - is_causal: bool, - softcap: float = 0.0, - return_softmax: bool = False, - block_table: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask = flash_dmattn_gpu.varlen_fwd( - q, - k, - v, - None, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - leftpad_k, - block_table, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - zero_tensors, - is_causal, - softcap, - return_softmax, - ) - # _sanitize_tensors(out, nan=0.0, posinf=0.0, neginf=0.0) - return out, softmax_lse, S_dmask - - -@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_varlen_forward") -def _flash_dmattn_varlen_forward_fake( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - is_causal: bool, - softcap: float = 0.0, - return_softmax: bool = False, - block_table: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - paged_kv = block_table is not None - batch_size = cu_seqlens_q.numel() - 1 - total_q, num_heads, _ = q.shape - - out = torch.empty_like(q) - softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) - p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) - seqlen_q_rounded = round_multiple(max_seqlen_q, 128) - seqlen_k_rounded = round_multiple(max_seqlen_k, 128) - if return_softmax: - p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) - return out, softmax_lse, p - - -_wrapped_flash_dmattn_varlen_forward = _flash_dmattn_varlen_forward - - -@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_backward", mutates_args=("dq", "dk", "dv", "dbias"), device_types="cuda") -def _flash_dmattn_backward( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: Optional[torch.Tensor], - bias: Optional[torch.Tensor], - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dbias: Optional[torch.Tensor], - softmax_scale: float, - is_causal: bool, - softcap: float, - deterministic: bool, -) -> torch.Tensor: - dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] - ( - dq, - dk, - dv, - dbias, - softmax_d, - ) = flash_dmattn_gpu.bwd( - dout, - q, - k, - v, - mask, - bias, - out, - softmax_lse, - dq, - dk, - dv, - dbias, - softmax_scale, - is_causal, - softcap, - deterministic, - ) - # _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=0.0, neginf=0.0) - return softmax_d - - -@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_backward") -def _flash_dmattn_backward_fake( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: Optional[torch.Tensor], - bias: Optional[torch.Tensor], - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dbias: Optional[torch.Tensor], - softmax_scale: float, - is_causal: bool, - softcap: float, - deterministic: bool, -) -> torch.Tensor: - dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] - if dq is None: - dq = torch.empty_like(q) - if dk is None: - dk = torch.empty_like(k) - if dv is None: - dv = torch.empty_like(v) - if dbias is None: - dbias = torch.empty_like(bias) - batch_size, seqlen_q, num_heads, _ = q.shape - softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) - - return softmax_d - - -_wrapped_flash_dmattn_backward = _flash_dmattn_backward - - -@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") -def _flash_dmattn_varlen_backward( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - is_causal: bool, - softcap: float, - deterministic: bool, - zero_tensors: bool = False, -) -> torch.Tensor: - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - ( - dq, - dk, - dv, - softmax_d, - ) = flash_dmattn_gpu.varlen_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - zero_tensors, - is_causal, - softcap, - deterministic, - ) - # _sanitize_tensors(dq, dk, dv, nan=0.0, posinf=0.0, neginf=0.0) - return softmax_d - - -@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_varlen_backward") -def _flash_dmattn_varlen_backward_fake( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dbias: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - is_causal: bool, - softcap: float, - deterministic: bool, - zero_tensors: bool = False, -) -> torch.Tensor: - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - batch_size = cu_seqlens_q.numel() - 1 - total_q, num_heads, _ = q.shape - - if dq is None: - dq = torch.empty_like(q) - if dk is None: - dk = torch.empty_like(k) - if dv is None: - dv = torch.empty_like(v) - softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) - - return softmax_d - - -_wrapped_flash_dmattn_varlen_backward = _flash_dmattn_varlen_backward - - -class FlashDMAttnFunc(torch.autograd.Function): - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: Optional[torch.Tensor], - bias: Optional[torch.Tensor], - softmax_scale: Optional[float], - is_causal: Optional[bool], - softcap: Optional[float], - deterministic: Optional[bool], - return_softmax: Optional[bool], - is_grad_enabled: bool = True, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - if is_causal is None: - is_causal = False - if softcap is None: - softcap = 0.0 - if deterministic is None: - deterministic = False - if return_softmax is None: - return_softmax = False - seqlen_k_bias_og = bias.shape[-1] if bias is not None else 0 - - # Padding to multiple of 8 for 16-bit memory allocations - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - seqlen_k_rounded = round_multiple(k.shape[1], 128) - if mask is not None and mask.shape[-1] != seqlen_k_rounded: - if mask.shape[-1] == 1: - mask = mask.expand(*mask.shape[:-1], seqlen_k_rounded) - else: - mask = torch.nn.functional.pad(mask, [0, seqlen_k_rounded - mask.shape[-1]]) - if bias is not None and bias.shape[-1] != seqlen_k_rounded: - if bias.shape[-1] == 1: - bias = bias.expand(*bias.shape[:-1], seqlen_k_rounded) - else: - bias = torch.nn.functional.pad(bias, [0, seqlen_k_rounded - bias.shape[-1]]) - - out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( - q, - k, - v, - mask, - bias, - softmax_scale, - is_causal=is_causal, - softcap=softcap, - return_softmax=return_softmax, - ) - - if is_grad: - ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse) - ctx.softmax_scale = softmax_scale - ctx.is_causal = is_causal - ctx.softcap = softcap - ctx.deterministic = deterministic - ctx.seqlen_k_bias_og = seqlen_k_bias_og - - out = out_padded[..., :head_size_og] - - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward( - ctx: torch.autograd.function.FunctionCtx, - dout: torch.Tensor, - *args: Any, - ): - q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) - dbias = torch.zeros_like(bias).contiguous() if bias is not None else None - - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - _wrapped_flash_dmattn_backward( - dout_padded, - q, - k, - v, - mask, - bias, - out, - softmax_lse, - dq, - dk, - dv, - dbias, - ctx.softmax_scale, - ctx.is_causal, - ctx.softcap, - ctx.deterministic, - ) - - # We could have padded the head dimension - dq = dq[..., : dout.shape[-1]] - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - - if dbias is not None: - dbias = dbias[..., :k.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : k.shape[1]] - - return dq, dk, dv, None, dbias, None, None, None, None, None, None - - -class FlashAttnVarlenFunc(torch.autograd.Function): - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - is_causal: Optional[bool], - softcap: Optional[float], - deterministic: Optional[bool], - return_softmax: Optional[bool], - block_table: Optional[torch.Tensor], - is_grad_enabled: bool = True, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - if is_causal is None: - is_causal = False - if softcap is None: - softcap = 0.0 - if deterministic is None: - deterministic = False - if return_softmax is None: - return_softmax = False - - # Padding to multiple of 8 for 16-bit memory allocations - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_varlen_forward( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - is_causal=is_causal, - softcap=softcap, - return_softmax=return_softmax, - block_table=block_table, - ) - - if is_grad: - ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k - ) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.is_causal = is_causal - ctx.softcap = softcap - ctx.deterministic = deterministic - - out = out_padded[..., :head_size_og] - - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors - dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - _wrapped_flash_dmattn_varlen_backward( - dout_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.softmax_scale, - ctx.is_causal, - ctx.softcap, - ctx.deterministic, - ) - - # We could have padded the head dimension - dq = dq[..., : dout.shape[-1]] - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None - - -def flash_dmattn_func( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - attn_bias: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - is_causal: Optional[bool] = None, - softcap: Optional[float] = None, - deterministic: Optional[bool] = None, - return_attn_probs: Optional[bool] = None, -): - """ - Supports multi-query attention and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - Similarity, also supports attn_mask and attn_bias with head dimension of 1, nheads_k or nheads for MQA/GQA. - For example, if Q has 6 heads, K, V have 2 heads, then attn_mask and attn_bias can have head dimension - of 1, 2 or 6. If it is 1, all heads use the same mask/bias; if it is 2, head 0, 1, 2 of Q use head 0 - of mask/bias, head 3, 4, 5 of Q use head 1 of mask/bias. If it is 6, each head uses its own mask/bias. - - If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - Arguments: - query: torch.Tensor. The query tensor of shape (batch_size, seqlen, nheads, headdim) - key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim) - value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim) - attn_mask: torch.Tensor, optional. The attention mask boolean tensor of - shape ({batch_size|1}, {nheads|nheads_k|1}, {seqlen_q|1}, {seqlen_k|1}) to apply to the attention scores. - If None, no mask is applied. - attn_bias: torch.Tensor, optional. The attention bias float tensor of - shape ({batch_size|1}, {nheads|nheads_k|1}, {seqlen_q|1}, {seqlen_k|1}) to add to the attention scores. - If None, no bias is applied. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - softcap: float. Anything > 0 activates softcapping attention. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). - """ - return FlashDMAttnFunc.apply( - query, - key, - value, - attn_mask, - attn_bias, - softmax_scale, - is_causal, - softcap, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_dmattn_varlen_func( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: Optional[float] = None, - is_causal: Optional[bool] = None, - softcap: Optional[float] = None, - deterministic: Optional[bool] = None, - return_attn_probs: Optional[bool] = None, - block_table: Optional[torch.Tensor] = None, -): - """ - Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - Arguments: - query: torch.Tensor. The query tensor of shape (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - key: torch.Tensor. The key tensor of shape (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - value: torch.Tensor. The value tensor of shape (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: torch.Tensor. The cumulative sequence lengths of the sequences in the batch, used to index into q. - cu_seqlens_k: torch.Tensor. The cumulative sequence lengths of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - softcap: float. Anything > 0 activates softcapping attention. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). - """ - return FlashAttnVarlenFunc.apply( - query, - key, - value, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - is_causal, - softcap, - deterministic, - return_attn_probs, - block_table, - torch.is_grad_enabled(), - ) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py deleted file mode 100644 index 66141cb..0000000 --- a/flash_dmattn/flash_dmattn_triton.py +++ /dev/null @@ -1,1246 +0,0 @@ -from typing import Optional -import math - -import torch -import triton -import triton.language as tl - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128}, - num_warps=4, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64}, - num_warps=4, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, - num_warps=4, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128}, - num_warps=8, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64}, - num_warps=8, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, - num_warps=8, - num_stages=1, - ), - ], - key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'HAS_MASK', 'HAS_BIAS', 'BLOCK_HEADDIM'] -) -@triton.heuristics( - { - "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, - "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, - "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], - } -) -@triton.jit -def _fwd_kernel( - Q, - K, - V, - Mask, - Bias, - Out, - Lse, - softmax_scale, - stride_qb, - stride_qh, - stride_qm, - stride_kb, - stride_kh, - stride_kn, - stride_vb, - stride_vh, - stride_vn, - stride_mb, - stride_mh, - stride_mm, - stride_bb, - stride_bh, - stride_bm, - stride_ob, - stride_oh, - stride_om, - nheads, - nheads_k, - nheads_mask, - nheads_bias, - h_h_k_ratio, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - headdim, - CACHE_KEY_SEQLEN_Q: tl.constexpr, - CACHE_KEY_SEQLEN_K: tl.constexpr, - IS_CAUSAL: tl.constexpr, - HAS_MASK: tl.constexpr, - HAS_BIAS: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - EVEN_M: tl.constexpr, - EVEN_N: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - start_m = tl.program_id(0) - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_hq = off_hb % nheads - off_hk = off_hq // h_h_k_ratio - if HAS_MASK: - if nheads_mask == 1: - off_hmask = 0 - elif nheads_mask == nheads_k: - off_hmask = off_hk - else: - off_hmask = off_hq - if HAS_BIAS: - if nheads_bias == 1: - off_hbbias = 0 - elif nheads_bias == nheads_k: - off_hbbias = off_hk - else: - off_hbbias = off_hq - # off_b = tl.program_id(1) - # off_h = tl.program_id(2) - # off_hb = off_b * nheads + off_h - - # Initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_HEADDIM) - - # Initialize pointers to Q, K, V, Mask, Bias - q_ptrs = ( - Q + off_b * stride_qb + off_hq * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) - ) - k_ptrs = ( - K + off_b * stride_kb + off_hk * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) - ) - v_ptrs = ( - V + off_b * stride_vb + off_hk * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) - ) - m_ptrs = ( - Mask + off_b * stride_mb + off_hmask * stride_mh + (offs_m[:, None] * stride_mm + offs_n[None, :]) - ) if HAS_MASK else None - b_ptrs = ( - Bias + off_b * stride_bb + off_hbbias * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) - ) if HAS_BIAS else None - - # Initialize pointer to m and l - lse_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) - - # Load q: it will stay in SRAM throughout - if EVEN_M: - if EVEN_HEADDIM: - q = tl.load(q_ptrs) - else: - q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) - else: - q = tl.load( - q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 - ) - - # Scale q - q = (q * softmax_scale).to(q.dtype) - - # Loop over k, v and update accumulator - end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) - for start_n in range(0, end_n, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - - if HAS_MASK: - # Load mask - if EVEN_M & EVEN_N: - mask = tl.load(m_ptrs + start_n) - else: - mask = tl.load( - m_ptrs + start_n, - mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), - other=False - ) - - # Check if any element in mask is non-zero - any_active = tl.reduce_or(mask, axis=None) - else: - any_active = True - - # Skip this iteration if no active elements - if any_active: - - # Load k - if EVEN_N: - if EVEN_HEADDIM: - k = tl.load(k_ptrs + start_n * stride_kn) - else: - k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) - else: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, - ) - - if HAS_BIAS: - # Load bias - if EVEN_M & EVEN_N: - bias = tl.load(b_ptrs + start_n).to(tl.float32) - else: - bias = tl.load( - b_ptrs + start_n, - mask=(offs_m[:, None] < seqlen_q) - & ((start_n + offs_n)[None, :] < seqlen_k), - other=0.0, - ).to(tl.float32) - acc_s = bias - else: - acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - - # Compute acc_s - acc_s += tl.dot(q, tl.trans(k)) - - # Apply masks - # Trying to combine the three masks seem to make the result wrong - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) - if IS_CAUSAL: - acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) - if HAS_MASK: - acc_s += tl.where(mask, 0, float("-inf")) - - # Compute p - m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) - p = tl.exp(acc_s - m_ij[:, None]) - l_ij = tl.sum(p, 1) - - # Scale acc_o - acc_o_scale = tl.exp(m_i - m_ij) - - # Update output accumulator - acc_o = acc_o * acc_o_scale[:, None] - - # Load v - if EVEN_N: - if EVEN_HEADDIM: - v = tl.load(v_ptrs + start_n * stride_vn) - else: - v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) - else: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, - ) - - # Compute acc_o - acc_o += tl.dot(p.to(v.dtype), v) - - # Update statistics - m_i = m_ij - l_i_new = tl.exp(lse_i - m_ij) + l_ij - lse_i = m_ij + tl.log(l_i_new) - - o_scale = tl.exp(m_i - lse_i) - acc_o = acc_o * o_scale[:, None] - # Rematerialize offsets to save registers - start_m = tl.program_id(0) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # Write back l and m - lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m - tl.store(lse_ptrs, lse_i) - # Initialize pointers to output - offs_d = tl.arange(0, BLOCK_HEADDIM) - out_ptrs = ( - Out - + off_b * stride_ob - + off_hq * stride_oh - + (offs_m[:, None] * stride_om + offs_d[None, :]) - ) - if EVEN_M: - if EVEN_HEADDIM: - tl.store(out_ptrs, acc_o) - else: - tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) - else: - tl.store( - out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) - ) - - -@triton.jit -def _bwd_preprocess_do_o_dot( - Out, - DO, - Delta, - stride_ob, - stride_oh, - stride_om, - stride_dob, - stride_doh, - stride_dom, - nheads, - seqlen_q, - seqlen_q_rounded, - headdim, - BLOCK_M: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, -): - start_m = tl.program_id(0) - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads - # Initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_d = tl.arange(0, BLOCK_HEADDIM) - # Load o - o = tl.load( - Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ).to(tl.float32) - do = tl.load( - DO - + off_b * stride_dob - + off_h * stride_doh - + offs_m[:, None] * stride_dom - + offs_d[None, :], - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ).to(tl.float32) - delta = tl.sum(o * do, axis=1) - # Write back - tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) - - -@triton.jit -def _bwd_kernel_one_col_block( - start_n, - Q, - K, - V, - Mask, - Bias, - DO, - DQ, - DK, - DV, - DBias, - LSE, - D, - softmax_scale, - stride_qm, - stride_kn, - stride_vn, - stride_mm, - stride_bm, - stride_dom, - stride_dqm, - stride_dkn, - stride_dvn, - stride_dbm, - seqlen_q, - seqlen_k, - headdim, - IS_CAUSAL: tl.constexpr, - HAS_MASK: tl.constexpr, - HAS_BIAS: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - EVEN_M: tl.constexpr, - EVEN_N: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, - ATOMIC_ADD: tl.constexpr, - ACCUM_DBIAS: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) - begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M - # Initialize row/col offsets - offs_qm = begin_m + tl.arange(0, BLOCK_M) - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_m = tl.arange(0, BLOCK_M) - offs_d = tl.arange(0, BLOCK_HEADDIM) - # Initialize pointers to value-like data - q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) - v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) - if HAS_MASK: - m_ptrs = Mask + (offs_qm[:, None] * stride_mm + offs_n[None, :]) - if HAS_BIAS: - b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) - do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) - dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) - db_ptrs = DBias + (offs_qm[:, None] * stride_dbm + offs_n[None, :]) - # Initialize dv and dk - dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) - dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) - # There seems to be some problem with Triton pipelining that makes results wrong for - # headdim=64, seqlen=(113, 255). In this case the for loop may have zero step, - # and pipelining with the bias matrix could screw it up. So we just exit early. - if begin_m >= seqlen_q: - dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) - dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - - if EVEN_N: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) - else: - tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) - tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) - tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) - else: - tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - return - - # Load k and v, them will stay in SRAM throughout - if EVEN_N: - if EVEN_HEADDIM: - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - else: - k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) - v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) - else: - k = tl.load( - k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 - ) - v = tl.load( - v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 - ) - - # Scale k - k = (k * softmax_scale).to(k.dtype) - - # Initialize accumulator for dbias if needed - acc_dbias = tl.zeros([BLOCK_N], dtype=tl.float32) if (HAS_BIAS and ACCUM_DBIAS) else None - - # Loop over q and update accumulators - num_block_m = tl.cdiv(seqlen_q, BLOCK_M) - for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): - start_m = tl.multiple_of(start_m, BLOCK_M) - offs_m_curr = start_m + offs_m - - if HAS_MASK: - # Load mask - if EVEN_M & EVEN_N: - mask = tl.load(m_ptrs) - else: - mask = tl.load( - m_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), - other=False, - ) - - # Check if any element in mask is non-zero - any_active = tl.reduce_or(mask, axis=None) - else: - any_active = True - - # Skip this iteration if no active elements - if any_active: - # Load q - if EVEN_M & EVEN_HEADDIM: - q = tl.load(q_ptrs) - else: - if EVEN_HEADDIM: - q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) - else: - q = tl.load( - q_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ) - - if HAS_BIAS: - # Load bias - if EVEN_M & EVEN_N: - bias = tl.load(b_ptrs).to(tl.float32) - else: - bias = tl.load( - b_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), - other=0.0, - ).to(tl.float32) - acc_s = bias - else: - acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - - # Compute acc_s - acc_s += tl.dot(q, tl.trans(k)) - - # Apply masks - # Trying to combine the three masks seem to make the result wrong - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - acc_s += tl.where(offs_n[None, :] < seqlen_k, 0, float("-inf")) - if IS_CAUSAL: - acc_s += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float("-inf")) - if HAS_MASK: - acc_s += tl.where(mask, 0, float("-inf")) - - lse_i = tl.load(LSE + offs_m_curr) - # p = tl.exp(acc_s - lse_i[:, None]) - p = tl.exp(acc_s - tl.where(lse_i > float("-inf"), lse_i, 0.0)[:, None]) - - # Load do - if EVEN_M & EVEN_HEADDIM: - do = tl.load(do_ptrs) - else: - # There's a race condition if we just use m_mask and not d_mask. - do = tl.load( - do_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ) - - # Compute dv - dv += tl.dot(tl.trans(p.to(do.dtype)), do) - - # Compute dp - dp = tl.dot(do, tl.trans(v)) - - # Putting the subtraction after the dp matmul (instead of before) is slightly faster - Di = tl.load(D + offs_m_curr) - - # Compute ds - # Converting ds to q.dtype here reduces register pressure and makes it much faster - # for BLOCK_HEADDIM=128 - ds = (p * (dp - Di[:, None])).to(q.dtype) - - # Write back - if not (EVEN_M & EVEN_N): - tl.debug_barrier() - if HAS_BIAS: - if ACCUM_DBIAS: - acc_dbias += tl.sum(ds, axis=0) - else: - if EVEN_M & EVEN_N: - tl.store( - db_ptrs, - ds, - ) - else: - tl.store( - db_ptrs, - ds, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), - ) - - # Compute dk - dk += tl.dot(tl.trans(ds), q) - - # Compute dq - if not ATOMIC_ADD: - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds, k).to(ds.dtype) - tl.store(dq_ptrs, dq, eviction_policy="evict_last") - else: - if EVEN_HEADDIM: - dq = tl.load( - dq_ptrs, - mask=offs_m_curr[:, None] < seqlen_q, - other=0.0, - eviction_policy="evict_last", - ) - dq += tl.dot(ds, k).to(ds.dtype) - tl.store( - dq_ptrs, - dq, - mask=offs_m_curr[:, None] < seqlen_q, - eviction_policy="evict_last", - ) - else: - dq = tl.load( - dq_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - eviction_policy="evict_last", - ) - dq += tl.dot(ds, k).to(ds.dtype) - tl.store( - dq_ptrs, - dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - eviction_policy="evict_last", - ) - else: # If we're parallelizing across the seqlen_k dimension - dq = tl.dot(ds, k).to(ds.dtype) - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - tl.atomic_add(dq_ptrs, dq) - else: - if EVEN_HEADDIM: - tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) - else: - tl.atomic_add( - dq_ptrs, - dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - ) - - # Increment pointers - do_ptrs += BLOCK_M * stride_dom - dq_ptrs += BLOCK_M * stride_dqm - if HAS_BIAS: - db_ptrs += BLOCK_M * stride_dbm - q_ptrs += BLOCK_M * stride_qm - if HAS_MASK: - m_ptrs += BLOCK_M * stride_mm - if HAS_BIAS: - b_ptrs += BLOCK_M * stride_bm - - # Scale dk - dk = (dk * softmax_scale).to(dk.dtype) - - # Write back - dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) - dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - if HAS_BIAS and ACCUM_DBIAS: - if EVEN_N: - tl.store(DBias + offs_n, acc_dbias) - else: - tl.store(DBias + offs_n, acc_dbias, mask=(offs_n < seqlen_k)) - - if EVEN_N: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) - else: - tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) - tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) - tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) - else: - tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - - -def init_to_zero(names): - if isinstance(names, str): - names = [names] - def init_func(nargs): - for name in names: - nargs[name].zero_() - return init_func - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, - num_warps=8, - num_stages=1, - pre_hook=init_to_zero(["DQ", "DBias"]), - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, - num_warps=8, - num_stages=1, - pre_hook=init_to_zero(["DQ", "DBias"]), - ), - ], - key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "HAS_MASK", "HAS_BIAS", "HAS_INDICE", "BLOCK_HEADDIM"], -) -@triton.heuristics( - { - "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, - "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, - "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], - "ACCUM_DBIAS": lambda args: args["HAS_BIAS"] and (args["stride_dbm"] == 0) and (args["seqlen_q"] > 1), - } -) -@triton.jit -def _bwd_kernel( - Q, - K, - V, - Mask, - Bias, - DO, - DQ, - DK, - DV, - DBias, - LSE, - D, - softmax_scale, - stride_qb, - stride_qh, - stride_qm, - stride_kb, - stride_kh, - stride_kn, - stride_vb, - stride_vh, - stride_vn, - stride_mb, - stride_mh, - stride_mm, - stride_bb, - stride_bh, - stride_bm, - stride_dob, - stride_doh, - stride_dom, - stride_dqb, - stride_dqh, - stride_dqm, - stride_dkb, - stride_dkh, - stride_dkn, - stride_dvb, - stride_dvh, - stride_dvn, - stride_dbb, - stride_dbh, - stride_dbm, - nheads, - nheads_k, - nheads_mask, - nheads_bias, - h_h_k_ratio, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - headdim, - CACHE_KEY_SEQLEN_Q, - CACHE_KEY_SEQLEN_K, - IS_CAUSAL: tl.constexpr, - HAS_MASK: tl.constexpr, - HAS_BIAS: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - SEQUENCE_PARALLEL: tl.constexpr, - EVEN_M: tl.constexpr, - EVEN_N: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, - ACCUM_DBIAS: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_hq = off_hb % nheads - off_hk = off_hq // h_h_k_ratio - if HAS_MASK: - if nheads_mask == 1: - off_hmask = 0 - elif nheads_mask == nheads_k: - off_hmask = off_hk - else: - off_hmask = off_hq - if HAS_BIAS: - if nheads_bias == 1: - off_hbbias = 0 - elif nheads_bias == nheads_k: - off_hbbias = off_hk - else: - off_hbbias = off_hq - - # Advance offset pointers for batch and head - Q += off_b * stride_qb + off_hq * stride_qh - K += off_b * stride_kb + off_hk * stride_kh - V += off_b * stride_vb + off_hk * stride_vh - if HAS_MASK: - Mask += off_b * stride_mb + off_hmask * stride_mh - if HAS_BIAS: - Bias += off_b * stride_bb + off_hbbias * stride_bh - DO += off_b * stride_dob + off_hq * stride_doh - DQ += off_b * stride_dqb + off_hq * stride_dqh - DK += off_b * stride_dkb + off_hq * stride_dkh - DV += off_b * stride_dvb + off_hq * stride_dvh - if HAS_BIAS: - DBias += off_b * stride_dbb + off_hq * stride_dbh - # Advance pointer to row-wise quantities in value-like data - D += off_hb * seqlen_q_rounded - LSE += off_hb * seqlen_q_rounded - - if not SEQUENCE_PARALLEL: - num_block_n = tl.cdiv(seqlen_k, BLOCK_N) - for start_n in range(0, num_block_n): - _bwd_kernel_one_col_block( - start_n, - Q, - K, - V, - Mask, - Bias, - DO, - DQ, - DK, - DV, - DBias, - LSE, - D, - softmax_scale, - stride_qm, - stride_kn, - stride_vn, - stride_mm, - stride_bm, - stride_dom, - stride_dqm, - stride_dkn, - stride_dvn, - stride_dbm, - seqlen_q, - seqlen_k, - headdim, - IS_CAUSAL=IS_CAUSAL, - HAS_MASK=HAS_MASK, - HAS_BIAS=HAS_BIAS, - BLOCK_HEADDIM=BLOCK_HEADDIM, - EVEN_M=EVEN_M, - EVEN_N=EVEN_N, - EVEN_HEADDIM=EVEN_HEADDIM, - ATOMIC_ADD=False, - ACCUM_DBIAS=ACCUM_DBIAS, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) - else: - start_n = tl.program_id(0) - _bwd_kernel_one_col_block( - start_n, - Q, - K, - V, - Mask, - Bias, - DO, - DQ, - DK, - DV, - DBias, - LSE, - D, - softmax_scale, - stride_qm, - stride_kn, - stride_vn, - stride_mm, - stride_bm, - stride_dom, - stride_dqm, - stride_dkn, - stride_dvn, - stride_dbm, - seqlen_q, - seqlen_k, - headdim, - IS_CAUSAL=IS_CAUSAL, - HAS_MASK=HAS_MASK, - HAS_BIAS=HAS_BIAS, - BLOCK_HEADDIM=BLOCK_HEADDIM, - EVEN_M=EVEN_M, - EVEN_N=EVEN_N, - EVEN_HEADDIM=EVEN_HEADDIM, - ATOMIC_ADD=True, - ACCUM_DBIAS=ACCUM_DBIAS, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) - - -def _flash_dmattn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False): - # shape constraints - batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - - assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" - assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" - assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" - assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" - assert q.is_cuda and k.is_cuda and v.is_cuda - - has_mask = mask is not None - if has_mask: - assert mask.dtype == torch.bool, "Only support bool" - assert mask.is_cuda - nheads_mask = mask.shape[1] - else: - nheads_mask = 1 - mask = torch.empty(0, device=q.device, dtype=torch.bool) - - has_bias = bias is not None - if has_bias: - assert bias.dtype == q.dtype, "Only support fp16 and bf16" - assert bias.is_cuda - nheads_bias = bias.shape[1] - else: - nheads_bias = 1 - bias = torch.empty(0, device=q.device, dtype=q.dtype) - - softmax_scale = softmax_scale or 1.0 / math.sqrt(d) - - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - o = torch.empty_like(q) - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - # BLOCK_M = 128 - # BLOCK_N = 64 - # num_warps = 4 if d <= 64 else 8 - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - _fwd_kernel[grid]( - q, - k, - v, - mask, - bias, - o, - lse, - softmax_scale, - q.stride(0), - q.stride(2), - q.stride(1), - k.stride(0), - k.stride(2), - k.stride(1), - v.stride(0), - v.stride(2), - v.stride(1), - ((0 if (has_mask and mask.shape[0] == 1) else (mask.stride(0) if has_mask else 0))), - ((0 if (has_mask and mask.shape[1] == 1) else (mask.stride(1) if has_mask else 0))), - ((0 if (has_mask and mask.shape[2] == 1) else (mask.stride(2) if has_mask else 0))), - ((0 if (has_bias and bias.shape[0] == 1) else (bias.stride(0) if has_bias else 0))), - ((0 if (has_bias and bias.shape[1] == 1) else (bias.stride(1) if has_bias else 0))), - ((0 if (has_bias and bias.shape[2] == 1) else (bias.stride(2) if has_bias else 0))), - o.stride(0), - o.stride(2), - o.stride(1), - nheads, - nheads_k, - nheads_mask, - nheads_bias, - nheads // nheads_k, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - d, - seqlen_q // 32, - seqlen_k // 32, # key for triton cache (limit number of compilations) - # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # IS_CAUSAL=is_causal, HAS_MASK=has_mask, HAS_BIAS=has_bias, BLOCK_HEADDIM=d, - is_causal, - has_mask, - has_bias, - BLOCK_HEADDIM, - # BLOCK_M=BLOCK_M, - # BLOCK_N=BLOCK_N, - # num_warps=num_warps, - # num_stages=1, - ) - return o, lse, softmax_scale # softmax_scale could have been updated - - -def _flash_dmattn_backward( - do, q, k, v, mask, bias, o, lse, softmax_scale=None, is_causal=False -): - # Make sure that the last dimension is contiguous - if do.stride(-1) != 1: - do = do.contiguous() - batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, dk = k.shape - - assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" - assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - seqlen_k_rounded = math.ceil(seqlen_k / 128) * 128 - assert lse.shape == (batch, nheads, seqlen_q_rounded) - - has_mask = mask is not None - if has_mask: - assert mask.dtype == torch.bool, "Only support bool" - nheads_mask = mask.shape[1] - else: - nheads_mask = 1 - mask = torch.empty(0, device=q.device, dtype=torch.bool) - - has_bias = bias is not None - if has_bias: - assert bias.dtype == q.dtype, "Only support fp16 and bf16" - nheads_bias = bias.shape[1] - else: - nheads_bias = 1 - bias = torch.empty(0, device=q.device, dtype=q.dtype) - - softmax_scale = softmax_scale or 1.0 / math.sqrt(d) - # dq_accum = torch.zeros_like(q, dtype=torch.float32) - dq_accum = torch.empty_like(q, dtype=torch.float32) - delta = torch.empty_like(lse) - # delta = torch.zeros_like(lse) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - dbias = torch.empty_like(bias) if has_bias else torch.empty(0, device=q.device, dtype=q.dtype) - - dk_expanded = torch.empty(batch, seqlen_k, nheads, d, device=q.device, dtype=q.dtype) if nheads != nheads_k else dk - dv_expanded = torch.empty(batch, seqlen_k, nheads, d, device=q.device, dtype=q.dtype) if nheads != nheads_k else dv - if has_bias: - if ( - nheads_bias != nheads - or ((bias.shape[0] == 1) and (batch > 1)) - or ((bias.shape[-2] == 1) and (seqlen_q > 1)) - ): - if bias.shape[-2] == 1: - dbias_expanded = torch.zeros(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) - else: - dbias_expanded = torch.zeros(batch, nheads, seqlen_q, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) - else: - dbias_expanded = dbias - else: - dbias_expanded = dbias - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - _bwd_preprocess_do_o_dot[grid]( - o, - do, - delta, - o.stride(0), - o.stride(2), - o.stride(1), - do.stride(0), - do.stride(2), - do.stride(1), - nheads, - seqlen_q, - seqlen_q_rounded, - d, - BLOCK_M=64, - BLOCK_HEADDIM=BLOCK_HEADDIM, - ) - - # BLOCK_M = 128 - # BLOCK_N = 64 - # num_warps = 4 - grid = lambda META: ( - triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, - batch * nheads, - ) - _bwd_kernel[grid]( - q, - k, - v, - mask, - bias, - do, - dq_accum, - dk_expanded, - dv_expanded, - dbias_expanded, - lse, - delta, - softmax_scale, - q.stride(0), - q.stride(2), - q.stride(1), - k.stride(0), - k.stride(2), - k.stride(1), - v.stride(0), - v.stride(2), - v.stride(1), - ((0 if (has_mask and mask.shape[0] == 1) else (mask.stride(0) if has_mask else 0))), - ((0 if (has_mask and mask.shape[1] == 1) else (mask.stride(1) if has_mask else 0))), - ((0 if (has_mask and mask.shape[2] == 1) else (mask.stride(2) if has_mask else 0))), - ((0 if (has_bias and bias.shape[0] == 1) else (bias.stride(0) if has_bias else 0))), - ((0 if (has_bias and bias.shape[1] == 1) else (bias.stride(1) if has_bias else 0))), - ((0 if (has_bias and bias.shape[2] == 1) else (bias.stride(2) if has_bias else 0))), - do.stride(0), - do.stride(2), - do.stride(1), - dq_accum.stride(0), - dq_accum.stride(2), - dq_accum.stride(1), - dk_expanded.stride(0), - dk_expanded.stride(2), - dk_expanded.stride(1), - dv_expanded.stride(0), - dv_expanded.stride(2), - dv_expanded.stride(1), - (dbias_expanded.stride(0) if has_bias else 0), - (dbias_expanded.stride(1) if has_bias else 0), - ((0 if (has_bias and bias.shape[-2] == 1) else (dbias_expanded.stride(2) if has_bias else 0))), - nheads, - nheads_k, - nheads_mask, - nheads_bias, - nheads // nheads_k, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - d, - seqlen_q // 32, - seqlen_k // 32, # key for triton cache (limit number of compilations) - # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # IS_CAUSAL=is_causal, HAS_MASK=has_mask, HAS_BIAS=has_bias, BLOCK_HEADDIM=BLOCK_HEADDIM, - is_causal, - has_mask, - has_bias, - BLOCK_HEADDIM, - # SEQUENCE_PARALLEL=False, - # BLOCK_M=BLOCK_M, - # BLOCK_N=BLOCK_N, - # num_warps=num_warps, - # num_stages=1, - ) - dq = dq_accum.to(q.dtype) - if nheads != nheads_k: - dk = dk_expanded.view(batch, seqlen_k, nheads_k, nheads // nheads_k, d).sum(dim=3) - dv = dv_expanded.view(batch, seqlen_k, nheads_k, nheads // nheads_k, d).sum(dim=3) - if has_bias: - if ( - nheads_bias != nheads - and bias.shape[0] == batch - and bias.shape[-2] == seqlen_q - ): - dbias = dbias_expanded.view(batch, nheads_bias, nheads // nheads_bias, seqlen_q, seqlen_k_rounded).sum(dim=2) - else: - if bias.shape[-2] == 1: - dbias_expanded = dbias_expanded.view(batch, nheads_bias, nheads // nheads_bias, 1, seqlen_k_rounded).sum(dim=2) - else: - dbias_expanded = dbias_expanded.view(batch, nheads_bias, nheads // nheads_bias, seqlen_q, seqlen_k_rounded).sum(dim=2) - if bias.shape[0] == 1: - dbias_expanded = dbias_expanded.sum(dim=0, keepdim=True) - dbias.copy_(dbias_expanded) - return dq, dk, dv, dbias if has_bias else None - - -def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - - -def round_multiple(x, m): - return (x + m - 1) // m * m - - -class FlashDMAttnFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=False, softmax_scale=None): - """ - query: (batch_size, seqlen_q, nheads, headdim) - key: (batch_size, seqlen_k, nheads, headdim) - value: (batch_size, seqlen_k, nheads, headdim) - attn_mask: optional, (batch, nheads, seqlen_q, seqlen_k) - attn_bias: optional, (batch, nheads, seqlen_q, seqlen_k) - is_causal: bool, whether to apply causal masking - softmax_scale: float, scaling factor for attention scores - """ - - # Make sure that the last dimension is contiguous - query, key, value, attn_mask, attn_bias = [maybe_contiguous(x) for x in [query, key, value, attn_mask, attn_bias]] - - # Padding to multiple of 8 for 16-bit memory allocations - head_size_og = query.size(3) - if head_size_og % 8 != 0: - query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8]) - key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8]) - value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8]) - seqlen_k_rounded = round_multiple(key.shape[1], 128) - if attn_mask is not None and attn_mask.shape[-1] != seqlen_k_rounded: - if attn_mask.shape[-1] == 1: - attn_mask = attn_mask.expand(*attn_mask.shape[:-1], seqlen_k_rounded) - else: - attn_mask = torch.nn.functional.pad(attn_mask, [0, seqlen_k_rounded - attn_mask.shape[-1]]) - if attn_bias is not None and attn_bias.shape[-1] != seqlen_k_rounded: - if attn_bias.shape[-1] == 1: - attn_bias = attn_bias.expand(*attn_bias.shape[:-1], seqlen_k_rounded) - else: - attn_bias = torch.nn.functional.pad(attn_bias, [0, seqlen_k_rounded - attn_bias.shape[-1]]) - - o, lse, ctx.softmax_scale = _flash_dmattn_forward( - query, - key, - value, - attn_mask, - attn_bias, - softmax_scale=softmax_scale, - is_causal=is_causal - ) - ctx.save_for_backward(query, key, value, o, lse, attn_mask, attn_bias) - ctx.is_causal = is_causal - ctx.seqlen_k_bias_og = attn_bias.shape[-1] if attn_bias is not None else 0 - return o - - @staticmethod - def backward(ctx, do): - query, key, value, o, lse, attn_mask, attn_bias = ctx.saved_tensors - - head_size_og = do.size(3) - do_padded = do - if head_size_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) - - dq, dk, dv, dbias = _flash_dmattn_backward( - do_padded, - query, - key, - value, - attn_mask, - attn_bias, - o, - lse, - softmax_scale=ctx.softmax_scale, - is_causal=ctx.is_causal, - ) - - # We could have padded the head dimension - dq = dq[..., : do.shape[-1]] - dk = dk[..., : do.shape[-1]] - dv = dv[..., : do.shape[-1]] - - if dbias is not None: - dbias = dbias[..., :key.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : key.shape[1]] - - return dq, dk, dv, None, dbias, None, None - - -def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, is_causal=False, softmax_scale=None): - return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, is_causal, softmax_scale) diff --git a/flash_dmattn/flash_dmattn_triton_special.py b/flash_dmattn/flash_dmattn_triton_special.py deleted file mode 100644 index 828e80b..0000000 --- a/flash_dmattn/flash_dmattn_triton_special.py +++ /dev/null @@ -1,1244 +0,0 @@ -from typing import Optional -import math - -import torch -import triton -import triton.language as tl - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128}, - num_warps=4, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64}, - num_warps=4, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, - num_warps=4, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128}, - num_warps=8, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64}, - num_warps=8, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, - num_warps=8, - num_stages=1, - ), - ], - key=['IS_CAUSAL', 'BLOCK_HEADDIM'] -) -@triton.heuristics( - { - "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], - } -) -@triton.jit -def _fwd_preprocess( - K, - V, - B, - I, - CuK, - CuV, - CuB, - CuM, - stride_kb, - stride_kh, - stride_kn, - stride_vb, - stride_vh, - stride_vn, - stride_bb, - stride_bh, - stride_bn, - stride_ib, - stride_ih, - stride_ik, - stride_ckb, - stride_ckh, - stride_ckk, - stride_cvb, - stride_cvh, - stride_cvk, - stride_cbb, - stride_cbh, - stride_cbk, - stride_cmb, - stride_cmh, - stride_cmm, - stride_cmk, - nheads_k, - seqlen_q, - seqlen_k, - window_size, - headdim, - IS_CAUSAL: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - off_hb = tl.program_id(0) - off_b = off_hb // nheads_k - off_hk = off_hb % nheads_k - - # Initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_HEADDIM) - - # Initialize base pointers to K, V, B, I, CuK, CuV, CuB - k_base_ptrs = ( - K + off_b * stride_kb + off_hk * stride_kh - ) - v_base_ptrs = ( - V + off_b * stride_vb + off_hk * stride_vh - ) - b_base_ptrs = ( - B + off_b * stride_bb + off_hk * stride_bh - ) - i_base_ptrs = ( - I + off_b * stride_ib + off_hk * stride_ih - ) - cuk_base_ptrs = ( - CuK + off_b * stride_ckb + off_hk * stride_ckh - ) - cuv_base_ptrs = ( - CuV + off_b * stride_cvb + off_hk * stride_cvh - ) - cub_base_ptrs = ( - CuB + off_b * stride_cbb + off_hk * stride_cbh - ) - cum_base_ptrs = ( - CuM + off_b * stride_cmb + off_hk * stride_cmh - ) - - # Loop over blocks of window_size - for start_k in range(0, window_size, BLOCK_N): - start_k = tl.multiple_of(start_k, BLOCK_N) - offs_k = start_k + offs_n - - # Load I - i_ptrs = ( - i_base_ptrs + offs_k * stride_ik - ) - gather_idx = tl.load(i_ptrs, mask=offs_k < window_size, other=0).to(tl.int64) - valid_idx = (offs_k < window_size) & (gather_idx >= 0) & (gather_idx < seqlen_k) - gather_idx = tl.where(valid_idx, gather_idx, 0) - - # Load K, V, B - k_ptrs = ( - k_base_ptrs + gather_idx[:, None] * stride_kn + offs_d[None, :] - ) - v_ptrs = ( - v_base_ptrs + gather_idx[:, None] * stride_vn + offs_d[None, :] - ) - if EVEN_HEADDIM: - k = tl.load(k_ptrs, mask=valid_idx[:, None], other=0.0) - v = tl.load(v_ptrs, mask=valid_idx[:, None], other=0.0) - else: - k = tl.load( - k_ptrs, - mask=valid_idx[:, None] & (offs_d[None, :] < headdim), - other=0.0 - ) - v = tl.load( - v_ptrs, - mask=valid_idx[:, None] & (offs_d[None, :] < headdim), - other=0.0 - ) - b_ptrs = ( - b_base_ptrs + gather_idx * stride_bn - ) - b = tl.load(b_ptrs, mask=valid_idx, other=0.0) - - # Store to CuK, CuV, CuB - cuk_ptrs = ( - cuk_base_ptrs + offs_k[:, None] * stride_ckk + offs_d[None, :] - ) - cuv_ptrs = ( - cuv_base_ptrs + offs_k[:, None] * stride_cvk + offs_d[None, :] - ) - if EVEN_HEADDIM: - tl.store(cuk_ptrs, k, mask=valid_idx[:, None]) - tl.store(cuv_ptrs, v, mask=valid_idx[:, None]) - else: - tl.store( - cuk_ptrs, k, - mask=valid_idx[:, None] & (offs_d[None, :] < headdim), - ) - tl.store( - cuv_ptrs, v, - mask=valid_idx[:, None] & (offs_d[None, :] < headdim), - ) - cub_ptrs = ( - cub_base_ptrs + offs_k * stride_cbk - ) - tl.store(cub_ptrs, b, mask=valid_idx) - - # Store mask to CuM - for start_m in range(0, seqlen_q, BLOCK_M): - start_m = tl.multiple_of(start_m, BLOCK_M) - offs_m = start_m + tl.arange(0, BLOCK_M) - - cum_ptrs = ( - cum_base_ptrs + offs_m[:, None] * stride_cmm + offs_k[None, :] * stride_cmk - ) - - col_mask = offs_k < window_size - row_mask = offs_m[:, None] < seqlen_q - - if IS_CAUSAL: - mask = (offs_m[:, None] >= gather_idx[None, :]) & valid_idx[None, :] - else: - mask = valid_idx[None, :] - - cum = tl.where(row_mask & col_mask[None, :], mask, False) - - tl.store(cum_ptrs, cum, mask=row_mask & col_mask[None, :]) - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128}, - num_warps=4, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64}, - num_warps=4, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, - num_warps=4, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128}, - num_warps=8, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64}, - num_warps=8, - num_stages=1, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, - num_warps=8, - num_stages=1, - ), - ], - key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BLOCK_HEADDIM'] -) -@triton.heuristics( - { - "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, - "EVEN_N": lambda args: args["window_size"] % args["BLOCK_N"] == 0, - "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], - } -) -@triton.jit -def _fwd_kernel( - Q, - CuK, - CuV, - CuB, - CuM, - Out, - Lse, - softmax_scale, - stride_qb, - stride_qh, - stride_qm, - stride_ckb, - stride_ckh, - stride_ckk, - stride_cvb, - stride_cvh, - stride_cvk, - stride_cbb, - stride_cbh, - stride_cbk, - stride_cmb, - stride_cmh, - stride_cmm, - stride_cmk, - stride_ob, - stride_oh, - stride_om, - nheads, - h_h_k_ratio, - seqlen_q, - window_size, - seqlen_q_rounded, - headdim, - CACHE_KEY_SEQLEN_Q: tl.constexpr, - CACHE_KEY_SEQLEN_K: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - EVEN_M: tl.constexpr, - EVEN_N: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - start_m = tl.program_id(0) - off_bh = tl.program_id(1) - off_b = off_bh // nheads - off_hq = off_bh % nheads - off_hk = off_hq // h_h_k_ratio - - # Initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_HEADDIM) - - # Initialize pointers to Q, CuK, CuV, CuM, CuB - q_ptrs = ( - Q + off_b * stride_qb + off_hq * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) - ) - cuk_base_ptrs = ( - CuK + off_b * stride_ckb + off_hk * stride_ckh - ) - cv_base_ptrs = ( - CuV + off_b * stride_cvb + off_hk * stride_cvh - ) - cub_base_ptrs = ( - CuB + off_b * stride_cbb + off_hk * stride_cbh - ) - cum_base_ptrs = ( - CuM + off_b * stride_cmb + off_hk * stride_cmh - ) - - # Initialize pointer to m and l - lse_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) - - # Load q: it will stay in SRAM throughout - if EVEN_M: - if EVEN_HEADDIM: - q = tl.load(q_ptrs) - else: - q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) - else: - q = tl.load( - q_ptrs, - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0 - ) - - # Scale q - q = (q * softmax_scale).to(q.dtype) - - # Loop over k, v and update accumulator - for start_n in range(0, window_size, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - - cum_ptrs = ( - cum_base_ptrs + offs_m[:, None] * stride_cmm + (start_n + offs_n)[None, :] * stride_cmk - ) - # Load mask - if EVEN_M & EVEN_N: - m = tl.load(cum_ptrs) - else: - m = tl.load( - cum_ptrs, - mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < window_size), - other=False, - ) - - # Check if any element in mask is non-zero - any_active = tl.reduce_or(m, axis=None) - - # Skip this iteration if no active elements - if any_active: - - # Load k - cuk_ptrs = ( - cuk_base_ptrs + (start_n + offs_n)[:, None] * stride_ckk + offs_d[None, :] - ) - if EVEN_N: - if EVEN_HEADDIM: - k = tl.load(cuk_ptrs) - else: - k = tl.load(cuk_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load( - cuk_ptrs, - mask=(start_n + offs_n)[:, None] < window_size, - other=0.0, - ) - else: - k = tl.load( - cuk_ptrs, - mask=((start_n + offs_n)[:, None] < window_size) & (offs_d[None, :] < headdim), - other=0.0, - ) - - # Load bias - cub_ptrs = ( - cub_base_ptrs + (start_n + offs_n) * stride_cbk - ) - if EVEN_M & EVEN_N: - b = tl.load(cub_ptrs) - else: - b = tl.load( - cub_ptrs, - mask=(start_n + offs_n) < window_size, - other=0.0, - ) - - # Initialize acc_s - acc_s = b[None, :].to(tl.float32) - - # Compute acc_s - acc_s += tl.dot(q, tl.trans(k)) - - # Apply masks - # Trying to combine the two masks seem to make the result wrong - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - acc_s += tl.where((start_n + offs_n)[None, :] < window_size, 0, float("-inf")) - acc_s += tl.where(m, 0, float("-inf")) - - # Compute p - m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) - p = tl.exp(acc_s - m_ij[:, None]) - l_ij = tl.sum(p, 1) - - # Scale acc_o - acc_o_scale = tl.exp(m_i - m_ij) - - # Update output accumulator - acc_o = acc_o * acc_o_scale[:, None] - - # Load v - cuv_ptrs = ( - cv_base_ptrs + (start_n + offs_n)[:, None] * stride_cvk + offs_d[None, :] - ) - if EVEN_N: - if EVEN_HEADDIM: - v = tl.load(cuv_ptrs) - else: - v = tl.load(cuv_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - v = tl.load( - cuv_ptrs, - mask=(start_n + offs_n)[:, None] < window_size, - other=0.0, - ) - else: - v = tl.load( - cuv_ptrs, - mask=((start_n + offs_n)[:, None] < window_size) & (offs_d[None, :] < headdim), - other=0.0, - ) - - # Compute acc_o - acc_o += tl.dot(p.to(v.dtype), v) - - # Update statistics - m_i = m_ij - l_i_new = tl.exp(lse_i - m_ij) + l_ij - lse_i = m_ij + tl.log(l_i_new) - - o_scale = tl.exp(m_i - lse_i) - acc_o = acc_o * o_scale[:, None] - # Rematerialize offsets to save registers - start_m = tl.program_id(0) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # Write back l and m - lse_ptrs = Lse + off_bh * seqlen_q_rounded + offs_m - tl.store(lse_ptrs, lse_i) - # Initialize pointers to output - offs_d = tl.arange(0, BLOCK_HEADDIM) - out_ptrs = ( - Out - + off_b * stride_ob - + off_hq * stride_oh - + (offs_m[:, None] * stride_om + offs_d[None, :]) - ) - if EVEN_M: - if EVEN_HEADDIM: - tl.store(out_ptrs, acc_o) - else: - tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) - else: - tl.store( - out_ptrs, acc_o, - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) - ) - - -@triton.jit -def _bwd_preprocess_do_o_dot( - Out, - DO, - Delta, - stride_ob, - stride_oh, - stride_om, - stride_dob, - stride_doh, - stride_dom, - nheads, - seqlen_q, - seqlen_q_rounded, - headdim, - BLOCK_M: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, -): - start_m = tl.program_id(0) - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads - # Initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_d = tl.arange(0, BLOCK_HEADDIM) - # Load o - o = tl.load( - Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ).to(tl.float32) - do = tl.load( - DO - + off_b * stride_dob - + off_h * stride_doh - + offs_m[:, None] * stride_dom - + offs_d[None, :], - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ).to(tl.float32) - delta = tl.sum(o * do, axis=1) - # Write back - tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) - - -@triton.jit -def _bwd_kernel_one_col_block( - start_n, - Q, - CuK, - CuV, - CuB, - CuM, - DO, - DQ, - DCuK, - DCuV, - DCuB, - LSE, - D, - softmax_scale, - stride_qm, - stride_ckk, - stride_cvk, - stride_cbk, - stride_cmm, - stride_cmk, - stride_dom, - stride_dqm, - stride_dckk, - stride_dcvk, - stride_dcbk, - seqlen_q, - window_size, - headdim, - BLOCK_HEADDIM: tl.constexpr, - EVEN_M: tl.constexpr, - EVEN_N: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, - ATOMIC_ADD: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - # Initialize row/col offsets - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_m = tl.arange(0, BLOCK_M) - offs_d = tl.arange(0, BLOCK_HEADDIM) - - # Initialize pointers to value-like data - q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_d[None, :]) - cuk_ptrs = CuK + (offs_n[:, None] * stride_ckk + offs_d[None, :]) - cuv_ptrs = CuV + (offs_n[:, None] * stride_cvk + offs_d[None, :]) - cub_ptrs = CuB + (offs_n * stride_cbk) - do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_d[None, :]) - dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_d[None, :]) - dcuk_ptrs = DCuK + (offs_n[:, None] * stride_dckk + offs_d[None, :]) - dcuv_ptrs = DCuV + (offs_n[:, None] * stride_dcvk + offs_d[None, :]) - dcub_ptrs = DCuB + (offs_n * stride_dcbk) - - # Initialize dv, dk, db accumulators - dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) - dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) - db = tl.zeros([BLOCK_N], dtype=tl.float32) - - # Load k and v, them will stay in SRAM throughout - if EVEN_N: - if EVEN_HEADDIM: - k = tl.load(cuk_ptrs) - v = tl.load(cuv_ptrs) - else: - k = tl.load(cuk_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - v = tl.load(cuv_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load(cuk_ptrs, mask=offs_n[:, None] < window_size, other=0.0) - v = tl.load(cuv_ptrs, mask=offs_n[:, None] < window_size, other=0.0) - else: - k = tl.load( - cuk_ptrs, - mask=(offs_n[:, None] < window_size) & (offs_d[None, :] < headdim), - other=0.0 - ) - v = tl.load( - cuv_ptrs, - mask=(offs_n[:, None] < window_size) & (offs_d[None, :] < headdim), - other=0.0 - ) - if EVEN_N: - b = tl.load(cub_ptrs) - else: - b = tl.load(cub_ptrs, mask=offs_n < window_size, other=0.0) - - # Scale k - k = (k * softmax_scale).to(k.dtype) - - # Loop over q and update accumulators - num_block_m = tl.cdiv(seqlen_q, BLOCK_M) - for start_m in range(0, num_block_m * BLOCK_M, BLOCK_M): - start_m = tl.multiple_of(start_m, BLOCK_M) - offs_m_curr = start_m + offs_m - - # Load mask - cum_ptrs = ( - CuM + offs_m_curr[:, None] * stride_cmm + offs_n[None, :] * stride_cmk - ) - if EVEN_M & EVEN_N: - m = tl.load(cum_ptrs) - else: - m = tl.load( - cum_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < window_size), - other=False, - ) - - # Check if any element in mask is non-zero - any_active = tl.reduce_or(m, axis=None) - - # Skip this iteration if no active elements - if any_active: - # Load q - if EVEN_M & EVEN_HEADDIM: - q = tl.load(q_ptrs) - else: - if EVEN_HEADDIM: - q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) - else: - q = tl.load( - q_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ) - - # Initialize acc_s - acc_s = b[None, :].to(tl.float32) - - # Compute acc_s - acc_s += tl.dot(q, tl.trans(k)) - - # Apply masks - # Trying to combine the two masks seem to make the result wrong - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - acc_s += tl.where(offs_n[None, :] < window_size, 0, float("-inf")) - acc_s += tl.where(m, 0, float("-inf")) - - lse_i = tl.load(LSE + offs_m_curr) - # p = tl.exp(acc_s - lse_i[:, None]) - p = tl.exp(acc_s - tl.where(lse_i > float("-inf"), lse_i, 0.0)[:, None]) - - # Load do - if EVEN_M & EVEN_HEADDIM: - do = tl.load(do_ptrs) - else: - # There's a race condition if we just use m_mask and not d_mask. - do = tl.load( - do_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ) - - # Compute dv - dv += tl.dot(tl.trans(p.to(do.dtype)), do) - - # Compute dp - dp = tl.dot(do, tl.trans(v)) - - # Putting the subtraction after the dp matmul (instead of before) is slightly faster - Di = tl.load(D + offs_m_curr) - - # Compute ds - # Converting ds to q.dtype here reduces register pressure and makes it much faster - # for BLOCK_HEADDIM=128 - ds = (p * (dp - Di[:, None])).to(q.dtype) - - # Compute db - db += tl.sum(ds, axis=0) - - # Compute dk - dk += tl.dot(tl.trans(ds), q) - - # Compute dq - if not ATOMIC_ADD: - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds, k).to(ds.dtype) - tl.store(dq_ptrs, dq, eviction_policy="evict_last") - else: - if EVEN_HEADDIM: - dq = tl.load( - dq_ptrs, - mask=offs_m_curr[:, None] < seqlen_q, - other=0.0, - eviction_policy="evict_last", - ) - dq += tl.dot(ds, k).to(ds.dtype) - tl.store( - dq_ptrs, - dq, - mask=offs_m_curr[:, None] < seqlen_q, - eviction_policy="evict_last", - ) - else: - dq = tl.load( - dq_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - eviction_policy="evict_last", - ) - dq += tl.dot(ds, k).to(ds.dtype) - tl.store( - dq_ptrs, - dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - eviction_policy="evict_last", - ) - else: # If we're parallelizing across the seqlen_k dimension - dq = tl.dot(ds, k).to(ds.dtype) - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - tl.atomic_add(dq_ptrs, dq) - else: - if EVEN_HEADDIM: - tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) - else: - tl.atomic_add( - dq_ptrs, - dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - ) - - # Increment pointers - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_dom - dq_ptrs += BLOCK_M * stride_dqm - else: - # Increment pointers - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_dom - dq_ptrs += BLOCK_M * stride_dqm - - # Scale dk - dk = (dk * softmax_scale).to(dk.dtype) - - # Write back - if EVEN_N: - if EVEN_HEADDIM: - tl.store(dcuk_ptrs, dk) - tl.store(dcuv_ptrs, dv) - else: - tl.store(dcuk_ptrs, dk, mask=offs_d[None, :] < headdim) - tl.store(dcuv_ptrs, dv, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(dcuk_ptrs, dk, mask=offs_n[:, None] < window_size) - tl.store(dcuv_ptrs, dv, mask=offs_n[:, None] < window_size) - else: - tl.store(dcuk_ptrs, dk, mask=(offs_n[:, None] < window_size) & (offs_d[None, :] < headdim)) - tl.store(dcuv_ptrs, dv, mask=(offs_n[:, None] < window_size) & (offs_d[None, :] < headdim)) - - if EVEN_N: - tl.store(dcub_ptrs, db) - else: - tl.store(dcub_ptrs, db, mask=(offs_n < window_size)) - - -def init_to_zero(names): - if isinstance(names, str): - names = [names] - def init_func(nargs): - for name in names: - nargs[name].zero_() - return init_func - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, - num_warps=8, - num_stages=1, - pre_hook=init_to_zero(["DQ", "DCuB"]), - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, - num_warps=8, - num_stages=1, - pre_hook=init_to_zero(["DQ", "DCuB"]), - ), - ], - key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BLOCK_HEADDIM"], -) -@triton.heuristics( - { - "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, - "EVEN_N": lambda args: args["window_size"] % args["BLOCK_N"] == 0, - "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], - } -) -@triton.jit -def _bwd_kernel( - Q, - CuK, - CuV, - CuB, - CuM, - DO, - DQ, - DCuK, - DCuV, - DCuB, - LSE, - D, - softmax_scale, - stride_qb, - stride_qh, - stride_qm, - stride_ckb, - stride_ckh, - stride_ckk, - stride_cvb, - stride_cvh, - stride_cvk, - stride_cbb, - stride_cbh, - stride_cbk, - stride_cmb, - stride_cmh, - stride_cmm, - stride_cmk, - stride_dob, - stride_doh, - stride_dom, - stride_dqb, - stride_dqh, - stride_dqm, - stride_dckb, - stride_dckh, - stride_dckk, - stride_dcvb, - stride_dcvh, - stride_dcvk, - stride_dcbb, - stride_dcbh, - stride_dcbk, - nheads, - h_h_k_ratio, - seqlen_q, - window_size, - seqlen_q_rounded, - headdim, - CACHE_KEY_SEQLEN_Q, - CACHE_KEY_SEQLEN_K, - BLOCK_HEADDIM: tl.constexpr, - SEQUENCE_PARALLEL: tl.constexpr, - EVEN_M: tl.constexpr, - EVEN_N: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_hq = off_hb % nheads - off_hk = off_hq // h_h_k_ratio - - # Advance offset pointers for batch and head - Q += off_b * stride_qb + off_hq * stride_qh - CuK += off_b * stride_ckb + off_hk * stride_ckh - CuV += off_b * stride_cvb + off_hk * stride_cvh - CuB += off_b * stride_cbb + off_hk * stride_cbh - CuM += off_b * stride_cmb + off_hk * stride_cmh - DO += off_b * stride_dob + off_hq * stride_doh - DQ += off_b * stride_dqb + off_hq * stride_dqh - DCuK += off_b * stride_dckb + off_hq * stride_dckh - DCuV += off_b * stride_dcvb + off_hq * stride_dcvh - DCuB += off_b * stride_dcbb + off_hq * stride_dcbh - # Advance pointer to row-wise quantities in value-like data - D += off_hb * seqlen_q_rounded - LSE += off_hb * seqlen_q_rounded - - if not SEQUENCE_PARALLEL: - num_block_n = tl.cdiv(window_size, BLOCK_N) - for start_n in range(0, num_block_n): - _bwd_kernel_one_col_block( - start_n, - Q, - CuK, - CuV, - CuB, - CuM, - DO, - DQ, - DCuK, - DCuV, - DCuB, - LSE, - D, - softmax_scale, - stride_qm, - stride_ckk, - stride_cvk, - stride_cbk, - stride_cmm, - stride_cmk, - stride_dom, - stride_dqm, - stride_dckk, - stride_dcvk, - stride_dcbk, - seqlen_q, - window_size, - headdim, - BLOCK_HEADDIM=BLOCK_HEADDIM, - EVEN_M=EVEN_M, - EVEN_N=EVEN_N, - EVEN_HEADDIM=EVEN_HEADDIM, - ATOMIC_ADD=False, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) - else: - start_n = tl.program_id(0) - _bwd_kernel_one_col_block( - start_n, - Q, - CuK, - CuV, - CuB, - CuM, - DO, - DQ, - DCuK, - DCuV, - DCuB, - LSE, - D, - softmax_scale, - stride_qm, - stride_ckk, - stride_cvk, - stride_cbk, - stride_cmm, - stride_cmk, - stride_dom, - stride_dqm, - stride_dckk, - stride_dcvk, - stride_dcbk, - seqlen_q, - window_size, - headdim, - BLOCK_HEADDIM=BLOCK_HEADDIM, - EVEN_M=EVEN_M, - EVEN_N=EVEN_N, - EVEN_HEADDIM=EVEN_HEADDIM, - ATOMIC_ADD=True, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) - - -def _flash_dmattn_forward(q, k, v, b, i, softmax_scale=None, is_causal=False, window_size=None): - # shape constraints - batch, nheads, seqlen_q, d = q.shape - _, nheads_k, seqlen_k, _ = k.shape - - assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" - assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" - assert q.dtype == k.dtype == v.dtype == b.dtype, "All tensors must have the same type" - assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" - assert i.dtype == torch.int64, "Indices must be int64" - assert q.is_cuda and k.is_cuda and v.is_cuda and b.is_cuda, "All tensors must be on GPU" - - softmax_scale = softmax_scale or 1.0 / math.sqrt(d) - - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - cu_k = torch.empty((batch, nheads_k, window_size, d), device=q.device, dtype=k.dtype) - cu_v = torch.empty((batch, nheads_k, window_size, d), device=q.device, dtype=v.dtype) - cu_b = torch.empty((batch, nheads_k, window_size), device=q.device, dtype=b.dtype) - cu_m = torch.zeros((batch, nheads_k, seqlen_q, window_size), device=q.device, dtype=torch.bool) - - lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - o = torch.empty_like(q) - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - grid = (batch * nheads_k,) - _fwd_preprocess[grid]( - k, v, b, i, - cu_k, cu_v, cu_b, cu_m, - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - b.stride(0), b.stride(1), b.stride(2), - i.stride(0), i.stride(1), i.stride(2), - cu_k.stride(0), cu_k.stride(1), cu_k.stride(2), - cu_v.stride(0), cu_v.stride(1), cu_v.stride(2), - cu_b.stride(0), cu_b.stride(1), cu_b.stride(2), - cu_m.stride(0), cu_m.stride(1), cu_m.stride(2), cu_m.stride(3), - nheads_k, seqlen_q, seqlen_k, window_size, d, is_causal, BLOCK_HEADDIM - ) - - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - _fwd_kernel[grid]( - q, - cu_k, cu_v, cu_b, cu_m, - o, lse, softmax_scale, - q.stride(0), q.stride(1), q.stride(2), - cu_k.stride(0), cu_k.stride(1), cu_k.stride(2), - cu_v.stride(0), cu_v.stride(1), cu_v.stride(2), - cu_b.stride(0), cu_b.stride(1), cu_b.stride(2), - cu_m.stride(0), cu_m.stride(1), cu_m.stride(2), cu_m.stride(3), - o.stride(0), o.stride(1), o.stride(2), - nheads, nheads // nheads_k, seqlen_q, window_size, seqlen_q_rounded, d, - seqlen_q // 32, - window_size // 32, # key for triton cache (limit number of compilations) - # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # BLOCK_HEADDIM=d, - BLOCK_HEADDIM, - # BLOCK_M=BLOCK_M, - # BLOCK_N=BLOCK_N, - # num_warps=num_warps, - # num_stages=1, - ) - return o, lse, softmax_scale, cu_k, cu_v, cu_b, cu_m - - -def _flash_dmattn_backward( - do, q, cuk, cuv, cub, cum, i, o, lse, softmax_scale, seqlen_q, seqlen_k, window_size -): - # Make sure that the last dimension is contiguous - if do.stride(-1) != 1: - do = do.contiguous() - batch, nheads, _, d = q.shape - _, nheads_k, _, _ = cuk.shape - - assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" - assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - seqlen_k_rounded = math.ceil(seqlen_k / 128) * 128 - assert lse.shape == (batch, nheads, seqlen_q_rounded) - - softmax_scale = softmax_scale or 1.0 / math.sqrt(d) - # dq_accum = torch.zeros_like(q, dtype=torch.float32) - dq_accum = torch.empty_like(q, dtype=torch.float32) - delta = torch.empty_like(lse) - # delta = torch.zeros_like(lse) - dk = torch.zeros(batch, nheads_k, seqlen_k, d, device=q.device, dtype=q.dtype) - dv = torch.zeros(batch, nheads_k, seqlen_k, d, device=q.device, dtype=q.dtype) - db = torch.zeros(batch, nheads_k, seqlen_k, device=q.device, dtype=q.dtype) - - dk_expanded = torch.empty(batch, nheads, window_size, d, device=q.device, dtype=q.dtype) - dv_expanded = torch.empty(batch, nheads, window_size, d, device=q.device, dtype=q.dtype) - db_expanded = torch.empty(batch, nheads, window_size, device=q.device, dtype=q.dtype) - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - _bwd_preprocess_do_o_dot[grid]( - o, do, delta, - o.stride(0), o.stride(1), o.stride(2), - do.stride(0), do.stride(1), do.stride(2), - nheads, seqlen_q, seqlen_q_rounded, d, - BLOCK_M=64, - BLOCK_HEADDIM=BLOCK_HEADDIM, - ) - - # BLOCK_M = 128 - # BLOCK_N = 64 - # num_warps = 4 - grid = lambda META: ( - triton.cdiv(window_size, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, - batch * nheads, - ) - _bwd_kernel[grid]( - q, cuk, cuv, cub, cum, do, - dq_accum, dk_expanded, dv_expanded, db_expanded, - lse, delta, softmax_scale, - q.stride(0), q.stride(1), q.stride(2), - cuk.stride(0), cuk.stride(1), cuk.stride(2), - cuv.stride(0), cuv.stride(1), cuv.stride(2), - cub.stride(0), cub.stride(1), cub.stride(2), - cum.stride(0), cum.stride(1), cum.stride(2), cum.stride(3), - do.stride(0), do.stride(1), do.stride(2), - dq_accum.stride(0), dq_accum.stride(1), dq_accum.stride(2), - dk_expanded.stride(0), dk_expanded.stride(1), dk_expanded.stride(2), - dv_expanded.stride(0), dv_expanded.stride(1), dv_expanded.stride(2), - db_expanded.stride(0), db_expanded.stride(1), db_expanded.stride(2), - nheads, nheads // nheads_k, seqlen_q, window_size, seqlen_q_rounded, d, - seqlen_q // 32, - window_size // 32, # key for triton cache (limit number of compilations) - # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # BLOCK_HEADDIM=BLOCK_HEADDIM, - BLOCK_HEADDIM, - # SEQUENCE_PARALLEL=False, - # BLOCK_M=BLOCK_M, - # BLOCK_N=BLOCK_N, - # num_warps=num_warps, - # num_stages=1, - ) - dq = dq_accum.to(q.dtype) - - if nheads != nheads_k: - dk_expanded = dk_expanded.view(batch, nheads_k, nheads // nheads_k, window_size, d).sum(dim=2) - dv_expanded = dv_expanded.view(batch, nheads_k, nheads // nheads_k, window_size, d).sum(dim=2) - db_expanded = db_expanded.view(batch, nheads_k, nheads // nheads_k, window_size).sum(dim=2) - - dk.scatter_add_( - dim=2, - index=i.unsqueeze(-1).expand(-1, -1, -1, d), - src=dk_expanded, - ) - dv.scatter_add_( - dim=2, - index=i.unsqueeze(-1).expand(-1, -1, -1, d), - src=dv_expanded, - ) - db.scatter_add_( - dim=2, - index=i, - src=db_expanded, - ) - - return dq, dk, dv, db - - -def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - - -def round_multiple(x, m): - return (x + m - 1) // m * m - - -class FlashDMAttnFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, query, key, value, attn_bias, attn_indices, is_causal=False, softmax_scale=None): - """ - query: (batch_size, nheads, seqlen_q, headdim) - key: (batch_size, nheads_k, seqlen_k, headdim) - value: (batch_size, nheads_k, seqlen_k, headdim) - attn_bias: (batch_size, nheads_k, seqlen_k) - attn_indices: (batch_size, nheads_k, window_size) - is_causal: bool, whether to apply causal masking - softmax_scale: float, scaling factor for attention scores - """ - - # Make sure that the last dimension is contiguous - query, key, value, attn_bias, attn_indices = [maybe_contiguous(x) for x in [query, key, value, attn_bias, attn_indices]] - - # Padding to multiple of 8 for 16-bit memory allocations - head_size_og = query.size(3) - if head_size_og % 8 != 0: - query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8]) - key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8]) - value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8]) - seqlen_k_rounded = round_multiple(key.shape[2], 128) - if attn_bias.shape[-1] != seqlen_k_rounded: - attn_bias = torch.nn.functional.pad(attn_bias, [0, seqlen_k_rounded - attn_bias.shape[-1]]) - window_size = attn_indices.shape[-1] - - o, lse, ctx.softmax_scale, cu_key, cu_value, cu_attn_bias, cu_attn_mask = _flash_dmattn_forward( - query, - key, - value, - attn_bias, - attn_indices, - softmax_scale=softmax_scale, - is_causal=is_causal, - window_size=window_size, - ) - ctx.save_for_backward(query, cu_key, cu_value, cu_attn_bias, cu_attn_mask, attn_indices, o, lse) - ctx.seqlen_q = query.size(2) - ctx.seqlen_k = key.size(2) - ctx.window_size = window_size - - o = o[..., : head_size_og] - return o - - @staticmethod - def backward(ctx, do): - query, cu_key, cu_value, cu_attn_bias, cu_attn_mask, attn_indices, o, lse = ctx.saved_tensors - - head_size_og = do.size(3) - do_padded = do - if head_size_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) - - dq, dk, dv, db = _flash_dmattn_backward( - do_padded, - query, - cu_key, - cu_value, - cu_attn_bias, - cu_attn_mask, - attn_indices, - o, - lse, - softmax_scale=ctx.softmax_scale, - seqlen_q=ctx.seqlen_q, - seqlen_k=ctx.seqlen_k, - window_size=ctx.window_size, - ) - - # We could have padded the head dimension - dq = dq[..., : do.shape[-1]] - dk = dk[..., : do.shape[-1]] - dv = dv[..., : do.shape[-1]] - - return dq, dk, dv, db, None, None, None - - -def triton_dmattn_func(query, key, value, attn_bias, attn_indices, is_causal=False, softmax_scale=None): - return FlashDMAttnFunc.apply(query, key, value, attn_bias, attn_indices, is_causal, softmax_scale) \ No newline at end of file diff --git a/flash_dmattn/integrations/flash_dynamic_mask_attention.py b/flash_dmattn/integrations/flash_dynamic_mask_attention.py deleted file mode 100644 index f842ae6..0000000 --- a/flash_dmattn/integrations/flash_dynamic_mask_attention.py +++ /dev/null @@ -1,111 +0,0 @@ -from typing import Optional - -import torch - -from .modeling_flash_dynamic_mask_attention_utils import _flash_dynamic_mask_attention_forward -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - - -def flash_dynamic_mask_attention_forward( - module: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - attention_bias: Optional[torch.Tensor], - scaling: Optional[float] = None, - window_size: Optional[int] = None, - softcap: Optional[float] = None, - **kwargs, -) -> tuple[torch.Tensor, None]: - """ - A wrapper around the _flash_dynamic_mask_attention_forward function to be used in - the FlashDynamicMaskAttention class from HuggingFace Transformers. - - Args: - module (torch.nn.Module): The attention module. - query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim). - key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim). - value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim). - attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape - (batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}). - attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape - ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}). - scaling (Optional[float]): The scaling factor for the attention scores. - window_size (Optional[int]): The size of the window to keep. - softcap (Optional[float]): The softcap value for the attention scores. - **kwargs: Additional keyword arguments. - Includes: - - is_causal (bool): Whether to apply a causal mask. - - layer_idx (int): The index of the layer (for logging purposes). - - implementation (str): The implementation to use ("flash_dmattn" or None). - - Returns: - tuple[torch.Tensor, None]: The output tensor of shape (batch_size, seq_len, num_heads, head_dim) - and None (for compatibility with other attention implementations). - """ - - if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None: - logger.warning_once( - "`flash_dynamic_mask_attention` does not support `output_attentions=True` or `head_mask`." - " Please set your attention to `eager` if you want any of these features." - ) - - # This is before the transpose - query_len = query.shape[2] - key_len = key.shape[2] - - if any(dim == 0 for dim in query.shape): - raise ValueError( - "Tensor query has shape with a zero dimension.\n" - "FlashDynamicMaskAttention does not support inputs with dim=0.\n" - "Please check your input shapes or use SDPA instead." - ) - - # FDMA uses non-transposed inputs - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (usually our RMSNorm modules handle it correctly) - target_dtype = None - if query.dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(module.config, "_pre_quantization_dtype"): - target_dtype = module.config._pre_quantization_dtype - else: - target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype - - # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented - is_causal = kwargs.pop("is_causal", None) - if is_causal is None: - is_causal = module.is_causal - - attn_output = _flash_dynamic_mask_attention_forward( - query, - key, - value, - attention_mask, - attention_bias, - query_length=query_len, - key_length=key_len, - is_causal=is_causal, - softmax_scale=scaling, - softcap=softcap, - window_size=window_size, - target_dtype=target_dtype, - implementation="flash_dmattn", - layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None, - **kwargs, - ) - - return attn_output, None diff --git a/flash_dmattn/integrations/import_utils.py b/flash_dmattn/integrations/import_utils.py deleted file mode 100644 index 583248b..0000000 --- a/flash_dmattn/integrations/import_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2025 Jingze Shi and the HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Import utilities: Utilities related to imports and our lazy inits. -""" - -import importlib.metadata -import importlib.util -from functools import lru_cache -from typing import Union - - -from transformers import is_torch_available -from transformers.utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better. -def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[tuple[bool, str], bool]: - # Check if the package spec exists and grab its version to avoid importing a local directory - package_exists = importlib.util.find_spec(pkg_name) is not None - package_version = "N/A" - if package_exists: - try: - # TODO: Once python 3.9 support is dropped, `importlib.metadata.packages_distributions()` - # should be used here to map from package name to distribution names - # e.g. PIL -> Pillow, Pillow-SIMD; quark -> amd-quark; onnxruntime -> onnxruntime-gpu. - # `importlib.metadata.packages_distributions()` is not available in Python 3.9. - - # Primary method to get the package version - package_version = importlib.metadata.version(pkg_name) - except importlib.metadata.PackageNotFoundError: - # Fallback method: Only for "torch" and versions containing "dev" - if pkg_name == "torch": - try: - package = importlib.import_module(pkg_name) - temp_version = getattr(package, "__version__", "N/A") - # Check if the version contains "dev" - if "dev" in temp_version: - package_version = temp_version - package_exists = True - else: - package_exists = False - except ImportError: - # If the package can't be imported, it's not available - package_exists = False - elif pkg_name == "quark": - # TODO: remove once `importlib.metadata.packages_distributions()` is supported. - try: - package_version = importlib.metadata.version("amd-quark") - except Exception: - package_exists = False - elif pkg_name == "triton": - try: - package_version = importlib.metadata.version("pytorch-triton") - except Exception: - package_exists = False - else: - # For packages other than "torch", don't attempt the fallback and set as not available - package_exists = False - logger.debug(f"Detected {pkg_name} version: {package_version}") - if return_version: - return package_exists, package_version - else: - return package_exists - - - -@lru_cache -def is_flash_dmattn_available(): - if not is_torch_available(): - return False - - if not _is_package_available("flash_dmattn"): - return False - - import torch - - if not torch.cuda.is_available(): - return False - - return True \ No newline at end of file diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py deleted file mode 100644 index c2638b8..0000000 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ /dev/null @@ -1,597 +0,0 @@ -# Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -import os -from functools import partial -from typing import Optional, TypedDict - -import torch -import torch.nn.functional as F - -from .import_utils import is_flash_dmattn_available -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - - -# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves -_fdma_fn = None -_fdma_varlen_fn = None -_pad_fn = None -_unpad_fn = None -_create_mask_fn = None - -# function that processes kwargs, generalized to handle any supported kwarg within the function -_process_flash_kwargs_fn = None -# exceptions where hf API doesn't match the original FDMA API -_hf_api_to_flash_mapping = { - "dropout": None, - "sliding_window": None, -} - - -def _lazy_imports(implementation: Optional[str]): - """ - Lazy loads the respective flash dynamic mask attention implementations. - - Return: - flash_attn_func: The base flash dynamic mask attention function. - flash_attn_varlen_func: The flash dynamic mask attention function supporting variable sequence lengths, e.g. for padding-free training. - pad_input: The function to pad inputs into one sequence and returning the respective kwargs. - unpad_input: The function to unpad outputs based on the kwargs (from pad_input). - """ - is_fdma = is_flash_dmattn_available() - - if (implementation == "flash_dmattn" and is_fdma) or (implementation is None and is_fdma): - from flash_dmattn import flash_dmattn_func, flash_dmattn_varlen_func - from flash_dmattn.utils.padding import pad_input, unpad_input - from flash_dmattn.utils.mask import create_mask - - return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input, create_mask - - -def _lazy_define_process_function(flash_function): - """ - Depending on the version and kernel some features are not supported. Due to limitations in - `torch.compile`, we opt to statically type which (optional) kwarg parameters are supported - within `_process_flash_dynamic_mask_attention_kwargs`. - - NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`. - This might be confusing for kwargs that we use in any case, e.g. `is_causal`. - """ - - flash_parameters = inspect.signature(flash_function).parameters - process_parameters = inspect.signature(_process_flash_dynamic_mask_attention_kwargs).parameters - - supports_mapping = {} - for param in process_parameters: - fdma_param = _hf_api_to_flash_mapping.get(param, param) - supports_mapping[fdma_param] = fdma_param in flash_parameters - - return partial(_process_flash_dynamic_mask_attention_kwargs, supports_mapping=supports_mapping) - - -def lazy_import_flash_dynamic_mask_attention(implementation: Optional[str], force_import: Optional[bool] = False): - """ - Lazily import flash dmattn and return the respective functions + flags. - - NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can - work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`. - """ - global _fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn - if force_import or any(k is None for k in [_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn]): - _fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn = _lazy_imports(implementation) - - global _process_flash_kwargs_fn - if force_import or _process_flash_kwargs_fn is None: - _process_flash_kwargs_fn = _lazy_define_process_function(_fdma_varlen_fn) - - return (_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn), _process_flash_kwargs_fn - - -def _index_first_axis(tensor, indices): - """ - A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, - after flattening the first two dimensions of the tensor. This is functionally equivalent to - FA2's `index_first_axis` and replaces the need to import it. - """ - # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first - # two dimensions to get (total_tokens, ...) before indexing. - reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) - return reshaped_tensor[indices] - - -def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: - """ - Retrieves indexing data required to repad unpadded (ragged) tensors. - - Arguments: - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - - Return: - indices (`torch.Tensor`): - The indices of non-masked tokens from the flattened input sequence. - cu_seqlens (`torch.Tensor`): - The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - max_seqlen_in_batch (`int`): - Maximum sequence length in batch. - """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - # NOTE: Similar to the `.item()` in prepare_fdma_kwargs_from_position_ids, with torch compile, - # this might cause a graph break - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def _upad_input( - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, - unpad_input_func, -): - """ - Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. - This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary - tensors for query, key, value tensors. - - Arguments: - query_layer (`torch.Tensor`): - Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - query_length (`int`): - Target length. - unpad_input_func: - The function to use for unpadding the input tensors. - - Return: - query_layer (`torch.Tensor`): - Query state without padding. Shape: (total_target_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - indices_q (`torch.Tensor`): - The indices of non-masked tokens from the flattened input target sequence. - (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - - # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage - # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores - if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): - key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] - - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = _index_first_axis(key_layer, indices_k) - value_layer = _index_first_axis(value_layer, indices_k) - if query_length == kv_seq_len: - query_layer = _index_first_axis(query_layer, indices_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -def prepare_fdma_kwargs_from_position_ids(position_ids): - """ - This function returns all the necessary kwargs to call `flash_attn_varlen_func` extracted from position_ids. - - Arguments: - position_ids (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - - Return: - (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into - ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, - `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device} - - position_ids = position_ids.view(-1) - indices_q = (position_ids == 0).nonzero().view(-1) - - cu_seq_lens_q = torch.cat( - ( - indices_q.to(**tensor_kwargs), - torch.tensor(position_ids.size(), **tensor_kwargs), - ) - ) - cu_seq_lens_k = cu_seq_lens_q - - # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 - # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing - # for some models (e.g. qwen2-vl). - max_length_q = cu_seq_lens_q.diff().max() - # NOTE: With torch compile, this will cause a graph break if you don't set - # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call - # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. - # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` - # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. - max_length_q = max_length_q.item() - max_length_k = max_length_q - - return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) - - -def _prepare_from_posids(query, key, value, position_ids): - """ - This function returns necessary arguments to call `flash_attn_varlen_func`. - All three query, key, value states will be flattened. - Cumulative lengths of each examples in the batch will be extracted from position_ids. - NOTE: ideally cumulative lengths should be prepared at the data collator stage - - Arguments: - query (`torch.Tensor`): - Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). - key (`torch.Tensor`): - Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - value (`torch.Tensor`): - Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - position_ids (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - - Return: - query (`torch.Tensor`): - Query state without padding. Shape: (total_target_length, num_heads, head_dim). - key (`torch.Tensor`): - Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - value (`torch.Tensor`): - Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - query = query.contiguous().view(-1, query.size(-2), query.size(-1)) - key = key.contiguous().view(-1, key.size(-2), key.size(-1)) - value = value.contiguous().view(-1, value.size(-2), value.size(-1)) - - (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fdma_kwargs_from_position_ids(position_ids) - - return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)) - - -def _is_packed_sequence(position_ids, batch_size): - """ - Check the position ids whether packed sequences are indicated or not - 1. Position ids exist - 2. Flattened sequences only are supported - 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences - """ - if position_ids is None: - return False - - increasing_position_sequences = ( - torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min() - ) - return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool() - - -def fdma_peft_integration_check( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - bias: Optional[torch.Tensor], - target_dtype: Optional[torch.dtype] = None -): - """ - PEFT usually casts the layer norms in float32 for training stability reasons - therefore the input hidden states gets silently casted in float32. Hence, we need - cast them back in float16 / bfloat16 just to be sure everything works as expected. - This might slowdown training & inference so it is recommended to not cast the LayerNorms! - """ - if target_dtype and q.dtype == torch.float32: - logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-dmattn compatibility.") - q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) - if bias is not None: - bias = bias.to(target_dtype) - return q, k, v, bias - - -class FlashDynamicMaskAttentionKwargs(TypedDict, total=False): - """ - Keyword arguments for Flash Dynamic Mask Attention with Compile. - - Attributes: - cu_seq_lens_q (`torch.LongTensor`, *optional*) - Gets cumulative sequence length for query state. - cu_seq_lens_k (`torch.LongTensor`, *optional*) - Gets cumulative sequence length for key state. - max_length_q (`int`, *optional*): - Maximum sequence length for query state. - max_length_k (`int`, *optional*): - Maximum sequence length for key state. - """ - - cu_seq_lens_q: Optional[torch.LongTensor] - cu_seq_lens_k: Optional[torch.LongTensor] - max_length_q: Optional[int] - max_length_k: Optional[int] - - -def _process_flash_dynamic_mask_attention_kwargs( - query_length: int, - key_length: int, - is_causal: bool, - softmax_scale: Optional[float] = None, - window_size: Optional[int] = None, - softcap: Optional[float] = None, - deterministic: Optional[bool] = None, - s_aux: Optional[torch.Tensor] = None, - supports_mapping: Optional[dict[str, bool]] = None, - **kwargs, -): - """ - Returns a set of kwargs that are passed down to the according flash attention function based on - requested features and whether it is supported - depends on the version and kernel implementation - which is dynamically configured at `lazy_import_flash_dynamic_mask_attention`. The (un)supported features can be - inspected in `supports_mapping`, see `_lazy_define_process_function` for more details. - - Args: - query_length (`int`): - Length of the query states - key_length (`int`): - Length of the key states - is_causal (`bool`): - Whether we perform causal (decoder) attention or full attention. - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`. - window_size (`int`, *optional*): - If set, only the `window_size` largest key/value pairs per query are kept. - softcap (`float`, *optional*): - Softcap for the attention logits, used e.g. in gemma2. - deterministic (`bool`, *optional*): - Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. - s_aux (`torch.Tensor`, *optional*): - Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head. - Return: - flash_kwargs (`dict`): - A dict of kwargs that are requested and supported. - """ - flash_kwargs = { - "is_causal": is_causal and not query_length == 1, - "softmax_scale": softmax_scale, - } - - if supports_mapping["window_size"] and window_size is not None and key_length > window_size: - flash_kwargs["window_size"] = window_size - - if supports_mapping["deterministic"]: - flash_kwargs["deterministic"] = ( - deterministic if deterministic is not None else os.getenv("FLASH_DMATTN_DETERMINISTIC", "0") == "1" - ) - - if supports_mapping["softcap"] and softcap is not None: - flash_kwargs["softcap"] = softcap - - # Only within kernel implementation atm - if supports_mapping["s_aux"] and s_aux is not None: - flash_kwargs["s_aux"] = s_aux - - return flash_kwargs - - -def _flash_dynamic_mask_attention_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - attention_mask: Optional[torch.Tensor], - attention_bias: Optional[torch.Tensor], - query_length: int, - key_length: int, - is_causal: bool, - position_ids: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - window_size: Optional[int] = None, - softcap: Optional[float] = None, - deterministic: Optional[bool] = None, - cu_seq_lens_q: Optional[torch.LongTensor] = None, - cu_seq_lens_k: Optional[torch.LongTensor] = None, - max_length_q: Optional[int] = None, - max_length_k: Optional[int] = None, - target_dtype: Optional[torch.dtype] = None, - implementation: Optional[str] = None, - **kwargs, -): - """ - Calls the forward method of Flash Dynamic Mask Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - (Optional) kwargs are described further in `_process_flash_dynamic_mask_attention_kwargs` and `FlashDynamicMaskAttentionKwargs`. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash DMATTN API - key_states (`torch.Tensor`): - Input key states to be passed to Flash DMATTN API - value_states (`torch.Tensor`): - Input value states to be passed to Flash DMATTN API - attention_mask (`torch.Tensor`, *optional*): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - attention_bias (`torch.Tensor`, *optional*): - The attention bias tensor to add to attention scores. - implementation (`str`, *optional*): - The attention implementation to use. If None, will default to the one based on the environment. - """ - - if ( - attention_mask is not None - and attention_mask.dim() == 2 - and attention_bias is not None - ): - raise ValueError( - "If shape of attention_mask is (batch_size, seq_len), attention_bias has to be None." - ) - - (fdma_fn, fdma_varlen_fn, pad_fn, unpad_fn, create_mask_fn), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention(implementation) - - # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op - query_states, key_states, value_states, attention_bias = fdma_peft_integration_check( - query_states, key_states, value_states, attention_bias, target_dtype - ) - - # Extract the flash attention kwargs that have been requested (and are supported by the implementation) - flash_kwargs = process_flash_kwargs_fn( - query_length=query_length, - key_length=key_length, - is_causal=is_causal, - softmax_scale=softmax_scale, - window_size=window_size, - softcap=softcap, - deterministic=deterministic, - **kwargs, - ) - - # We will use `fdma_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases: - # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`. - # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to - # use `fdma_varlen_fn` knowing we already have all necessary the kwargs. - # - # NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model. - # See #39121 for more information. - is_fdma_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0)) - is_fdma_with_varlen_kwargs = all( - kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) - ) - - # Contains at least one padding token in the sequence - if attention_mask is not None and attention_mask.dim() == 2: - q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input( - query_states, key_states, value_states, attention_mask, query_length, unpad_fn - ) - - # TODO for now this is required to work with - # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py - if "mps" in str(q.device): - cu_seq_lens_k = cu_seq_lens_k.clone() - - out_unpad = fdma_varlen_fn( - q, - k, - v, - cu_seqlens_q=cu_seq_lens_q, - cu_seqlens_k=cu_seq_lens_k, - max_seqlen_q=max_length_q, - max_seqlen_k=max_length_k, - **flash_kwargs, - ) - if isinstance(out_unpad, tuple): - out_unpad = out_unpad[0] - - out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length) - - # Padding free, i.e. sequences flattened into one total sequence - elif is_fdma_with_varlen_kwargs or is_fdma_with_position_ids: - if cu_seq_lens_q is None or cu_seq_lens_k is None: - q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids( - query_states, key_states, value_states, position_ids - ) - else: - q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) - k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) - v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) - - # TODO for now this is required to work with - # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py - if "mps" in str(q.device): - cu_seq_lens_k = cu_seq_lens_k.clone() - - out = fdma_varlen_fn( - q, - k, - v, - cu_seqlens_q=cu_seq_lens_q, - cu_seqlens_k=cu_seq_lens_k, - max_seqlen_q=max_length_q, - max_seqlen_k=max_length_k, - **flash_kwargs, - ) - if isinstance(out, tuple): - out = out[0] - - out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1)) - - # No padding - else: - - # Generate a combined attention mask if `attention_bias` are provided - if ( - attention_bias is not None - and window_size is not None - and key_length > window_size - ): - attention_mask = create_mask_fn( - attention_bias, - attention_mask, - batch_size=query_states.size(0), - query_len=query_length, - key_len=key_length, - window_size=window_size, - min_dtype=torch.finfo(attention_bias.dtype).min, - ) - - out = fdma_fn( - query_states, - key_states, - value_states, - attention_mask, - attention_bias, - **flash_kwargs, - ) - if isinstance(out, tuple): - out = out[0] - - return out diff --git a/flash_dmattn/utils/mask.py b/flash_dmattn/utils/mask.py deleted file mode 100644 index 491e270..0000000 --- a/flash_dmattn/utils/mask.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright 2025 Jingze Shi and Liangdong Wang. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch - - -def topk_indices( - attention_bias: torch.Tensor, - window_size: int, - **kwargs, -) -> torch.Tensor: - r""" - This function generates top-k indices based on the attention bias. - - Args: - attention_bias (torch.Tensor): The attention bias tensor of - (batch_size, num_kv_heads, key_len). - window_size (int): The number of top elements to consider for the mask. - **kwargs: Additional keyword arguments. - - Returns: - topk_indices (Tensor): The top-k indices tensor of shape - (batch_size, num_kv_heads, window_size). - """ - attention_bias = attention_bias.detach() - topk_indices = torch.topk( - attention_bias, - window_size, dim=-1, largest=True, sorted=False - ).indices - topk_indices = torch.sort(topk_indices, dim=-1).values - return topk_indices - - -def block_smooth( - attention_mask: torch.Tensor, - key_len: int, - block_size: int, -): - if block_size <= 0: - raise ValueError(f"block_size must be a positive integer, got {block_size}.") - - if block_size > 1: - full_len = (key_len // block_size) * block_size - - if full_len: - block_view = attention_mask[..., :full_len] - block_shape = (*block_view.shape[:-1], full_len // block_size, block_size) - blocks = block_view.view(*block_shape) - block_counts = blocks.sum(dim=-1).to(torch.int64) - block_keep = (block_counts * 2) > block_size - blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks)) - - if key_len > full_len: - tail_slice = attention_mask[..., full_len:] - tail_len = tail_slice.shape[-1] - tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64) - tail_keep = (tail_counts * 2) > tail_len - tail_slice.copy_(tail_keep.expand_as(tail_slice)) - - return attention_mask - - -def topk_mask( - attention_bias: torch.Tensor, - attention_mask: Optional[torch.Tensor], - window_size: int, - min_dtype: float, - block_size: Optional[int] = None, - **kwargs, -): - r""" - This function generates a dynamic mask based on the top-k attention bias. - - Args: - attention_bias (torch.Tensor): The attention bias tensor of shape - ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). - attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape - ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). - window_size (int): The number of top elements to consider for the mask. - min_dtype (float): The minimum value to use for masking. - block_size (Optional[int]): Optional size of aggregation blocks to smooth the - resulting mask along the key dimension. - - Returns: - attention_mask (Tensor): The attention mask tensor of shape - ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). - """ - - attention_bias = attention_bias.detach() - attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias - topk_values, topk_indices = torch.topk( - attention_bias, - window_size, dim=-1, largest=True, sorted=False - ) - attention_mask = torch.zeros_like( - attention_bias, dtype=torch.bool, device=attention_bias.device - ).scatter_(-1, topk_indices, topk_values != min_dtype) - - if block_size is not None and block_size > 1: - key_len = attention_mask.shape[-1] - attention_mask = block_smooth( - attention_mask=attention_mask, - key_len=key_len, - block_size=block_size - ) - - return attention_mask - - -def relu_mask( - attention_bias: torch.Tensor, - attention_mask: Optional[torch.Tensor], - min_dtype: float, - block_size: Optional[int] = None, - **kwargs -): - r""" - This function generates a dynamic mask based on the ReLU of attention bias. - - Args: - attention_bias (torch.Tensor): The attention bias tensor of shape - ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). - attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape - ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). - min_dtype (float): The minimum value to use for masking. - block_size (Optional[int]): Optional size of aggregation blocks to smooth the - resulting mask along the key dimension. - - Returns: - attention_mask (Tensor): The attention mask tensor of shape - ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). - """ - - attention_bias = attention_bias.detach() - attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias - attention_mask = attention_bias > 0 - - if block_size is not None and block_size > 1: - key_len = attention_mask.shape[-1] - attention_mask = block_smooth( - attention_mask=attention_mask, - key_len=key_len, - block_size=block_size - ) - - return attention_mask - - - -def create_mask( - attention_bias: torch.Tensor, - attention_mask: Optional[torch.Tensor], - batch_size: int, - query_len: int, - key_len: int, - window_size: int, - min_dtype: float, - block_size: Optional[int] = None, - type: str = "topk", -) -> torch.Tensor: - r""" - This function creates a mask tensor for Flash Dynamic Mask Attention. - - If attention_mask is not of shape (batch_size, seq_len), it needs to match the shape of attention_bias. - - Args: - attention_bias (torch.Tensor): The attention bias tensor of shape - ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). - attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape - (batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). - batch_size (int): The batch size. - query_len (int): The sequence length of the query. - key_len (int): The sequence length of the key. - window_size (int): The number of top elements to consider for the attention mask. - min_dtype (float): The minimum value to use for masking. - block_size (Optional[int]): Optional size of aggregation blocks after top-k masking. - type (str): The type of mask to create. Options are "topk" and "relu". - - Returns: - attention (Tensor): The attention mask tensor of shape - ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). - """ - - # If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len) - if attention_mask is not None and attention_mask.dim() == 2: - if attention_mask.shape[-1] == key_len: - attention_mask = attention_mask.view(batch_size, 1, 1, key_len) - elif attention_mask.shape[-1] == query_len: - pad_len = key_len - query_len - if pad_len > 0: - pad_mask = torch.ones( - (batch_size, 1, 1, pad_len), - dtype=torch.bool, - device=attention_mask.device, - ) - attention_mask = torch.cat( - [pad_mask, attention_mask.view(batch_size, 1, 1, query_len)], - dim=-1, - ) - else: - attention_mask = attention_mask.view(batch_size, 1, 1, query_len) - else: - raise ValueError( - f"attention_mask shape {attention_mask.shape} is not compatible with key_len {key_len} or query_len {query_len}." - ) - - # Generate dynamic mask based on attention_bias and attention_mask - if type == "topk": - attention_mask = topk_mask( - attention_bias=attention_bias, - attention_mask=attention_mask, - window_size=window_size, - min_dtype=min_dtype, - block_size=block_size, - ) - elif type == "relu": - attention_mask = relu_mask( - attention_bias=attention_bias, - attention_mask=attention_mask, - window_size=window_size, - min_dtype=min_dtype, - block_size=block_size, - ) - else: - raise ValueError(f"Unsupported mask type: {type}. Supported types are 'topk' and 'relu'.") - - return attention_mask diff --git a/flash_dmattn/utils/padding.py b/flash_dmattn/utils/padding.py deleted file mode 100644 index b675af7..0000000 --- a/flash_dmattn/utils/padding.py +++ /dev/null @@ -1,170 +0,0 @@ -# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py - -import torch -import torch.nn.functional as F - - -def index_first_axis(tensor, indices): - """ - A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, - after flattening the first two dimensions of the tensor. - """ - # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first - # two dimensions to get (total_tokens, ...) before indexing. - reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) - return reshaped_tensor[indices] - - -def unpad_input(hidden_states, attention_mask, unused_mask=None): - """ - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. - unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. - - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. - indices: (total_nnz), the indices of masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. - """ - all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask - seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) - used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - - return ( - index_first_axis(hidden_states, indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - used_seqlens_in_batch, - ) - - -def pad_input(hidden_states, indices, batch, seqlen): - """ - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. - batch: int, batch size for the padded sequence. - seqlen: int, maximum sequence length for the padded sequence. - - Return: - hidden_states: (batch, seqlen, ...) - """ - dim = hidden_states.shape[1:] - output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) - output[indices] = hidden_states - return output.view(batch, seqlen, *dim) - - -def get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: - """ - Retrieves indexing data required to repad unpadded (ragged) tensors. - - Arguments: - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - - Return: - indices (`torch.Tensor`): - The indices of non-masked tokens from the flattened input sequence. - cu_seqlens (`torch.Tensor`): - The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - max_seqlen_in_batch (`int`): - Maximum sequence length in batch. - """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - # NOTE: Similar to the `.item()` in prepare_fdma_from_position_ids, with torch compile, - # this might cause a graph break - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def upad_input( - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, - unpad_input_func, -): - """ - Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. - This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary - tensors for query, key, value tensors. - - Arguments: - query_layer (`torch.Tensor`): - Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - query_length (`int`): - Target length. - unpad_input_func: - The function to use for unpadding the input tensors. - - Return: - query_layer (`torch.Tensor`): - Query state without padding. Shape: (total_target_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - indices_q (`torch.Tensor`): - The indices of non-masked tokens from the flattened input target sequence. - (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask) - - # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage - # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores - if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): - key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] - - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis(key_layer, indices_k) - value_layer = index_first_axis(value_layer, indices_k) - if query_length == kv_seq_len: - query_layer = index_first_axis(query_layer, indices_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) From ea4350a1f2cb3b5030344deb2f293e61773e261f Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 22:59:55 +0800 Subject: [PATCH 05/29] Implement feature X to enhance user experience and optimize performance --- flash_sparse_attn/flash_dmattn_triton.py | 1244 ++++++++++++++++++++++ 1 file changed, 1244 insertions(+) create mode 100644 flash_sparse_attn/flash_dmattn_triton.py diff --git a/flash_sparse_attn/flash_dmattn_triton.py b/flash_sparse_attn/flash_dmattn_triton.py new file mode 100644 index 0000000..828e80b --- /dev/null +++ b/flash_sparse_attn/flash_dmattn_triton.py @@ -0,0 +1,1244 @@ +from typing import Optional +import math + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_warps=8, + num_stages=1, + ), + ], + key=['IS_CAUSAL', 'BLOCK_HEADDIM'] +) +@triton.heuristics( + { + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_preprocess( + K, + V, + B, + I, + CuK, + CuV, + CuB, + CuM, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bn, + stride_ib, + stride_ih, + stride_ik, + stride_ckb, + stride_ckh, + stride_ckk, + stride_cvb, + stride_cvh, + stride_cvk, + stride_cbb, + stride_cbh, + stride_cbk, + stride_cmb, + stride_cmh, + stride_cmm, + stride_cmk, + nheads_k, + seqlen_q, + seqlen_k, + window_size, + headdim, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(0) + off_b = off_hb // nheads_k + off_hk = off_hb % nheads_k + + # Initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + + # Initialize base pointers to K, V, B, I, CuK, CuV, CuB + k_base_ptrs = ( + K + off_b * stride_kb + off_hk * stride_kh + ) + v_base_ptrs = ( + V + off_b * stride_vb + off_hk * stride_vh + ) + b_base_ptrs = ( + B + off_b * stride_bb + off_hk * stride_bh + ) + i_base_ptrs = ( + I + off_b * stride_ib + off_hk * stride_ih + ) + cuk_base_ptrs = ( + CuK + off_b * stride_ckb + off_hk * stride_ckh + ) + cuv_base_ptrs = ( + CuV + off_b * stride_cvb + off_hk * stride_cvh + ) + cub_base_ptrs = ( + CuB + off_b * stride_cbb + off_hk * stride_cbh + ) + cum_base_ptrs = ( + CuM + off_b * stride_cmb + off_hk * stride_cmh + ) + + # Loop over blocks of window_size + for start_k in range(0, window_size, BLOCK_N): + start_k = tl.multiple_of(start_k, BLOCK_N) + offs_k = start_k + offs_n + + # Load I + i_ptrs = ( + i_base_ptrs + offs_k * stride_ik + ) + gather_idx = tl.load(i_ptrs, mask=offs_k < window_size, other=0).to(tl.int64) + valid_idx = (offs_k < window_size) & (gather_idx >= 0) & (gather_idx < seqlen_k) + gather_idx = tl.where(valid_idx, gather_idx, 0) + + # Load K, V, B + k_ptrs = ( + k_base_ptrs + gather_idx[:, None] * stride_kn + offs_d[None, :] + ) + v_ptrs = ( + v_base_ptrs + gather_idx[:, None] * stride_vn + offs_d[None, :] + ) + if EVEN_HEADDIM: + k = tl.load(k_ptrs, mask=valid_idx[:, None], other=0.0) + v = tl.load(v_ptrs, mask=valid_idx[:, None], other=0.0) + else: + k = tl.load( + k_ptrs, + mask=valid_idx[:, None] & (offs_d[None, :] < headdim), + other=0.0 + ) + v = tl.load( + v_ptrs, + mask=valid_idx[:, None] & (offs_d[None, :] < headdim), + other=0.0 + ) + b_ptrs = ( + b_base_ptrs + gather_idx * stride_bn + ) + b = tl.load(b_ptrs, mask=valid_idx, other=0.0) + + # Store to CuK, CuV, CuB + cuk_ptrs = ( + cuk_base_ptrs + offs_k[:, None] * stride_ckk + offs_d[None, :] + ) + cuv_ptrs = ( + cuv_base_ptrs + offs_k[:, None] * stride_cvk + offs_d[None, :] + ) + if EVEN_HEADDIM: + tl.store(cuk_ptrs, k, mask=valid_idx[:, None]) + tl.store(cuv_ptrs, v, mask=valid_idx[:, None]) + else: + tl.store( + cuk_ptrs, k, + mask=valid_idx[:, None] & (offs_d[None, :] < headdim), + ) + tl.store( + cuv_ptrs, v, + mask=valid_idx[:, None] & (offs_d[None, :] < headdim), + ) + cub_ptrs = ( + cub_base_ptrs + offs_k * stride_cbk + ) + tl.store(cub_ptrs, b, mask=valid_idx) + + # Store mask to CuM + for start_m in range(0, seqlen_q, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m = start_m + tl.arange(0, BLOCK_M) + + cum_ptrs = ( + cum_base_ptrs + offs_m[:, None] * stride_cmm + offs_k[None, :] * stride_cmk + ) + + col_mask = offs_k < window_size + row_mask = offs_m[:, None] < seqlen_q + + if IS_CAUSAL: + mask = (offs_m[:, None] >= gather_idx[None, :]) & valid_idx[None, :] + else: + mask = valid_idx[None, :] + + cum = tl.where(row_mask & col_mask[None, :], mask, False) + + tl.store(cum_ptrs, cum, mask=row_mask & col_mask[None, :]) + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_warps=8, + num_stages=1, + ), + ], + key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BLOCK_HEADDIM'] +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["window_size"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, + CuK, + CuV, + CuB, + CuM, + Out, + Lse, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_ckb, + stride_ckh, + stride_ckk, + stride_cvb, + stride_cvh, + stride_cvk, + stride_cbb, + stride_cbh, + stride_cbk, + stride_cmb, + stride_cmh, + stride_cmm, + stride_cmk, + stride_ob, + stride_oh, + stride_om, + nheads, + h_h_k_ratio, + seqlen_q, + window_size, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q: tl.constexpr, + CACHE_KEY_SEQLEN_K: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_bh = tl.program_id(1) + off_b = off_bh // nheads + off_hq = off_bh % nheads + off_hk = off_hq // h_h_k_ratio + + # Initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + + # Initialize pointers to Q, CuK, CuV, CuM, CuB + q_ptrs = ( + Q + off_b * stride_qb + off_hq * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + ) + cuk_base_ptrs = ( + CuK + off_b * stride_ckb + off_hk * stride_ckh + ) + cv_base_ptrs = ( + CuV + off_b * stride_cvb + off_hk * stride_cvh + ) + cub_base_ptrs = ( + CuB + off_b * stride_cbb + off_hk * stride_cbh + ) + cum_base_ptrs = ( + CuM + off_b * stride_cmb + off_hk * stride_cmh + ) + + # Initialize pointer to m and l + lse_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + + # Load q: it will stay in SRAM throughout + if EVEN_M: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0 + ) + + # Scale q + q = (q * softmax_scale).to(q.dtype) + + # Loop over k, v and update accumulator + for start_n in range(0, window_size, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + cum_ptrs = ( + cum_base_ptrs + offs_m[:, None] * stride_cmm + (start_n + offs_n)[None, :] * stride_cmk + ) + # Load mask + if EVEN_M & EVEN_N: + m = tl.load(cum_ptrs) + else: + m = tl.load( + cum_ptrs, + mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < window_size), + other=False, + ) + + # Check if any element in mask is non-zero + any_active = tl.reduce_or(m, axis=None) + + # Skip this iteration if no active elements + if any_active: + + # Load k + cuk_ptrs = ( + cuk_base_ptrs + (start_n + offs_n)[:, None] * stride_ckk + offs_d[None, :] + ) + if EVEN_N: + if EVEN_HEADDIM: + k = tl.load(cuk_ptrs) + else: + k = tl.load(cuk_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + cuk_ptrs, + mask=(start_n + offs_n)[:, None] < window_size, + other=0.0, + ) + else: + k = tl.load( + cuk_ptrs, + mask=((start_n + offs_n)[:, None] < window_size) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # Load bias + cub_ptrs = ( + cub_base_ptrs + (start_n + offs_n) * stride_cbk + ) + if EVEN_M & EVEN_N: + b = tl.load(cub_ptrs) + else: + b = tl.load( + cub_ptrs, + mask=(start_n + offs_n) < window_size, + other=0.0, + ) + + # Initialize acc_s + acc_s = b[None, :].to(tl.float32) + + # Compute acc_s + acc_s += tl.dot(q, tl.trans(k)) + + # Apply masks + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s += tl.where((start_n + offs_n)[None, :] < window_size, 0, float("-inf")) + acc_s += tl.where(m, 0, float("-inf")) + + # Compute p + m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) + p = tl.exp(acc_s - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # Scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # Update output accumulator + acc_o = acc_o * acc_o_scale[:, None] + + # Load v + cuv_ptrs = ( + cv_base_ptrs + (start_n + offs_n)[:, None] * stride_cvk + offs_d[None, :] + ) + if EVEN_N: + if EVEN_HEADDIM: + v = tl.load(cuv_ptrs) + else: + v = tl.load(cuv_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + cuv_ptrs, + mask=(start_n + offs_n)[:, None] < window_size, + other=0.0, + ) + else: + v = tl.load( + cuv_ptrs, + mask=((start_n + offs_n)[:, None] < window_size) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # Compute acc_o + acc_o += tl.dot(p.to(v.dtype), v) + + # Update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - lse_i) + acc_o = acc_o * o_scale[:, None] + # Rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # Write back l and m + lse_ptrs = Lse + off_bh * seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + # Initialize pointers to output + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = ( + Out + + off_b * stride_ob + + off_hq * stride_oh + + (offs_m[:, None] * stride_om + offs_d[None, :]) + ) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store( + out_ptrs, acc_o, + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) + ) + + +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, + DO, + Delta, + stride_ob, + stride_oh, + stride_om, + stride_dob, + stride_doh, + stride_dom, + nheads, + seqlen_q, + seqlen_q_rounded, + headdim, + BLOCK_M: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # Initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Load o + o = tl.load( + Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + do = tl.load( + DO + + off_b * stride_dob + + off_h * stride_doh + + offs_m[:, None] * stride_dom + + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # Write back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + +@triton.jit +def _bwd_kernel_one_col_block( + start_n, + Q, + CuK, + CuV, + CuB, + CuM, + DO, + DQ, + DCuK, + DCuV, + DCuB, + LSE, + D, + softmax_scale, + stride_qm, + stride_ckk, + stride_cvk, + stride_cbk, + stride_cmm, + stride_cmk, + stride_dom, + stride_dqm, + stride_dckk, + stride_dcvk, + stride_dcbk, + seqlen_q, + window_size, + headdim, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + ATOMIC_ADD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # Initialize row/col offsets + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + + # Initialize pointers to value-like data + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_d[None, :]) + cuk_ptrs = CuK + (offs_n[:, None] * stride_ckk + offs_d[None, :]) + cuv_ptrs = CuV + (offs_n[:, None] * stride_cvk + offs_d[None, :]) + cub_ptrs = CuB + (offs_n * stride_cbk) + do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_d[None, :]) + dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_d[None, :]) + dcuk_ptrs = DCuK + (offs_n[:, None] * stride_dckk + offs_d[None, :]) + dcuv_ptrs = DCuV + (offs_n[:, None] * stride_dcvk + offs_d[None, :]) + dcub_ptrs = DCuB + (offs_n * stride_dcbk) + + # Initialize dv, dk, db accumulators + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + db = tl.zeros([BLOCK_N], dtype=tl.float32) + + # Load k and v, them will stay in SRAM throughout + if EVEN_N: + if EVEN_HEADDIM: + k = tl.load(cuk_ptrs) + v = tl.load(cuv_ptrs) + else: + k = tl.load(cuk_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + v = tl.load(cuv_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(cuk_ptrs, mask=offs_n[:, None] < window_size, other=0.0) + v = tl.load(cuv_ptrs, mask=offs_n[:, None] < window_size, other=0.0) + else: + k = tl.load( + cuk_ptrs, + mask=(offs_n[:, None] < window_size) & (offs_d[None, :] < headdim), + other=0.0 + ) + v = tl.load( + cuv_ptrs, + mask=(offs_n[:, None] < window_size) & (offs_d[None, :] < headdim), + other=0.0 + ) + if EVEN_N: + b = tl.load(cub_ptrs) + else: + b = tl.load(cub_ptrs, mask=offs_n < window_size, other=0.0) + + # Scale k + k = (k * softmax_scale).to(k.dtype) + + # Loop over q and update accumulators + num_block_m = tl.cdiv(seqlen_q, BLOCK_M) + for start_m in range(0, num_block_m * BLOCK_M, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m_curr = start_m + offs_m + + # Load mask + cum_ptrs = ( + CuM + offs_m_curr[:, None] * stride_cmm + offs_n[None, :] * stride_cmk + ) + if EVEN_M & EVEN_N: + m = tl.load(cum_ptrs) + else: + m = tl.load( + cum_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < window_size), + other=False, + ) + + # Check if any element in mask is non-zero + any_active = tl.reduce_or(m, axis=None) + + # Skip this iteration if no active elements + if any_active: + # Load q + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # Initialize acc_s + acc_s = b[None, :].to(tl.float32) + + # Compute acc_s + acc_s += tl.dot(q, tl.trans(k)) + + # Apply masks + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s += tl.where(offs_n[None, :] < window_size, 0, float("-inf")) + acc_s += tl.where(m, 0, float("-inf")) + + lse_i = tl.load(LSE + offs_m_curr) + # p = tl.exp(acc_s - lse_i[:, None]) + p = tl.exp(acc_s - tl.where(lse_i > float("-inf"), lse_i, 0.0)[:, None]) + + # Load do + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # There's a race condition if we just use m_mask and not d_mask. + do = tl.load( + do_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # Compute dv + dv += tl.dot(tl.trans(p.to(do.dtype)), do) + + # Compute dp + dp = tl.dot(do, tl.trans(v)) + + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + + # Compute ds + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None])).to(q.dtype) + + # Compute db + db += tl.sum(ds, axis=0) + + # Compute dk + dk += tl.dot(tl.trans(ds), q) + + # Compute dq + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k).to(ds.dtype) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load( + dq_ptrs, + mask=offs_m_curr[:, None] < seqlen_q, + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k).to(ds.dtype) + tl.store( + dq_ptrs, + dq, + mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last", + ) + else: + dq = tl.load( + dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k).to(ds.dtype) + tl.store( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last", + ) + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k).to(ds.dtype) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + ) + + # Increment pointers + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_dom + dq_ptrs += BLOCK_M * stride_dqm + else: + # Increment pointers + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_dom + dq_ptrs += BLOCK_M * stride_dqm + + # Scale dk + dk = (dk * softmax_scale).to(dk.dtype) + + # Write back + if EVEN_N: + if EVEN_HEADDIM: + tl.store(dcuk_ptrs, dk) + tl.store(dcuv_ptrs, dv) + else: + tl.store(dcuk_ptrs, dk, mask=offs_d[None, :] < headdim) + tl.store(dcuv_ptrs, dv, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dcuk_ptrs, dk, mask=offs_n[:, None] < window_size) + tl.store(dcuv_ptrs, dv, mask=offs_n[:, None] < window_size) + else: + tl.store(dcuk_ptrs, dk, mask=(offs_n[:, None] < window_size) & (offs_d[None, :] < headdim)) + tl.store(dcuv_ptrs, dv, mask=(offs_n[:, None] < window_size) & (offs_d[None, :] < headdim)) + + if EVEN_N: + tl.store(dcub_ptrs, db) + else: + tl.store(dcub_ptrs, db, mask=(offs_n < window_size)) + + +def init_to_zero(names): + if isinstance(names, str): + names = [names] + def init_func(nargs): + for name in names: + nargs[name].zero_() + return init_func + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero(["DQ", "DCuB"]), + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero(["DQ", "DCuB"]), + ), + ], + key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BLOCK_HEADDIM"], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["window_size"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _bwd_kernel( + Q, + CuK, + CuV, + CuB, + CuM, + DO, + DQ, + DCuK, + DCuV, + DCuB, + LSE, + D, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_ckb, + stride_ckh, + stride_ckk, + stride_cvb, + stride_cvh, + stride_cvk, + stride_cbb, + stride_cbh, + stride_cbk, + stride_cmb, + stride_cmh, + stride_cmm, + stride_cmk, + stride_dob, + stride_doh, + stride_dom, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dckb, + stride_dckh, + stride_dckk, + stride_dcvb, + stride_dcvh, + stride_dcvk, + stride_dcbb, + stride_dcbh, + stride_dcbk, + nheads, + h_h_k_ratio, + seqlen_q, + window_size, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_hq = off_hb % nheads + off_hk = off_hq // h_h_k_ratio + + # Advance offset pointers for batch and head + Q += off_b * stride_qb + off_hq * stride_qh + CuK += off_b * stride_ckb + off_hk * stride_ckh + CuV += off_b * stride_cvb + off_hk * stride_cvh + CuB += off_b * stride_cbb + off_hk * stride_cbh + CuM += off_b * stride_cmb + off_hk * stride_cmh + DO += off_b * stride_dob + off_hq * stride_doh + DQ += off_b * stride_dqb + off_hq * stride_dqh + DCuK += off_b * stride_dckb + off_hq * stride_dckh + DCuV += off_b * stride_dcvb + off_hq * stride_dcvh + DCuB += off_b * stride_dcbb + off_hq * stride_dcbh + # Advance pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(window_size, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, + CuK, + CuV, + CuB, + CuM, + DO, + DQ, + DCuK, + DCuV, + DCuB, + LSE, + D, + softmax_scale, + stride_qm, + stride_ckk, + stride_cvk, + stride_cbk, + stride_cmm, + stride_cmk, + stride_dom, + stride_dqm, + stride_dckk, + stride_dcvk, + stride_dcbk, + seqlen_q, + window_size, + headdim, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ATOMIC_ADD=False, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, + CuK, + CuV, + CuB, + CuM, + DO, + DQ, + DCuK, + DCuV, + DCuB, + LSE, + D, + softmax_scale, + stride_qm, + stride_ckk, + stride_cvk, + stride_cbk, + stride_cmm, + stride_cmk, + stride_dom, + stride_dqm, + stride_dckk, + stride_dcvk, + stride_dcbk, + seqlen_q, + window_size, + headdim, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ATOMIC_ADD=True, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + +def _flash_dmattn_forward(q, k, v, b, i, softmax_scale=None, is_causal=False, window_size=None): + # shape constraints + batch, nheads, seqlen_q, d = q.shape + _, nheads_k, seqlen_k, _ = k.shape + + assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" + assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" + assert q.dtype == k.dtype == v.dtype == b.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert i.dtype == torch.int64, "Indices must be int64" + assert q.is_cuda and k.is_cuda and v.is_cuda and b.is_cuda, "All tensors must be on GPU" + + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + cu_k = torch.empty((batch, nheads_k, window_size, d), device=q.device, dtype=k.dtype) + cu_v = torch.empty((batch, nheads_k, window_size, d), device=q.device, dtype=v.dtype) + cu_b = torch.empty((batch, nheads_k, window_size), device=q.device, dtype=b.dtype) + cu_m = torch.zeros((batch, nheads_k, seqlen_q, window_size), device=q.device, dtype=torch.bool) + + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = (batch * nheads_k,) + _fwd_preprocess[grid]( + k, v, b, i, + cu_k, cu_v, cu_b, cu_m, + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + b.stride(0), b.stride(1), b.stride(2), + i.stride(0), i.stride(1), i.stride(2), + cu_k.stride(0), cu_k.stride(1), cu_k.stride(2), + cu_v.stride(0), cu_v.stride(1), cu_v.stride(2), + cu_b.stride(0), cu_b.stride(1), cu_b.stride(2), + cu_m.stride(0), cu_m.stride(1), cu_m.stride(2), cu_m.stride(3), + nheads_k, seqlen_q, seqlen_k, window_size, d, is_causal, BLOCK_HEADDIM + ) + + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, + cu_k, cu_v, cu_b, cu_m, + o, lse, softmax_scale, + q.stride(0), q.stride(1), q.stride(2), + cu_k.stride(0), cu_k.stride(1), cu_k.stride(2), + cu_v.stride(0), cu_v.stride(1), cu_v.stride(2), + cu_b.stride(0), cu_b.stride(1), cu_b.stride(2), + cu_m.stride(0), cu_m.stride(1), cu_m.stride(2), cu_m.stride(3), + o.stride(0), o.stride(1), o.stride(2), + nheads, nheads // nheads_k, seqlen_q, window_size, seqlen_q_rounded, d, + seqlen_q // 32, + window_size // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # BLOCK_HEADDIM=d, + BLOCK_HEADDIM, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + return o, lse, softmax_scale, cu_k, cu_v, cu_b, cu_m + + +def _flash_dmattn_backward( + do, q, cuk, cuv, cub, cum, i, o, lse, softmax_scale, seqlen_q, seqlen_k, window_size +): + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, nheads, _, d = q.shape + _, nheads_k, _, _ = cuk.shape + + assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" + assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + seqlen_k_rounded = math.ceil(seqlen_k / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + # dq_accum = torch.zeros_like(q, dtype=torch.float32) + dq_accum = torch.empty_like(q, dtype=torch.float32) + delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) + dk = torch.zeros(batch, nheads_k, seqlen_k, d, device=q.device, dtype=q.dtype) + dv = torch.zeros(batch, nheads_k, seqlen_k, d, device=q.device, dtype=q.dtype) + db = torch.zeros(batch, nheads_k, seqlen_k, device=q.device, dtype=q.dtype) + + dk_expanded = torch.empty(batch, nheads, window_size, d, device=q.device, dtype=q.dtype) + dv_expanded = torch.empty(batch, nheads, window_size, d, device=q.device, dtype=q.dtype) + db_expanded = torch.empty(batch, nheads, window_size, device=q.device, dtype=q.dtype) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, do, delta, + o.stride(0), o.stride(1), o.stride(2), + do.stride(0), do.stride(1), do.stride(2), + nheads, seqlen_q, seqlen_q_rounded, d, + BLOCK_M=64, + BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: ( + triton.cdiv(window_size, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads, + ) + _bwd_kernel[grid]( + q, cuk, cuv, cub, cum, do, + dq_accum, dk_expanded, dv_expanded, db_expanded, + lse, delta, softmax_scale, + q.stride(0), q.stride(1), q.stride(2), + cuk.stride(0), cuk.stride(1), cuk.stride(2), + cuv.stride(0), cuv.stride(1), cuv.stride(2), + cub.stride(0), cub.stride(1), cub.stride(2), + cum.stride(0), cum.stride(1), cum.stride(2), cum.stride(3), + do.stride(0), do.stride(1), do.stride(2), + dq_accum.stride(0), dq_accum.stride(1), dq_accum.stride(2), + dk_expanded.stride(0), dk_expanded.stride(1), dk_expanded.stride(2), + dv_expanded.stride(0), dv_expanded.stride(1), dv_expanded.stride(2), + db_expanded.stride(0), db_expanded.stride(1), db_expanded.stride(2), + nheads, nheads // nheads_k, seqlen_q, window_size, seqlen_q_rounded, d, + seqlen_q // 32, + window_size // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # BLOCK_HEADDIM=BLOCK_HEADDIM, + BLOCK_HEADDIM, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq = dq_accum.to(q.dtype) + + if nheads != nheads_k: + dk_expanded = dk_expanded.view(batch, nheads_k, nheads // nheads_k, window_size, d).sum(dim=2) + dv_expanded = dv_expanded.view(batch, nheads_k, nheads // nheads_k, window_size, d).sum(dim=2) + db_expanded = db_expanded.view(batch, nheads_k, nheads // nheads_k, window_size).sum(dim=2) + + dk.scatter_add_( + dim=2, + index=i.unsqueeze(-1).expand(-1, -1, -1, d), + src=dk_expanded, + ) + dv.scatter_add_( + dim=2, + index=i.unsqueeze(-1).expand(-1, -1, -1, d), + src=dv_expanded, + ) + db.scatter_add_( + dim=2, + index=i, + src=db_expanded, + ) + + return dq, dk, dv, db + + +def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def round_multiple(x, m): + return (x + m - 1) // m * m + + +class FlashDMAttnFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, query, key, value, attn_bias, attn_indices, is_causal=False, softmax_scale=None): + """ + query: (batch_size, nheads, seqlen_q, headdim) + key: (batch_size, nheads_k, seqlen_k, headdim) + value: (batch_size, nheads_k, seqlen_k, headdim) + attn_bias: (batch_size, nheads_k, seqlen_k) + attn_indices: (batch_size, nheads_k, window_size) + is_causal: bool, whether to apply causal masking + softmax_scale: float, scaling factor for attention scores + """ + + # Make sure that the last dimension is contiguous + query, key, value, attn_bias, attn_indices = [maybe_contiguous(x) for x in [query, key, value, attn_bias, attn_indices]] + + # Padding to multiple of 8 for 16-bit memory allocations + head_size_og = query.size(3) + if head_size_og % 8 != 0: + query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8]) + key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8]) + value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8]) + seqlen_k_rounded = round_multiple(key.shape[2], 128) + if attn_bias.shape[-1] != seqlen_k_rounded: + attn_bias = torch.nn.functional.pad(attn_bias, [0, seqlen_k_rounded - attn_bias.shape[-1]]) + window_size = attn_indices.shape[-1] + + o, lse, ctx.softmax_scale, cu_key, cu_value, cu_attn_bias, cu_attn_mask = _flash_dmattn_forward( + query, + key, + value, + attn_bias, + attn_indices, + softmax_scale=softmax_scale, + is_causal=is_causal, + window_size=window_size, + ) + ctx.save_for_backward(query, cu_key, cu_value, cu_attn_bias, cu_attn_mask, attn_indices, o, lse) + ctx.seqlen_q = query.size(2) + ctx.seqlen_k = key.size(2) + ctx.window_size = window_size + + o = o[..., : head_size_og] + return o + + @staticmethod + def backward(ctx, do): + query, cu_key, cu_value, cu_attn_bias, cu_attn_mask, attn_indices, o, lse = ctx.saved_tensors + + head_size_og = do.size(3) + do_padded = do + if head_size_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) + + dq, dk, dv, db = _flash_dmattn_backward( + do_padded, + query, + cu_key, + cu_value, + cu_attn_bias, + cu_attn_mask, + attn_indices, + o, + lse, + softmax_scale=ctx.softmax_scale, + seqlen_q=ctx.seqlen_q, + seqlen_k=ctx.seqlen_k, + window_size=ctx.window_size, + ) + + # We could have padded the head dimension + dq = dq[..., : do.shape[-1]] + dk = dk[..., : do.shape[-1]] + dv = dv[..., : do.shape[-1]] + + return dq, dk, dv, db, None, None, None + + +def triton_dmattn_func(query, key, value, attn_bias, attn_indices, is_causal=False, softmax_scale=None): + return FlashDMAttnFunc.apply(query, key, value, attn_bias, attn_indices, is_causal, softmax_scale) \ No newline at end of file From 6bf01c47e1477250fc10867ed6e8b1c448d8cb84 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:00:40 +0800 Subject: [PATCH 06/29] Introduces Triton sparse attention kernels Adds fused forward/backward kernels in Triton to accelerate sparse attention with masking, bias, and GQA support for PyTorch integration. --- flash_sparse_attn/flash_sparse_attn_triton.py | 1246 +++++++++++++++++ 1 file changed, 1246 insertions(+) create mode 100644 flash_sparse_attn/flash_sparse_attn_triton.py diff --git a/flash_sparse_attn/flash_sparse_attn_triton.py b/flash_sparse_attn/flash_sparse_attn_triton.py new file mode 100644 index 0000000..eefbf76 --- /dev/null +++ b/flash_sparse_attn/flash_sparse_attn_triton.py @@ -0,0 +1,1246 @@ +from typing import Optional +import math + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_warps=8, + num_stages=1, + ), + ], + key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'HAS_MASK', 'HAS_BIAS', 'BLOCK_HEADDIM'] +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, + K, + V, + Mask, + Bias, + Out, + Lse, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_mb, + stride_mh, + stride_mm, + stride_bb, + stride_bh, + stride_bm, + stride_ob, + stride_oh, + stride_om, + nheads, + nheads_k, + nheads_mask, + nheads_bias, + h_h_k_ratio, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q: tl.constexpr, + CACHE_KEY_SEQLEN_K: tl.constexpr, + IS_CAUSAL: tl.constexpr, + HAS_MASK: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_hq = off_hb % nheads + off_hk = off_hq // h_h_k_ratio + if HAS_MASK: + if nheads_mask == 1: + off_hmask = 0 + elif nheads_mask == nheads_k: + off_hmask = off_hk + else: + off_hmask = off_hq + if HAS_BIAS: + if nheads_bias == 1: + off_hbbias = 0 + elif nheads_bias == nheads_k: + off_hbbias = off_hk + else: + off_hbbias = off_hq + # off_b = tl.program_id(1) + # off_h = tl.program_id(2) + # off_hb = off_b * nheads + off_h + + # Initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + + # Initialize pointers to Q, K, V, Mask, Bias + q_ptrs = ( + Q + off_b * stride_qb + off_hq * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + ) + k_ptrs = ( + K + off_b * stride_kb + off_hk * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + ) + v_ptrs = ( + V + off_b * stride_vb + off_hk * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + ) + m_ptrs = ( + Mask + off_b * stride_mb + off_hmask * stride_mh + (offs_m[:, None] * stride_mm + offs_n[None, :]) + ) if HAS_MASK else None + b_ptrs = ( + Bias + off_b * stride_bb + off_hbbias * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) + ) if HAS_BIAS else None + + # Initialize pointer to m and l + lse_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + + # Load q: it will stay in SRAM throughout + if EVEN_M: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 + ) + + # Scale q + q = (q * softmax_scale).to(q.dtype) + + # Loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + if HAS_MASK: + # Load mask + if EVEN_M & EVEN_N: + mask = tl.load(m_ptrs + start_n) + else: + mask = tl.load( + m_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), + other=False + ) + + # Check if any element in mask is non-zero + any_active = tl.reduce_or(mask, axis=None) + else: + any_active = True + + # Skip this iteration if no active elements + if any_active: + + # Load k + if EVEN_N: + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + + if HAS_BIAS: + # Load bias + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load( + b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) + & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + acc_s = bias + else: + acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + # Compute acc_s + acc_s += tl.dot(q, tl.trans(k)) + + # Apply masks + # Trying to combine the three masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if HAS_MASK: + acc_s += tl.where(mask, 0, float("-inf")) + + # Compute p + m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) + p = tl.exp(acc_s - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # Scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # Update output accumulator + acc_o = acc_o * acc_o_scale[:, None] + + # Load v + if EVEN_N: + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # Compute acc_o + acc_o += tl.dot(p.to(v.dtype), v) + + # Update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - lse_i) + acc_o = acc_o * o_scale[:, None] + # Rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # Write back l and m + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + # Initialize pointers to output + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = ( + Out + + off_b * stride_ob + + off_hq * stride_oh + + (offs_m[:, None] * stride_om + offs_d[None, :]) + ) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store( + out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) + ) + + +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, + DO, + Delta, + stride_ob, + stride_oh, + stride_om, + stride_dob, + stride_doh, + stride_dom, + nheads, + seqlen_q, + seqlen_q_rounded, + headdim, + BLOCK_M: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # Initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Load o + o = tl.load( + Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + do = tl.load( + DO + + off_b * stride_dob + + off_h * stride_doh + + offs_m[:, None] * stride_dom + + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # Write back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + +@triton.jit +def _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Mask, + Bias, + DO, + DQ, + DK, + DV, + DBias, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_mm, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + stride_dbm, + seqlen_q, + seqlen_k, + headdim, + IS_CAUSAL: tl.constexpr, + HAS_MASK: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + ATOMIC_ADD: tl.constexpr, + ACCUM_DBIAS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) + begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # Initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) + if HAS_MASK: + m_ptrs = Mask + (offs_qm[:, None] * stride_mm + offs_n[None, :]) + if HAS_BIAS: + b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) + dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) + db_ptrs = DBias + (offs_qm[:, None] * stride_dbm + offs_n[None, :]) + # Initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # There seems to be some problem with Triton pipelining that makes results wrong for + # headdim=64, seqlen=(113, 255). In this case the for loop may have zero step, + # and pipelining with the bias matrix could screw it up. So we just exit early. + if begin_m >= seqlen_q: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + + if EVEN_N: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + return + + # Load k and v, them will stay in SRAM throughout + if EVEN_N: + if EVEN_HEADDIM: + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + else: + k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + else: + k = tl.load( + k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 + ) + v = tl.load( + v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 + ) + + # Scale k + k = (k * softmax_scale).to(k.dtype) + + # Initialize accumulator for dbias if needed + acc_dbias = tl.zeros([BLOCK_N], dtype=tl.float32) if (HAS_BIAS and ACCUM_DBIAS) else None + + # Loop over q and update accumulators + num_block_m = tl.cdiv(seqlen_q, BLOCK_M) + for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m_curr = start_m + offs_m + + if HAS_MASK: + # Load mask + if EVEN_M & EVEN_N: + mask = tl.load(m_ptrs) + else: + mask = tl.load( + m_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + other=False, + ) + + # Check if any element in mask is non-zero + any_active = tl.reduce_or(mask, axis=None) + else: + any_active = True + + # Skip this iteration if no active elements + if any_active: + # Load q + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + + if HAS_BIAS: + # Load bias + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load( + b_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + acc_s = bias + else: + acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + # Compute acc_s + acc_s += tl.dot(q, tl.trans(k)) + + # Apply masks + # Trying to combine the three masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s += tl.where(offs_n[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + acc_s += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float("-inf")) + if HAS_MASK: + acc_s += tl.where(mask, 0, float("-inf")) + + lse_i = tl.load(LSE + offs_m_curr) + # p = tl.exp(acc_s - lse_i[:, None]) + p = tl.exp(acc_s - tl.where(lse_i > float("-inf"), lse_i, 0.0)[:, None]) + + # Load do + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # There's a race condition if we just use m_mask and not d_mask. + do = tl.load( + do_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # Compute dv + dv += tl.dot(tl.trans(p.to(do.dtype)), do) + + # Compute dp + dp = tl.dot(do, tl.trans(v)) + + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + + # Compute ds + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None])).to(q.dtype) + + # Write back + if not (EVEN_M & EVEN_N): + tl.debug_barrier() + if HAS_BIAS: + if ACCUM_DBIAS: + acc_dbias += tl.sum(ds, axis=0) + else: + if EVEN_M & EVEN_N: + tl.store( + db_ptrs, + ds, + ) + else: + tl.store( + db_ptrs, + ds, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + ) + + # Compute dk + dk += tl.dot(tl.trans(ds), q) + + # Compute dq + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k).to(ds.dtype) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load( + dq_ptrs, + mask=offs_m_curr[:, None] < seqlen_q, + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k).to(ds.dtype) + tl.store( + dq_ptrs, + dq, + mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last", + ) + else: + dq = tl.load( + dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k).to(ds.dtype) + tl.store( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last", + ) + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k).to(ds.dtype) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + ) + + # Increment pointers + do_ptrs += BLOCK_M * stride_dom + dq_ptrs += BLOCK_M * stride_dqm + if HAS_BIAS: + db_ptrs += BLOCK_M * stride_dbm + q_ptrs += BLOCK_M * stride_qm + if HAS_MASK: + m_ptrs += BLOCK_M * stride_mm + if HAS_BIAS: + b_ptrs += BLOCK_M * stride_bm + + # Scale dk + dk = (dk * softmax_scale).to(dk.dtype) + + # Write back + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + if HAS_BIAS and ACCUM_DBIAS: + if EVEN_N: + tl.store(DBias + offs_n, acc_dbias) + else: + tl.store(DBias + offs_n, acc_dbias, mask=(offs_n < seqlen_k)) + + if EVEN_N: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + + +def init_to_zero(names): + if isinstance(names, str): + names = [names] + def init_func(nargs): + for name in names: + nargs[name].zero_() + return init_func + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero(["DQ", "DBias"]), + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero(["DQ", "DBias"]), + ), + ], + key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "HAS_MASK", "HAS_BIAS", "HAS_INDICE", "BLOCK_HEADDIM"], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + "ACCUM_DBIAS": lambda args: args["HAS_BIAS"] and (args["stride_dbm"] == 0) and (args["seqlen_q"] > 1), + } +) +@triton.jit +def _bwd_kernel( + Q, + K, + V, + Mask, + Bias, + DO, + DQ, + DK, + DV, + DBias, + LSE, + D, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_mb, + stride_mh, + stride_mm, + stride_bb, + stride_bh, + stride_bm, + stride_dob, + stride_doh, + stride_dom, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dbb, + stride_dbh, + stride_dbm, + nheads, + nheads_k, + nheads_mask, + nheads_bias, + h_h_k_ratio, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + IS_CAUSAL: tl.constexpr, + HAS_MASK: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + ACCUM_DBIAS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_hq = off_hb % nheads + off_hk = off_hq // h_h_k_ratio + if HAS_MASK: + if nheads_mask == 1: + off_hmask = 0 + elif nheads_mask == nheads_k: + off_hmask = off_hk + else: + off_hmask = off_hq + if HAS_BIAS: + if nheads_bias == 1: + off_hbbias = 0 + elif nheads_bias == nheads_k: + off_hbbias = off_hk + else: + off_hbbias = off_hq + + # Advance offset pointers for batch and head + Q += off_b * stride_qb + off_hq * stride_qh + K += off_b * stride_kb + off_hk * stride_kh + V += off_b * stride_vb + off_hk * stride_vh + if HAS_MASK: + Mask += off_b * stride_mb + off_hmask * stride_mh + if HAS_BIAS: + Bias += off_b * stride_bb + off_hbbias * stride_bh + DO += off_b * stride_dob + off_hq * stride_doh + DQ += off_b * stride_dqb + off_hq * stride_dqh + DK += off_b * stride_dkb + off_hq * stride_dkh + DV += off_b * stride_dvb + off_hq * stride_dvh + if HAS_BIAS: + DBias += off_b * stride_dbb + off_hq * stride_dbh + # Advance pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(seqlen_k, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Mask, + Bias, + DO, + DQ, + DK, + DV, + DBias, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_mm, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + stride_dbm, + seqlen_q, + seqlen_k, + headdim, + IS_CAUSAL=IS_CAUSAL, + HAS_MASK=HAS_MASK, + HAS_BIAS=HAS_BIAS, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ATOMIC_ADD=False, + ACCUM_DBIAS=ACCUM_DBIAS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Mask, + Bias, + DO, + DQ, + DK, + DV, + DBias, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_mm, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + stride_dbm, + seqlen_q, + seqlen_k, + headdim, + IS_CAUSAL=IS_CAUSAL, + HAS_MASK=HAS_MASK, + HAS_BIAS=HAS_BIAS, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ATOMIC_ADD=True, + ACCUM_DBIAS=ACCUM_DBIAS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + +def _flash_sparse_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False): + # shape constraints + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + + assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" + assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert q.is_cuda and k.is_cuda and v.is_cuda + + has_mask = mask is not None + if has_mask: + assert mask.dtype == torch.bool, "Only support bool" + assert mask.is_cuda + nheads_mask = mask.shape[1] + else: + nheads_mask = 1 + mask = torch.empty(0, device=q.device, dtype=torch.bool) + + has_bias = bias is not None + if has_bias: + assert bias.dtype == q.dtype, "Only support fp16 and bf16" + assert bias.is_cuda + nheads_bias = bias.shape[1] + else: + nheads_bias = 1 + bias = torch.empty(0, device=q.device, dtype=q.dtype) + + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, + k, + v, + mask, + bias, + o, + lse, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + ((0 if (has_mask and mask.shape[0] == 1) else (mask.stride(0) if has_mask else 0))), + ((0 if (has_mask and mask.shape[1] == 1) else (mask.stride(1) if has_mask else 0))), + ((0 if (has_mask and mask.shape[2] == 1) else (mask.stride(2) if has_mask else 0))), + ((0 if (has_bias and bias.shape[0] == 1) else (bias.stride(0) if has_bias else 0))), + ((0 if (has_bias and bias.shape[1] == 1) else (bias.stride(1) if has_bias else 0))), + ((0 if (has_bias and bias.shape[2] == 1) else (bias.stride(2) if has_bias else 0))), + o.stride(0), + o.stride(2), + o.stride(1), + nheads, + nheads_k, + nheads_mask, + nheads_bias, + nheads // nheads_k, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=is_causal, HAS_MASK=has_mask, HAS_BIAS=has_bias, BLOCK_HEADDIM=d, + is_causal, + has_mask, + has_bias, + BLOCK_HEADDIM, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + return o, lse, softmax_scale # softmax_scale could have been updated + + +def _flash_sparse_attn_backward( + do, q, k, v, mask, bias, o, lse, softmax_scale=None, is_causal=False +): + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, dk = k.shape + + assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" + assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + seqlen_k_rounded = math.ceil(seqlen_k / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + + has_mask = mask is not None + if has_mask: + assert mask.dtype == torch.bool, "Only support bool" + nheads_mask = mask.shape[1] + else: + nheads_mask = 1 + mask = torch.empty(0, device=q.device, dtype=torch.bool) + + has_bias = bias is not None + if has_bias: + assert bias.dtype == q.dtype, "Only support fp16 and bf16" + nheads_bias = bias.shape[1] + else: + nheads_bias = 1 + bias = torch.empty(0, device=q.device, dtype=q.dtype) + + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + # dq_accum = torch.zeros_like(q, dtype=torch.float32) + dq_accum = torch.empty_like(q, dtype=torch.float32) + delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbias = torch.empty_like(bias) if has_bias else torch.empty(0, device=q.device, dtype=q.dtype) + + dk_expanded = torch.empty(batch, seqlen_k, nheads, d, device=q.device, dtype=q.dtype) if nheads != nheads_k else dk + dv_expanded = torch.empty(batch, seqlen_k, nheads, d, device=q.device, dtype=q.dtype) if nheads != nheads_k else dv + if has_bias: + if ( + nheads_bias != nheads + or ((bias.shape[0] == 1) and (batch > 1)) + or ((bias.shape[-2] == 1) and (seqlen_q > 1)) + ): + if bias.shape[-2] == 1: + dbias_expanded = torch.zeros(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) + else: + dbias_expanded = torch.zeros(batch, nheads, seqlen_q, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) + else: + dbias_expanded = dbias + else: + dbias_expanded = dbias + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, + do, + delta, + o.stride(0), + o.stride(2), + o.stride(1), + do.stride(0), + do.stride(2), + do.stride(1), + nheads, + seqlen_q, + seqlen_q_rounded, + d, + BLOCK_M=64, + BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: ( + triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads, + ) + _bwd_kernel[grid]( + q, + k, + v, + mask, + bias, + do, + dq_accum, + dk_expanded, + dv_expanded, + dbias_expanded, + lse, + delta, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + ((0 if (has_mask and mask.shape[0] == 1) else (mask.stride(0) if has_mask else 0))), + ((0 if (has_mask and mask.shape[1] == 1) else (mask.stride(1) if has_mask else 0))), + ((0 if (has_mask and mask.shape[2] == 1) else (mask.stride(2) if has_mask else 0))), + ((0 if (has_bias and bias.shape[0] == 1) else (bias.stride(0) if has_bias else 0))), + ((0 if (has_bias and bias.shape[1] == 1) else (bias.stride(1) if has_bias else 0))), + ((0 if (has_bias and bias.shape[2] == 1) else (bias.stride(2) if has_bias else 0))), + do.stride(0), + do.stride(2), + do.stride(1), + dq_accum.stride(0), + dq_accum.stride(2), + dq_accum.stride(1), + dk_expanded.stride(0), + dk_expanded.stride(2), + dk_expanded.stride(1), + dv_expanded.stride(0), + dv_expanded.stride(2), + dv_expanded.stride(1), + (dbias_expanded.stride(0) if has_bias else 0), + (dbias_expanded.stride(1) if has_bias else 0), + ((0 if (has_bias and bias.shape[-2] == 1) else (dbias_expanded.stride(2) if has_bias else 0))), + nheads, + nheads_k, + nheads_mask, + nheads_bias, + nheads // nheads_k, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=is_causal, HAS_MASK=has_mask, HAS_BIAS=has_bias, BLOCK_HEADDIM=BLOCK_HEADDIM, + is_causal, + has_mask, + has_bias, + BLOCK_HEADDIM, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq = dq_accum.to(q.dtype) + if nheads != nheads_k: + dk = dk_expanded.view(batch, seqlen_k, nheads_k, nheads // nheads_k, d).sum(dim=3) + dv = dv_expanded.view(batch, seqlen_k, nheads_k, nheads // nheads_k, d).sum(dim=3) + if has_bias: + if ( + nheads_bias != nheads + and bias.shape[0] == batch + and bias.shape[-2] == seqlen_q + ): + dbias = dbias_expanded.view(batch, nheads_bias, nheads // nheads_bias, seqlen_q, seqlen_k_rounded).sum(dim=2) + else: + if bias.shape[-2] == 1: + dbias_expanded = dbias_expanded.view(batch, nheads_bias, nheads // nheads_bias, 1, seqlen_k_rounded).sum(dim=2) + else: + dbias_expanded = dbias_expanded.view(batch, nheads_bias, nheads // nheads_bias, seqlen_q, seqlen_k_rounded).sum(dim=2) + if bias.shape[0] == 1: + dbias_expanded = dbias_expanded.sum(dim=0, keepdim=True) + dbias.copy_(dbias_expanded) + return dq, dk, dv, dbias if has_bias else None + + +def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def round_multiple(x, m): + return (x + m - 1) // m * m + + +class FlashDMAttnFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=False, softmax_scale=None): + """ + query: (batch_size, seqlen_q, nheads, headdim) + key: (batch_size, seqlen_k, nheads, headdim) + value: (batch_size, seqlen_k, nheads, headdim) + attn_mask: optional, (batch, nheads, seqlen_q, seqlen_k) + attn_bias: optional, (batch, nheads, seqlen_q, seqlen_k) + is_causal: bool, whether to apply causal masking + softmax_scale: float, scaling factor for attention scores + """ + + # Make sure that the last dimension is contiguous + query, key, value, attn_mask, attn_bias = [maybe_contiguous(x) for x in [query, key, value, attn_mask, attn_bias]] + + # Padding to multiple of 8 for 16-bit memory allocations + head_size_og = query.size(3) + if head_size_og % 8 != 0: + query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8]) + key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8]) + value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8]) + seqlen_k_rounded = round_multiple(key.shape[1], 128) + if attn_mask is not None and attn_mask.shape[-1] != seqlen_k_rounded: + if attn_mask.shape[-1] == 1: + attn_mask = attn_mask.expand(*attn_mask.shape[:-1], seqlen_k_rounded) + else: + attn_mask = torch.nn.functional.pad(attn_mask, [0, seqlen_k_rounded - attn_mask.shape[-1]]) + if attn_bias is not None and attn_bias.shape[-1] != seqlen_k_rounded: + if attn_bias.shape[-1] == 1: + attn_bias = attn_bias.expand(*attn_bias.shape[:-1], seqlen_k_rounded) + else: + attn_bias = torch.nn.functional.pad(attn_bias, [0, seqlen_k_rounded - attn_bias.shape[-1]]) + + o, lse, ctx.softmax_scale = _flash_sparse_attn_forward( + query, + key, + value, + attn_mask, + attn_bias, + softmax_scale=softmax_scale, + is_causal=is_causal + ) + ctx.save_for_backward(query, key, value, o, lse, attn_mask, attn_bias) + ctx.is_causal = is_causal + ctx.seqlen_k_bias_og = attn_bias.shape[-1] if attn_bias is not None else 0 + return o + + @staticmethod + def backward(ctx, do): + query, key, value, o, lse, attn_mask, attn_bias = ctx.saved_tensors + + head_size_og = do.size(3) + do_padded = do + if head_size_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) + + dq, dk, dv, dbias = _flash_sparse_attn_backward( + do_padded, + query, + key, + value, + attn_mask, + attn_bias, + o, + lse, + softmax_scale=ctx.softmax_scale, + is_causal=ctx.is_causal, + ) + + # We could have padded the head dimension + dq = dq[..., : do.shape[-1]] + dk = dk[..., : do.shape[-1]] + dv = dv[..., : do.shape[-1]] + + if dbias is not None: + dbias = dbias[..., :key.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : key.shape[1]] + + return dq, dk, dv, None, dbias, None, None + + +def triton_sparse_attn_func(query, key, value, attn_mask=None, attn_bias=None, is_causal=False, softmax_scale=None): + return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, is_causal, softmax_scale) From 152c73af65f291e0bc05848e60d216598afd3756 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:01:47 +0800 Subject: [PATCH 07/29] Adds flash sparse attention interface Enables calling sparse Flash attention CUDA kernels through custom autograd helpers. Registers fake implementations and padding logic so torch.compile stays compatible with varying head shapes. --- .../flash_sparse_attn_interface.py | 760 ++++++++++++++++++ 1 file changed, 760 insertions(+) create mode 100644 flash_sparse_attn/flash_sparse_attn_interface.py diff --git a/flash_sparse_attn/flash_sparse_attn_interface.py b/flash_sparse_attn/flash_sparse_attn_interface.py new file mode 100644 index 0000000..9f8e676 --- /dev/null +++ b/flash_sparse_attn/flash_sparse_attn_interface.py @@ -0,0 +1,760 @@ +# Copyright (c) 2025, Jingze Shi. + +from typing import Optional, Tuple, Any +from packaging import version +import torch + +import flash_sparse_attn_cuda as flash_sparse_attn_gpu # type: ignore + + +def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def _sanitize_tensors(*tensors: Optional[torch.Tensor], nan: float = 0.0, posinf: float = 1e6, neginf: float = -1e6) -> None: + for t in tensors: + if t is not None and isinstance(t, torch.Tensor): + torch.nan_to_num_(t, nan=nan, posinf=posinf, neginf=neginf) + + +def _get_block_size_n(device, head_dim, is_causal): + # This should match the block sizes in the CUDA kernel + assert head_dim <= 256 + major, minor = torch.cuda.get_device_capability(device) + is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) + is_sm80 = major == 8 and minor == 0 + is_sm90 = major == 9 and minor == 0 + if head_dim <= 32: + return 128 + if head_dim <= 64: + return 128 + elif head_dim <= 96: + return 64 + elif head_dim <= 128: + if is_sm8x: + return 64 if (is_causal) else 32 + else: + return 64 + elif head_dim <= 192: + return 64 + elif head_dim <= 224: + return 64 + elif head_dim <= 256: + return 64 + + +def round_multiple(x, m): + return (x + m - 1) // m * m + + +# torch.compile() support is only enabled for pytorch >= 2.4 +# The reason for this is that we are using the new custom_op and register_fake +# APIs, which support inplace modification of inputs in the function itself +if version.parse(torch.__version__) >= version.parse("2.4.0"): + _torch_custom_op_wrapper = torch.library.custom_op + _torch_register_fake_wrapper = torch.library.register_fake +else: + def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None): + def wrap(func): + return func + if fn is None: + return wrap + return fn + def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1): + def wrap(func): + return func + if fn is None: + return wrap + return fn + _torch_custom_op_wrapper = noop_custom_op_wrapper + _torch_register_fake_wrapper = noop_register_fake_wrapper + + +@_torch_custom_op_wrapper("flash_sparse_attn::_flash_sparse_attn_forward", mutates_args=(), device_types="cuda") +def _flash_sparse_attn_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + softcap: float, + return_softmax: bool +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)] + out, softmax_lse, S_dmask = flash_sparse_attn_gpu.fwd( + q, + k, + v, + mask, + bias, + None, + softmax_scale, + is_causal, + softcap, + return_softmax, + ) + # _sanitize_tensors(out, nan=0.0, posinf=0.0, neginf=0.0) + return out, softmax_lse, S_dmask + + +@_torch_register_fake_wrapper("flash_sparse_attn::_flash_sparse_attn_forward") +def _flash_sparse_attn_forward_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + softcap: float, + return_softmax: bool +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)] + batch_size, seqlen_q, num_heads, head_size = q.shape + seqlen_k = k.shape[1] + out = torch.empty_like(q) + softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout) + p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) + if return_softmax: + p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) + + return out, softmax_lse, p + + +_wrapped_flash_sparse_attn_forward = _flash_sparse_attn_forward + + +@_torch_custom_op_wrapper("flash_sparse_attn::_flash_sparse_attn_varlen_forward", mutates_args=(), device_types="cuda") +def _flash_sparse_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: bool, + softcap: float = 0.0, + return_softmax: bool = False, + block_table: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, softmax_lse, S_dmask = flash_sparse_attn_gpu.varlen_fwd( + q, + k, + v, + None, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + leftpad_k, + block_table, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + zero_tensors, + is_causal, + softcap, + return_softmax, + ) + # _sanitize_tensors(out, nan=0.0, posinf=0.0, neginf=0.0) + return out, softmax_lse, S_dmask + + +@_torch_register_fake_wrapper("flash_sparse_attn::_flash_sparse_attn_varlen_forward") +def _flash_sparse_attn_varlen_forward_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: bool, + softcap: float = 0.0, + return_softmax: bool = False, + block_table: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + paged_kv = block_table is not None + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + out = torch.empty_like(q) + softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) + p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) + seqlen_q_rounded = round_multiple(max_seqlen_q, 128) + seqlen_k_rounded = round_multiple(max_seqlen_k, 128) + if return_softmax: + p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) + return out, softmax_lse, p + + +_wrapped_flash_sparse_attn_varlen_forward = _flash_sparse_attn_varlen_forward + + +@_torch_custom_op_wrapper("flash_sparse_attn::_flash_sparse_attn_backward", mutates_args=("dq", "dk", "dv", "dbias"), device_types="cuda") +def _flash_sparse_attn_backward( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + dbias: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + softcap: float, + deterministic: bool, +) -> torch.Tensor: + dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] + ( + dq, + dk, + dv, + dbias, + softmax_d, + ) = flash_sparse_attn_gpu.bwd( + dout, + q, + k, + v, + mask, + bias, + out, + softmax_lse, + dq, + dk, + dv, + dbias, + softmax_scale, + is_causal, + softcap, + deterministic, + ) + # _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=0.0, neginf=0.0) + return softmax_d + + +@_torch_register_fake_wrapper("flash_sparse_attn::_flash_sparse_attn_backward") +def _flash_sparse_attn_backward_fake( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + dbias: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + softcap: float, + deterministic: bool, +) -> torch.Tensor: + dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] + if dq is None: + dq = torch.empty_like(q) + if dk is None: + dk = torch.empty_like(k) + if dv is None: + dv = torch.empty_like(v) + if dbias is None: + dbias = torch.empty_like(bias) + batch_size, seqlen_q, num_heads, _ = q.shape + softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) + + return softmax_d + + +_wrapped_flash_sparse_attn_backward = _flash_sparse_attn_backward + + +@_torch_custom_op_wrapper("flash_sparse_attn::_flash_sparse_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") +def _flash_sparse_attn_varlen_backward( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: bool, + softcap: float, + deterministic: bool, + zero_tensors: bool = False, +) -> torch.Tensor: + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + ( + dq, + dk, + dv, + softmax_d, + ) = flash_sparse_attn_gpu.varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + zero_tensors, + is_causal, + softcap, + deterministic, + ) + # _sanitize_tensors(dq, dk, dv, nan=0.0, posinf=0.0, neginf=0.0) + return softmax_d + + +@_torch_register_fake_wrapper("flash_sparse_attn::_flash_sparse_attn_varlen_backward") +def _flash_sparse_attn_varlen_backward_fake( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + dbias: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: bool, + softcap: float, + deterministic: bool, + zero_tensors: bool = False, +) -> torch.Tensor: + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + if dq is None: + dq = torch.empty_like(q) + if dk is None: + dk = torch.empty_like(k) + if dv is None: + dv = torch.empty_like(v) + softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) + + return softmax_d + + +_wrapped_flash_sparse_attn_varlen_backward = _flash_sparse_attn_varlen_backward + + +class FlashDMAttnFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + softmax_scale: Optional[float], + is_causal: Optional[bool], + softcap: Optional[float], + deterministic: Optional[bool], + return_softmax: Optional[bool], + is_grad_enabled: bool = True, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + if is_causal is None: + is_causal = False + if softcap is None: + softcap = 0.0 + if deterministic is None: + deterministic = False + if return_softmax is None: + return_softmax = False + seqlen_k_bias_og = bias.shape[-1] if bias is not None else 0 + + # Padding to multiple of 8 for 16-bit memory allocations + head_size_og = q.size(3) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + seqlen_k_rounded = round_multiple(k.shape[1], 128) + if mask is not None and mask.shape[-1] != seqlen_k_rounded: + if mask.shape[-1] == 1: + mask = mask.expand(*mask.shape[:-1], seqlen_k_rounded) + else: + mask = torch.nn.functional.pad(mask, [0, seqlen_k_rounded - mask.shape[-1]]) + if bias is not None and bias.shape[-1] != seqlen_k_rounded: + if bias.shape[-1] == 1: + bias = bias.expand(*bias.shape[:-1], seqlen_k_rounded) + else: + bias = torch.nn.functional.pad(bias, [0, seqlen_k_rounded - bias.shape[-1]]) + + out_padded, softmax_lse, S_dmask = _wrapped_flash_sparse_attn_forward( + q, + k, + v, + mask, + bias, + softmax_scale, + is_causal=is_causal, + softcap=softcap, + return_softmax=return_softmax, + ) + + if is_grad: + ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse) + ctx.softmax_scale = softmax_scale + ctx.is_causal = is_causal + ctx.softcap = softcap + ctx.deterministic = deterministic + ctx.seqlen_k_bias_og = seqlen_k_bias_og + + out = out_padded[..., :head_size_og] + + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + dout: torch.Tensor, + *args: Any, + ): + q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) + dbias = torch.zeros_like(bias).contiguous() if bias is not None else None + + head_size_og = dout.size(3) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + + _wrapped_flash_sparse_attn_backward( + dout_padded, + q, + k, + v, + mask, + bias, + out, + softmax_lse, + dq, + dk, + dv, + dbias, + ctx.softmax_scale, + ctx.is_causal, + ctx.softcap, + ctx.deterministic, + ) + + # We could have padded the head dimension + dq = dq[..., : dout.shape[-1]] + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + + if dbias is not None: + dbias = dbias[..., :k.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : k.shape[1]] + + return dq, dk, dv, None, dbias, None, None, None, None, None, None + + +class FlashAttnVarlenFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: Optional[bool], + softcap: Optional[float], + deterministic: Optional[bool], + return_softmax: Optional[bool], + block_table: Optional[torch.Tensor], + is_grad_enabled: bool = True, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + if is_causal is None: + is_causal = False + if softcap is None: + softcap = 0.0 + if deterministic is None: + deterministic = False + if return_softmax is None: + return_softmax = False + + # Padding to multiple of 8 for 16-bit memory allocations + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + out_padded, softmax_lse, S_dmask = _wrapped_flash_sparse_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + is_causal=is_causal, + softcap=softcap, + return_softmax=return_softmax, + block_table=block_table, + ) + + if is_grad: + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k + ) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.is_causal = is_causal + ctx.softcap = softcap + ctx.deterministic = deterministic + + out = out_padded[..., :head_size_og] + + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + + head_size_og = dout.size(2) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + + _wrapped_flash_sparse_attn_varlen_backward( + dout_padded, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.softmax_scale, + ctx.is_causal, + ctx.softcap, + ctx.deterministic, + ) + + # We could have padded the head dimension + dq = dq[..., : dout.shape[-1]] + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None + + +def flash_sparse_attn_func( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + attn_bias: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + is_causal: Optional[bool] = None, + softcap: Optional[float] = None, + deterministic: Optional[bool] = None, + return_attn_probs: Optional[bool] = None, +): + """ + Supports multi-query attention and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + Similarity, also supports attn_mask and attn_bias with head dimension of 1, nheads_k or nheads for MQA/GQA. + For example, if Q has 6 heads, K, V have 2 heads, then attn_mask and attn_bias can have head dimension + of 1, 2 or 6. If it is 1, all heads use the same mask/bias; if it is 2, head 0, 1, 2 of Q use head 0 + of mask/bias, head 3, 4, 5 of Q use head 1 of mask/bias. If it is 6, each head uses its own mask/bias. + + If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + Arguments: + query: torch.Tensor. The query tensor of shape (batch_size, seqlen, nheads, headdim) + key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim) + value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim) + attn_mask: torch.Tensor, optional. The attention mask boolean tensor of + shape ({batch_size|1}, {nheads|nheads_k|1}, {seqlen_q|1}, {seqlen_k|1}) to apply to the attention scores. + If None, no mask is applied. + attn_bias: torch.Tensor, optional. The attention bias float tensor of + shape ({batch_size|1}, {nheads|nheads_k|1}, {seqlen_q|1}, {seqlen_k|1}) to add to the attention scores. + If None, no bias is applied. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + softcap: float. Anything > 0 activates softcapping attention. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). + """ + return FlashDMAttnFunc.apply( + query, + key, + value, + attn_mask, + attn_bias, + softmax_scale, + is_causal, + softcap, + deterministic, + return_attn_probs, + torch.is_grad_enabled(), + ) + + +def flash_sparse_attn_varlen_func( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: Optional[float] = None, + is_causal: Optional[bool] = None, + softcap: Optional[float] = None, + deterministic: Optional[bool] = None, + return_attn_probs: Optional[bool] = None, + block_table: Optional[torch.Tensor] = None, +): + """ + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + Arguments: + query: torch.Tensor. The query tensor of shape (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + key: torch.Tensor. The key tensor of shape (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + value: torch.Tensor. The value tensor of shape (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: torch.Tensor. The cumulative sequence lengths of the sequences in the batch, used to index into q. + cu_seqlens_k: torch.Tensor. The cumulative sequence lengths of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + softcap: float. Anything > 0 activates softcapping attention. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). + """ + return FlashAttnVarlenFunc.apply( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + is_causal, + softcap, + deterministic, + return_attn_probs, + block_table, + torch.is_grad_enabled(), + ) From 77e4e61ddebfd60faaa3763b3bb62fc4328a8840 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:12:37 +0800 Subject: [PATCH 08/29] Clarifies install docs and performance layout Updates package and repo naming so installation commands match the published distribution. Repositions performance benchmarks after usage guidance for both languages and aligns tensor examples to current API expectations. --- README.md | 184 +++++++++++++++++++++++++------------------------- README_zh.md | 186 +++++++++++++++++++++++++-------------------------- 2 files changed, 185 insertions(+), 185 deletions(-) diff --git a/README.md b/README.md index 736c184..03a817e 100644 --- a/README.md +++ b/README.md @@ -45,95 +45,6 @@ Thus, a more effective approach is sparse attention: interacting each query with - Further performance improvements for skipping memory access and computation -## Performance - -We present the expected speedup of FSA over standard PyTorch SDPA under mask and bias conditions. - -![FSA Performance Overview](assets/performance_overview.png) - ---- - -### Forward Pass Performance - -The following table shows the forward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs. - -| Mode | Q len | K len | Window W | SDPA (ms) | FSA (ms) | Speedup | -|--------|-------|--------|----------|-----------|-----------|---------| -| Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x | -| Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x | -| Train | 1024 | 1024 | 1024 | 0.51 | 0.18 | 2.81x | -| Train | 2048 | 2048 | 1024 | 1.04 | 0.18 | 5.68x | -| Train | 4096 | 4096 | 1024 | 2.53 | 0.24 | 10.41x | -| Train | 8192 | 8192 | 1024 | 9.38 | 0.36 | 25.93x | -| Train | 16384 | 16384 | 1024 | 28.39 | 0.81 | 35.25x | -| Train | 32768 | 32768 | 1024 | 111.87 | 2.25 | 49.78x | -| Train | 32768 | 32768 | 32 | 113.19 | 2.10 | 53.97x | -| Train | 32768 | 32768 | 64 | 113.17 | 2.12 | 53.32x | -| Train | 32768 | 32768 | 128 | 113.14 | 2.10 | 53.78x | -| Train | 32768 | 32768 | 256 | 113.18 | 2.13 | 53.18x | -| Train | 32768 | 32768 | 512 | 113.19 | 2.17 | 52.17x | -| Train | 32768 | 32768 | 1024 | 113.19 | 2.24 | 50.45x | -| Train | 32768 | 32768 | 2048 | 113.15 | 2.39 | 47.35x | -| Train | 32768 | 32768 | 4096 | 113.16 | 2.67 | 42.39x | -| Train | 32768 | 32768 | 8192 | 113.11 | 3.20 | 35.29x | -| Train | 32768 | 32768 | 16384 | 113.15 | 3.97 | 28.51x | -| Train | 32768 | 32768 | 32768 | 113.11 | 4.90 | 23.10x | -| Infer | 1 | 256 | 1024 | 0.25 | 0.19 | 1.28x | -| Infer | 1 | 512 | 1024 | 0.25 | 0.19 | 1.27x | -| Infer | 1 | 1024 | 1024 | 0.25 | 0.20 | 1.28x | -| Infer | 1 | 2048 | 1024 | 0.25 | 0.20 | 1.24x | -| Infer | 1 | 4096 | 1024 | 0.25 | 0.19 | 1.29x | -| Infer | 1 | 8192 | 1024 | 0.25 | 0.20 | 1.25x | -| Infer | 1 | 16384 | 1024 | 0.25 | 0.19 | 1.29x | -| Infer | 1 | 32768 | 1024 | 0.27 | 0.20 | 1.33x | -| Infer | 1 | 65536 | 1024 | 0.42 | 0.20 | 2.10x | -| Infer | 1 | 131072 | 1024 | 0.72 | 0.20 | 3.65x | -| Infer | 1 | 262144 | 1024 | 1.31 | 0.22 | 6.06x | -| Infer | 1 | 524288 | 1024 | 2.49 | 0.24 | 10.45x | -| Infer | 1 | 524288 | 32 | 2.48 | 0.21 | 11.60x | -| Infer | 1 | 524288 | 64 | 2.44 | 0.21 | 11.66x | -| Infer | 1 | 524288 | 128 | 2.45 | 0.21 | 11.47x | -| Infer | 1 | 524288 | 256 | 2.43 | 0.21 | 11.47x | -| Infer | 1 | 524288 | 512 | 2.44 | 0.22 | 10.89x | -| Infer | 1 | 524288 | 1024 | 2.44 | 0.24 | 10.31x | -| Infer | 1 | 524288 | 2048 | 2.44 | 0.27 | 9.07x | -| Infer | 1 | 524288 | 4096 | 2.45 | 0.33 | 7.41x | -| Infer | 1 | 524288 | 8192 | 2.44 | 0.35 | 6.93x | -| Infer | 1 | 524288 | 16384 | 2.44 | 0.35 | 6.93x | -| Infer | 1 | 524288 | 32768 | 2.45 | 0.35 | 6.96x | -| Infer | 1 | 524288 | 65536 | 2.44 | 0.35 | 6.88x | - ---- - -### Backward Pass Performance - -The following table shows the backward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs. - -| Mode | Q len | K len | Window W | SDPA-BWD (ms) | FSA-BWD (ms) | Speedup | -|-------|-------|--------|----------|---------------|---------------|---------| -| Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x | -| Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x | -| Train | 1024 | 1024 | 1024 | 0.94 | 0.61 | 1.5x | -| Train | 2048 | 2048 | 1024 | 1.79 | 0.69 | 2.6x | -| Train | 4096 | 4096 | 1024 | 3.76 | 1.08 | 3.5x | -| Train | 8192 | 8192 | 1024 | 14.39 | 2.06 | 7.0x | -| Train | 16384 | 16384 | 1024 | 39.56 | 4.97 | 8.0x | -| Train | 32768 | 32768 | 1024 | 142.07 | 25.63 | 5.5x | -| Train | 32768 | 32768 | 32 | 142.70 | 21.91 | 6.5x | -| Train | 32768 | 32768 | 64 | 142.65 | 22.29 | 6.4x | -| Train | 32768 | 32768 | 128 | 142.69 | 23.04 | 6.2x | -| Train | 32768 | 32768 | 256 | 142.69 | 24.27 | 5.9x | -| Train | 32768 | 32768 | 512 | 142.67 | 25.12 | 5.7x | -| Train | 32768 | 32768 | 1024 | 142.55 | 25.58 | 5.6x | -| Train | 32768 | 32768 | 2048 | 142.75 | 25.64 | 5.6x | -| Train | 32768 | 32768 | 4096 | 142.61 | 24.84 | 5.7x | -| Train | 32768 | 32768 | 8192 | 142.33 | 25.63 | 5.6x | -| Train | 32768 | 32768 | 16384 | 142.40 | 25.62 | 5.6x | -| Train | 32768 | 32768 | 32768 | 142.43 | 25.63 | 5.6x | - ---- - - ## Installation ### Requirements @@ -150,14 +61,14 @@ The following table shows the backward pass performance comparison between FSA a You can install FSA via pre-compiled wheels: ```bash -pip install flash_sparse_attn --no-build-isolation +pip install flash-sparse-attn --no-build-isolation ``` Alternatively, you can compile and install from source: ```bash -git clone https://github.com/SmallDoges/flash_sparse_attn.git -cd flash_sparse_attn +git clone https://github.com/SmallDoges/flash-sparse-attn.git +cd flash-sparse-attn pip install . --no-build-isolation ``` @@ -245,6 +156,95 @@ print(f"Bias gradient shape: {attn_bias.grad.shape}") ``` +## Performance + +We present the expected speedup of FSA over standard PyTorch SDPA under mask and bias conditions. + +![FSA Performance Overview](assets/performance_overview.png) + +--- + +### Forward Pass Performance + +The following table shows the forward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs. + +| Mode | Q len | K len | Window W | SDPA (ms) | FSA (ms) | Speedup | +|--------|-------|--------|----------|-----------|-----------|---------| +| Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x | +| Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x | +| Train | 1024 | 1024 | 1024 | 0.51 | 0.18 | 2.81x | +| Train | 2048 | 2048 | 1024 | 1.04 | 0.18 | 5.68x | +| Train | 4096 | 4096 | 1024 | 2.53 | 0.24 | 10.41x | +| Train | 8192 | 8192 | 1024 | 9.38 | 0.36 | 25.93x | +| Train | 16384 | 16384 | 1024 | 28.39 | 0.81 | 35.25x | +| Train | 32768 | 32768 | 1024 | 111.87 | 2.25 | 49.78x | +| Train | 32768 | 32768 | 32 | 113.19 | 2.10 | 53.97x | +| Train | 32768 | 32768 | 64 | 113.17 | 2.12 | 53.32x | +| Train | 32768 | 32768 | 128 | 113.14 | 2.10 | 53.78x | +| Train | 32768 | 32768 | 256 | 113.18 | 2.13 | 53.18x | +| Train | 32768 | 32768 | 512 | 113.19 | 2.17 | 52.17x | +| Train | 32768 | 32768 | 1024 | 113.19 | 2.24 | 50.45x | +| Train | 32768 | 32768 | 2048 | 113.15 | 2.39 | 47.35x | +| Train | 32768 | 32768 | 4096 | 113.16 | 2.67 | 42.39x | +| Train | 32768 | 32768 | 8192 | 113.11 | 3.20 | 35.29x | +| Train | 32768 | 32768 | 16384 | 113.15 | 3.97 | 28.51x | +| Train | 32768 | 32768 | 32768 | 113.11 | 4.90 | 23.10x | +| Infer | 1 | 256 | 1024 | 0.25 | 0.19 | 1.28x | +| Infer | 1 | 512 | 1024 | 0.25 | 0.19 | 1.27x | +| Infer | 1 | 1024 | 1024 | 0.25 | 0.20 | 1.28x | +| Infer | 1 | 2048 | 1024 | 0.25 | 0.20 | 1.24x | +| Infer | 1 | 4096 | 1024 | 0.25 | 0.19 | 1.29x | +| Infer | 1 | 8192 | 1024 | 0.25 | 0.20 | 1.25x | +| Infer | 1 | 16384 | 1024 | 0.25 | 0.19 | 1.29x | +| Infer | 1 | 32768 | 1024 | 0.27 | 0.20 | 1.33x | +| Infer | 1 | 65536 | 1024 | 0.42 | 0.20 | 2.10x | +| Infer | 1 | 131072 | 1024 | 0.72 | 0.20 | 3.65x | +| Infer | 1 | 262144 | 1024 | 1.31 | 0.22 | 6.06x | +| Infer | 1 | 524288 | 1024 | 2.49 | 0.24 | 10.45x | +| Infer | 1 | 524288 | 32 | 2.48 | 0.21 | 11.60x | +| Infer | 1 | 524288 | 64 | 2.44 | 0.21 | 11.66x | +| Infer | 1 | 524288 | 128 | 2.45 | 0.21 | 11.47x | +| Infer | 1 | 524288 | 256 | 2.43 | 0.21 | 11.47x | +| Infer | 1 | 524288 | 512 | 2.44 | 0.22 | 10.89x | +| Infer | 1 | 524288 | 1024 | 2.44 | 0.24 | 10.31x | +| Infer | 1 | 524288 | 2048 | 2.44 | 0.27 | 9.07x | +| Infer | 1 | 524288 | 4096 | 2.45 | 0.33 | 7.41x | +| Infer | 1 | 524288 | 8192 | 2.44 | 0.35 | 6.93x | +| Infer | 1 | 524288 | 16384 | 2.44 | 0.35 | 6.93x | +| Infer | 1 | 524288 | 32768 | 2.45 | 0.35 | 6.96x | +| Infer | 1 | 524288 | 65536 | 2.44 | 0.35 | 6.88x | + +--- + +### Backward Pass Performance + +The following table shows the backward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs. + +| Mode | Q len | K len | Window W | SDPA-BWD (ms) | FSA-BWD (ms) | Speedup | +|-------|-------|--------|----------|---------------|---------------|---------| +| Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x | +| Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x | +| Train | 1024 | 1024 | 1024 | 0.94 | 0.61 | 1.5x | +| Train | 2048 | 2048 | 1024 | 1.79 | 0.69 | 2.6x | +| Train | 4096 | 4096 | 1024 | 3.76 | 1.08 | 3.5x | +| Train | 8192 | 8192 | 1024 | 14.39 | 2.06 | 7.0x | +| Train | 16384 | 16384 | 1024 | 39.56 | 4.97 | 8.0x | +| Train | 32768 | 32768 | 1024 | 142.07 | 25.63 | 5.5x | +| Train | 32768 | 32768 | 32 | 142.70 | 21.91 | 6.5x | +| Train | 32768 | 32768 | 64 | 142.65 | 22.29 | 6.4x | +| Train | 32768 | 32768 | 128 | 142.69 | 23.04 | 6.2x | +| Train | 32768 | 32768 | 256 | 142.69 | 24.27 | 5.9x | +| Train | 32768 | 32768 | 512 | 142.67 | 25.12 | 5.7x | +| Train | 32768 | 32768 | 1024 | 142.55 | 25.58 | 5.6x | +| Train | 32768 | 32768 | 2048 | 142.75 | 25.64 | 5.6x | +| Train | 32768 | 32768 | 4096 | 142.61 | 24.84 | 5.7x | +| Train | 32768 | 32768 | 8192 | 142.33 | 25.63 | 5.6x | +| Train | 32768 | 32768 | 16384 | 142.40 | 25.62 | 5.6x | +| Train | 32768 | 32768 | 32768 | 142.43 | 25.63 | 5.6x | + +--- + + ## Benchmarking FSA provides comprehensive benchmarking tools to evaluate performance across different configurations: diff --git a/README_zh.md b/README_zh.md index 149b108..8bea29c 100644 --- a/README_zh.md +++ b/README_zh.md @@ -45,95 +45,6 @@ Flash-Sparse-Attention 是一个高性能的可训练稀疏注意力实现, 将 - 进一步提升跳过访存与计算的性能 -## 性能 - -我们展示了带有mask与bias条件下 FSA 相对于标准 PyTorch SDPA 的预期加速效果. - -![FSA Performance Overview](assets/performance_overview.png) - ---- - -### 前向传播性能 - -以下表格是我们在NVIDIA A100-SXM4-80GB上对FSA与标准PyTorch SDPA在不同配置下的前向性能对比测试结果. 结果为预热两次, 运行三次的平均值. - -| Mode | Q len | K len | Window W | SDPA (ms) | FSA (ms) | Speedup | -|--------|-------|--------|----------|-----------|-----------|---------| -| Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x | -| Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x | -| Train | 1024 | 1024 | 1024 | 0.51 | 0.18 | 2.81x | -| Train | 2048 | 2048 | 1024 | 1.04 | 0.18 | 5.68x | -| Train | 4096 | 4096 | 1024 | 2.53 | 0.24 | 10.41x | -| Train | 8192 | 8192 | 1024 | 9.38 | 0.36 | 25.93x | -| Train | 16384 | 16384 | 1024 | 28.39 | 0.81 | 35.25x | -| Train | 32768 | 32768 | 1024 | 111.87 | 2.25 | 49.78x | -| Train | 32768 | 32768 | 32 | 113.19 | 2.10 | 53.97x | -| Train | 32768 | 32768 | 64 | 113.17 | 2.12 | 53.32x | -| Train | 32768 | 32768 | 128 | 113.14 | 2.10 | 53.78x | -| Train | 32768 | 32768 | 256 | 113.18 | 2.13 | 53.18x | -| Train | 32768 | 32768 | 512 | 113.19 | 2.17 | 52.17x | -| Train | 32768 | 32768 | 1024 | 113.19 | 2.24 | 50.45x | -| Train | 32768 | 32768 | 2048 | 113.15 | 2.39 | 47.35x | -| Train | 32768 | 32768 | 4096 | 113.16 | 2.67 | 42.39x | -| Train | 32768 | 32768 | 8192 | 113.11 | 3.20 | 35.29x | -| Train | 32768 | 32768 | 16384 | 113.15 | 3.97 | 28.51x | -| Train | 32768 | 32768 | 32768 | 113.11 | 4.90 | 23.10x | -| Infer | 1 | 256 | 1024 | 0.25 | 0.19 | 1.28x | -| Infer | 1 | 512 | 1024 | 0.25 | 0.19 | 1.27x | -| Infer | 1 | 1024 | 1024 | 0.25 | 0.20 | 1.28x | -| Infer | 1 | 2048 | 1024 | 0.25 | 0.20 | 1.24x | -| Infer | 1 | 4096 | 1024 | 0.25 | 0.19 | 1.29x | -| Infer | 1 | 8192 | 1024 | 0.25 | 0.20 | 1.25x | -| Infer | 1 | 16384 | 1024 | 0.25 | 0.19 | 1.29x | -| Infer | 1 | 32768 | 1024 | 0.27 | 0.20 | 1.33x | -| Infer | 1 | 65536 | 1024 | 0.42 | 0.20 | 2.10x | -| Infer | 1 | 131072 | 1024 | 0.72 | 0.20 | 3.65x | -| Infer | 1 | 262144 | 1024 | 1.31 | 0.22 | 6.06x | -| Infer | 1 | 524288 | 1024 | 2.49 | 0.24 | 10.45x | -| Infer | 1 | 524288 | 32 | 2.48 | 0.21 | 11.60x | -| Infer | 1 | 524288 | 64 | 2.44 | 0.21 | 11.66x | -| Infer | 1 | 524288 | 128 | 2.45 | 0.21 | 11.47x | -| Infer | 1 | 524288 | 256 | 2.43 | 0.21 | 11.47x | -| Infer | 1 | 524288 | 512 | 2.44 | 0.22 | 10.89x | -| Infer | 1 | 524288 | 1024 | 2.44 | 0.24 | 10.31x | -| Infer | 1 | 524288 | 2048 | 2.44 | 0.27 | 9.07x | -| Infer | 1 | 524288 | 4096 | 2.45 | 0.33 | 7.41x | -| Infer | 1 | 524288 | 8192 | 2.44 | 0.35 | 6.93x | -| Infer | 1 | 524288 | 16384 | 2.44 | 0.35 | 6.93x | -| Infer | 1 | 524288 | 32768 | 2.45 | 0.35 | 6.96x | -| Infer | 1 | 524288 | 65536 | 2.44 | 0.35 | 6.88x | - ---- - -### 反向传播性能 - -以下表格是我们在NVIDIA A100-SXM4-80GB上对FSA与标准PyTorch SDPA在不同配置下的反向性能对比测试结果. 结果为预热两次, 运行三次的平均值. - -| Mode | Q len | K len | Window W | SDPA-BWD (ms) | FSA-BWD (ms) | Speedup | -|-------|-------|--------|----------|---------------|---------------|---------| -| Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x | -| Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x | -| Train | 1024 | 1024 | 1024 | 0.94 | 0.61 | 1.5x | -| Train | 2048 | 2048 | 1024 | 1.79 | 0.69 | 2.6x | -| Train | 4096 | 4096 | 1024 | 3.76 | 1.08 | 3.5x | -| Train | 8192 | 8192 | 1024 | 14.39 | 2.06 | 7.0x | -| Train | 16384 | 16384 | 1024 | 39.56 | 4.97 | 8.0x | -| Train | 32768 | 32768 | 1024 | 142.07 | 25.63 | 5.5x | -| Train | 32768 | 32768 | 32 | 142.70 | 21.91 | 6.5x | -| Train | 32768 | 32768 | 64 | 142.65 | 22.29 | 6.4x | -| Train | 32768 | 32768 | 128 | 142.69 | 23.04 | 6.2x | -| Train | 32768 | 32768 | 256 | 142.69 | 24.27 | 5.9x | -| Train | 32768 | 32768 | 512 | 142.67 | 25.12 | 5.7x | -| Train | 32768 | 32768 | 1024 | 142.55 | 25.58 | 5.6x | -| Train | 32768 | 32768 | 2048 | 142.75 | 25.64 | 5.6x | -| Train | 32768 | 32768 | 4096 | 142.61 | 24.84 | 5.7x | -| Train | 32768 | 32768 | 8192 | 142.33 | 25.63 | 5.6x | -| Train | 32768 | 32768 | 16384 | 142.40 | 25.62 | 5.6x | -| Train | 32768 | 32768 | 32768 | 142.43 | 25.63 | 5.6x | - ---- - - ## 安装 ### 依赖 @@ -150,14 +61,14 @@ Flash-Sparse-Attention 是一个高性能的可训练稀疏注意力实现, 将 您可以通过预编译的轮子安装 FSA: ```bash -pip install flash_sparse_attn --no-build-isolation +pip install flash-sparse-attn --no-build-isolation ``` 或者, 您可以从源代码编译和安装: ```bash -git clone https://github.com/SmallDoges/flash_sparse_attn.git -cd flash_sparse_attn +git clone https://github.com/SmallDoges/flash-sparse-attn.git +cd flash-sparse-attn pip install . --no-build-isolation ``` @@ -185,7 +96,7 @@ key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dt value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) # 为稀疏注意力创建 bias -attn_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) +attn_bias = torch.randn(batch_size, num_kv_heads, 1, seq_len, device=device, dtype=dtype) # 基于 bias 生成动态 mask if seq_len > window_size: @@ -245,6 +156,95 @@ print(f"Bias 梯度形状: {attn_bias.grad.shape}") ``` +## 性能 + +我们展示了带有mask与bias条件下 FSA 相对于标准 PyTorch SDPA 的预期加速效果. + +![FSA Performance Overview](assets/performance_overview.png) + +--- + +### 前向传播性能 + +以下表格是我们在NVIDIA A100-SXM4-80GB上对FSA与标准PyTorch SDPA在不同配置下的前向性能对比测试结果. 结果为预热两次, 运行三次的平均值. + +| Mode | Q len | K len | Window W | SDPA (ms) | FSA (ms) | Speedup | +|--------|-------|--------|----------|-----------|-----------|---------| +| Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x | +| Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x | +| Train | 1024 | 1024 | 1024 | 0.51 | 0.18 | 2.81x | +| Train | 2048 | 2048 | 1024 | 1.04 | 0.18 | 5.68x | +| Train | 4096 | 4096 | 1024 | 2.53 | 0.24 | 10.41x | +| Train | 8192 | 8192 | 1024 | 9.38 | 0.36 | 25.93x | +| Train | 16384 | 16384 | 1024 | 28.39 | 0.81 | 35.25x | +| Train | 32768 | 32768 | 1024 | 111.87 | 2.25 | 49.78x | +| Train | 32768 | 32768 | 32 | 113.19 | 2.10 | 53.97x | +| Train | 32768 | 32768 | 64 | 113.17 | 2.12 | 53.32x | +| Train | 32768 | 32768 | 128 | 113.14 | 2.10 | 53.78x | +| Train | 32768 | 32768 | 256 | 113.18 | 2.13 | 53.18x | +| Train | 32768 | 32768 | 512 | 113.19 | 2.17 | 52.17x | +| Train | 32768 | 32768 | 1024 | 113.19 | 2.24 | 50.45x | +| Train | 32768 | 32768 | 2048 | 113.15 | 2.39 | 47.35x | +| Train | 32768 | 32768 | 4096 | 113.16 | 2.67 | 42.39x | +| Train | 32768 | 32768 | 8192 | 113.11 | 3.20 | 35.29x | +| Train | 32768 | 32768 | 16384 | 113.15 | 3.97 | 28.51x | +| Train | 32768 | 32768 | 32768 | 113.11 | 4.90 | 23.10x | +| Infer | 1 | 256 | 1024 | 0.25 | 0.19 | 1.28x | +| Infer | 1 | 512 | 1024 | 0.25 | 0.19 | 1.27x | +| Infer | 1 | 1024 | 1024 | 0.25 | 0.20 | 1.28x | +| Infer | 1 | 2048 | 1024 | 0.25 | 0.20 | 1.24x | +| Infer | 1 | 4096 | 1024 | 0.25 | 0.19 | 1.29x | +| Infer | 1 | 8192 | 1024 | 0.25 | 0.20 | 1.25x | +| Infer | 1 | 16384 | 1024 | 0.25 | 0.19 | 1.29x | +| Infer | 1 | 32768 | 1024 | 0.27 | 0.20 | 1.33x | +| Infer | 1 | 65536 | 1024 | 0.42 | 0.20 | 2.10x | +| Infer | 1 | 131072 | 1024 | 0.72 | 0.20 | 3.65x | +| Infer | 1 | 262144 | 1024 | 1.31 | 0.22 | 6.06x | +| Infer | 1 | 524288 | 1024 | 2.49 | 0.24 | 10.45x | +| Infer | 1 | 524288 | 32 | 2.48 | 0.21 | 11.60x | +| Infer | 1 | 524288 | 64 | 2.44 | 0.21 | 11.66x | +| Infer | 1 | 524288 | 128 | 2.45 | 0.21 | 11.47x | +| Infer | 1 | 524288 | 256 | 2.43 | 0.21 | 11.47x | +| Infer | 1 | 524288 | 512 | 2.44 | 0.22 | 10.89x | +| Infer | 1 | 524288 | 1024 | 2.44 | 0.24 | 10.31x | +| Infer | 1 | 524288 | 2048 | 2.44 | 0.27 | 9.07x | +| Infer | 1 | 524288 | 4096 | 2.45 | 0.33 | 7.41x | +| Infer | 1 | 524288 | 8192 | 2.44 | 0.35 | 6.93x | +| Infer | 1 | 524288 | 16384 | 2.44 | 0.35 | 6.93x | +| Infer | 1 | 524288 | 32768 | 2.45 | 0.35 | 6.96x | +| Infer | 1 | 524288 | 65536 | 2.44 | 0.35 | 6.88x | + +--- + +### 反向传播性能 + +以下表格是我们在NVIDIA A100-SXM4-80GB上对FSA与标准PyTorch SDPA在不同配置下的反向性能对比测试结果. 结果为预热两次, 运行三次的平均值. + +| Mode | Q len | K len | Window W | SDPA-BWD (ms) | FSA-BWD (ms) | Speedup | +|-------|-------|--------|----------|---------------|---------------|---------| +| Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x | +| Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x | +| Train | 1024 | 1024 | 1024 | 0.94 | 0.61 | 1.5x | +| Train | 2048 | 2048 | 1024 | 1.79 | 0.69 | 2.6x | +| Train | 4096 | 4096 | 1024 | 3.76 | 1.08 | 3.5x | +| Train | 8192 | 8192 | 1024 | 14.39 | 2.06 | 7.0x | +| Train | 16384 | 16384 | 1024 | 39.56 | 4.97 | 8.0x | +| Train | 32768 | 32768 | 1024 | 142.07 | 25.63 | 5.5x | +| Train | 32768 | 32768 | 32 | 142.70 | 21.91 | 6.5x | +| Train | 32768 | 32768 | 64 | 142.65 | 22.29 | 6.4x | +| Train | 32768 | 32768 | 128 | 142.69 | 23.04 | 6.2x | +| Train | 32768 | 32768 | 256 | 142.69 | 24.27 | 5.9x | +| Train | 32768 | 32768 | 512 | 142.67 | 25.12 | 5.7x | +| Train | 32768 | 32768 | 1024 | 142.55 | 25.58 | 5.6x | +| Train | 32768 | 32768 | 2048 | 142.75 | 25.64 | 5.6x | +| Train | 32768 | 32768 | 4096 | 142.61 | 24.84 | 5.7x | +| Train | 32768 | 32768 | 8192 | 142.33 | 25.63 | 5.6x | +| Train | 32768 | 32768 | 16384 | 142.40 | 25.62 | 5.6x | +| Train | 32768 | 32768 | 32768 | 142.43 | 25.63 | 5.6x | + +--- + + ## 基准测试 FSA 提供全面的基准测试工具, 用于评估不同配置下的性能: From 6a2993192df616a3916382b36b5e0b54f398c5b1 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:13:07 +0800 Subject: [PATCH 09/29] Renames project to flash-sparse-attn Aligns packaging metadata with new repository identity. --- pyproject.toml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 97ffc48..1158eec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,9 +11,9 @@ requires = [ build-backend = "setuptools.build_meta" [project] -name = "flash-dmattn" +name = "flash-sparse-attn" dynamic = ["version"] -description = "Flash Dynamic Mask Attention: Fast and Memory-Efficient Trainable Dynamic Mask Sparse Attention" +description = "Flash Sparse Attention: Fast and Memory-Efficient Trainable Dynamic Mask Sparse Attention" readme = "README.md" license = { file = "LICENSE" } authors = [ @@ -40,9 +40,9 @@ classifiers = [ ] [project.urls] -Homepage = "https://github.com/SmallDoges/flash-dmattn" -Source = "https://github.com/SmallDoges/flash-dmattn" -Issues = "https://github.com/SmallDoges/flash-dmattn/issues" +Homepage = "https://github.com/SmallDoges/flash-sparse-attention" +Source = "https://github.com/SmallDoges/flash-sparse-attention" +Issues = "https://github.com/SmallDoges/flash-sparse-attention/issues" [project.optional-dependencies] triton = [ @@ -69,11 +69,11 @@ dev = [ ] [tool.setuptools.dynamic] -version = { attr = "flash_dmattn.__version__" } +version = { attr = "flash_sparse_attn.__version__" } [tool.setuptools.packages.find] where = ["."] -include = ["flash_dmattn*"] +include = ["flash_sparse_attn*"] exclude = [ "build", "csrc", @@ -82,10 +82,10 @@ exclude = [ "dist", "docs", "benchmarks", - "flash_dmattn.egg-info" + "flash_sparse_attn.egg-info" ] [tool.setuptools.package-data] -flash_dmattn = ["*.py"] +flash_sparse_attn = ["*.py"] [tool.setuptools] From 612b85c82087c8e36940a464e125cf08930f3831 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:14:25 +0800 Subject: [PATCH 10/29] Aligns security docs with FSA naming Clarifies security instructions under the Flash Sparse Attention brand so users follow the right guidance for install, reporting, and support --- SECURITY.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/SECURITY.md b/SECURITY.md index 020430e..1abb585 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -13,7 +13,7 @@ We actively maintain and provide security updates for the following versions: ### CUDA Code Execution -Flash Dynamic Mask Attention includes CUDA kernels and C++ extensions that execute on your GPU. When using this library: +Flash Sparse Attention includes CUDA kernels and C++ extensions that execute on your GPU. When using this library: - Only install from trusted sources (official PyPI releases or verified builds) - Be cautious when building from source with modifications @@ -46,11 +46,11 @@ If you discover a security vulnerability, please report it responsibly: **For security issues:** - Email: losercheems@gmail.com -- Subject: [SECURITY] Flash-DMA Vulnerability Report +- Subject: [SECURITY] FSA Vulnerability Report - Include: Detailed description, reproduction steps, and potential impact **For general bugs:** -- Use our [GitHub Issues](https://github.com/SmallDoges/flash-dmattn/issues) +- Use our [GitHub Issues](https://github.com/SmallDoges/flash-sparse-attention/issues) - Follow our [contributing guidelines](CONTRIBUTING.md) ## Response Timeline @@ -63,21 +63,21 @@ Critical security issues will be prioritized and may result in emergency release ## Security Best Practices -When using Flash Dynamic Mask Attention: +When using Flash Sparse Attention: 1. **Environment Isolation** ```bash # Use virtual environments - python -m venv flash_dma_env - source flash_dma_env/bin/activate # Linux/Mac + python -m venv fsa_env + source fsa_env/bin/activate # Linux/Mac # or - flash_dma_env\Scripts\activate # Windows + fsa_env\Scripts\activate # Windows ``` 2. **Dependency Management** ```bash # Keep dependencies updated - pip install --upgrade torch flash-dmattn + pip install --upgrade torch flash_sparse_attn ``` 3. **Input Validation** @@ -108,5 +108,5 @@ For security-related questions or concerns: - Project maintainers: See [AUTHORS](AUTHORS) file For general support: -- GitHub Issues: https://github.com/SmallDoges/flash-dmattn/issues -- Documentation: https://github.com/SmallDoges/flash-dmattn/tree/main/docs/ +- GitHub Issues: https://github.com/SmallDoges/flash-sparse-attention/issues +- Documentation: https://github.com/SmallDoges/flash-sparse-attention/tree/main/docs/ \ No newline at end of file From 9f5d48d135fd6bbcea9cde9efe873c9b7dfaa2eb Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:15:32 +0800 Subject: [PATCH 11/29] Renames package to flash_sparse_attn Aligns packaging metadata and build hooks with the flash_sparse_attn name so prebuilt wheels, env vars, and CUDA builds resolve correctly. --- setup.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/setup.py b/setup.py index 6bda267..6c2a8b0 100644 --- a/setup.py +++ b/setup.py @@ -34,19 +34,19 @@ # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) -PACKAGE_NAME = "flash_dmattn" +PACKAGE_NAME = "flash_sparse_attn" BASE_WHEEL_URL = ( - "https://github.com/SmallDoges/flash-dmattn/releases/download/{tag_name}/{wheel_name}" + "https://github.com/SmallDoges/flash-sparse-attention/releases/download/{tag_name}/{wheel_name}" ) # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation # Also useful when user only wants Triton/Flex backends without CUDA compilation -FORCE_BUILD = os.getenv("FLASH_DMATTN_FORCE_BUILD", "FALSE") == "TRUE" -SKIP_CUDA_BUILD = os.getenv("FLASH_DMATTN_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +FORCE_BUILD = os.getenv("FLASH_SPARSE_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI -FORCE_CXX11_ABI = os.getenv("FLASH_DMATTN_FORCE_CXX11_ABI", "FALSE") == "TRUE" +FORCE_CXX11_ABI = os.getenv("FLASH_SPARSE_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" # Auto-detect if user wants only Triton/Flex backends based on pip install command # This helps avoid unnecessary CUDA compilation when user only wants Python backends @@ -69,7 +69,7 @@ def should_skip_cuda_build(): if has_triton_or_flex and not has_all_or_dev: print("Detected Triton/Flex-only installation. Skipping CUDA compilation.") - print("Set FLASH_DMATTN_FORCE_BUILD=TRUE to force CUDA compilation.") + print("Set FLASH_SPARSE_ATTENTION_FORCE_BUILD=TRUE to force CUDA compilation.") return True return False @@ -79,7 +79,7 @@ def should_skip_cuda_build(): @functools.lru_cache(maxsize=None) def cuda_archs(): - return os.getenv("FLASH_DMATTN_CUDA_ARCHS", "80;90;100").split(";") + return os.getenv("FLASH_SPARSE_ATTENTION_CUDA_ARCHS", "80;90;100").split(";") def detect_preferred_sm_arch() -> Optional[str]: @@ -154,14 +154,14 @@ def append_nvcc_threads(nvcc_extra_args): TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) - check_if_cuda_home_none("flash_dmattn") + check_if_cuda_home_none("flash_sparse_attn") # Check, if CUDA11 is installed for compute capability 8.0 cc_flag = [] if CUDA_HOME is not None: _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("11.7"): raise RuntimeError( - "Flash Dynamic Mask Attention is only supported on CUDA 11.7 and above. " + "Flash Sparse Attention is only supported on CUDA 11.7 and above. " "Note: make sure nvcc has a supported version by running nvcc -V." ) @@ -218,20 +218,20 @@ def append_nvcc_threads(nvcc_extra_args): ext_modules.append( CUDAExtension( - name="flash_dmattn_cuda", + name="flash_sparse_attn_cuda", sources=( [ - "csrc/flash_dmattn/flash_api.cpp", + "csrc/flash_sparse_attn/flash_api.cpp", ] - + sorted(glob.glob("csrc/flash_dmattn/src/instantiations/flash_*.cu")) + + sorted(glob.glob("csrc/flash_sparse_attn/src/instantiations/flash_*.cu")) ), extra_compile_args={ "cxx": compiler_c17_flag, "nvcc": append_nvcc_threads(nvcc_flags + cc_flag), }, include_dirs=[ - Path(this_dir) / "csrc" / "flash_dmattn", - Path(this_dir) / "csrc" / "flash_dmattn" / "src", + Path(this_dir) / "csrc" / "flash_sparse_attn", + Path(this_dir) / "csrc" / "flash_sparse_attn" / "src", Path(this_dir) / "csrc" / "cutlass" / "include", ], ) @@ -239,10 +239,10 @@ def append_nvcc_threads(nvcc_extra_args): def get_package_version(): - with open(Path(this_dir) / "flash_dmattn" / "__init__.py", "r") as f: + with open(Path(this_dir) / "flash_sparse_attn" / "__init__.py", "r") as f: version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) public_version = ast.literal_eval(version_match.group(1)) - local_version = os.environ.get("FLASH_DMATTN_LOCAL_VERSION") + local_version = os.environ.get("FLASH_SPARSE_ATTENTION_LOCAL_VERSION") if local_version: return f"{public_version}+{local_version}" else: From 13a0db0de13bf8e66f55696128a00f22288b148f Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:17:56 +0800 Subject: [PATCH 12/29] Aligns repo links with new name Points contribution guide links at flash-sparse-attention to avoid outdated references. --- CONTRIBUTING.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d0368d1..7271d04 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,7 +4,7 @@ Everyone is welcome to contribute, and we value everybody's contribution. Code c It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you. -However you choose to contribute, please be mindful and respect our [code of conduct](https://github.com/SmallDoges/flash-dmattn/blob/main/CODE_OF_CONDUCT.md). +However you choose to contribute, please be mindful and respect our [code of conduct](https://github.com/SmallDoges/flash-sparse-attention/blob/main/CODE_OF_CONDUCT.md). ## Ways to contribute @@ -16,7 +16,7 @@ There are several ways you can contribute to Flash-DMA: * Contribute to the examples, benchmarks, or documentation. * Improve CUDA kernel performance. -If you don't know where to start, there is a special [Good First Issue](https://github.com/SmallDoges/flash-dmattn/contribute) listing. It will give you a list of open issues that are beginner-friendly and help you start contributing to open-source. +If you don't know where to start, there is a special [Good First Issue](https://github.com/SmallDoges/flash-sparse-attention/contribute) listing. It will give you a list of open issues that are beginner-friendly and help you start contributing to open-source. > All contributions are equally valuable to the community. 🥰 @@ -81,14 +81,14 @@ You will need basic `git` proficiency to contribute to Flash-DMA. You'll need ** ### Development Setup -1. Fork the [repository](https://github.com/SmallDoges/flash-dmattn) by clicking on the **Fork** button. +1. Fork the [repository](https://github.com/SmallDoges/flash-sparse-attention) by clicking on the **Fork** button. 2. Clone your fork to your local disk, and add the base repository as a remote: ```bash - git clone https://github.com//flash-dmattn.git - cd flash-dmattn - git remote add upstream https://github.com/SmallDoges/flash-dmattn.git + git clone https://github.com//flash-sparse-attention.git + cd flash-sparse-attention + git remote add upstream https://github.com/SmallDoges/flash-sparse-attention.git ``` 3. Create a new branch to hold your development changes: @@ -157,7 +157,7 @@ You will need basic `git` proficiency to contribute to Flash-DMA. You'll need ** ### Tests -An extensive test suite is included to test the library behavior and performance. Tests can be found in the [tests](https://github.com/SmallDoges/flash-dmattn/tree/main/tests) folder and benchmarks in the [benchmarks](https://github.com/SmallDoges/flash-dmattn/tree/main/benchmarks) folder. +An extensive test suite is included to test the library behavior and performance. Tests can be found in the [tests](https://github.com/SmallDoges/flash-sparse-attention/tree/main/tests) folder and benchmarks in the [benchmarks](https://github.com/SmallDoges/flash-sparse-attention/tree/main/benchmarks) folder. We use `pytest` for testing. From the root of the repository, run: @@ -200,6 +200,6 @@ If you discover a security vulnerability, please send an e-mail to the maintaine ## Questions? -If you have questions about contributing, feel free to ask in the [GitHub Discussions](https://github.com/SmallDoges/flash-dmattn/discussions) or open an issue. +If you have questions about contributing, feel free to ask in the [GitHub Discussions](https://github.com/SmallDoges/flash-sparse-attention/discussions) or open an issue. Thank you for contributing to Flash Dynamic Mask Attention! 🚀 From 307a50ec73d4a403f1877e9591df9cb2e3ffd225 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:18:08 +0800 Subject: [PATCH 13/29] Aligns citation with repo rename Reflects updated project title and repository location to keep citation metadata current. --- CITATION.cff | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index d8f3d0e..4aaeee9 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -1,8 +1,8 @@ cff-version: "1.2.0" date-released: 2025-06 message: "If you use this software, please cite it using these metadata." -title: "Flash Dynamic Mask Attention: Trainable Dynamic Mask Sparse Attention" -url: "https://github.com/SmallDoges/flash-dmattn" +title: "Flash Sparse Attention: Trainable Dynamic Mask Sparse Attention" +url: "https://github.com/SmallDoges/flash-sparse-attention" authors: - family-names: Shi given-names: Jingze From 186c7257ab16c401908a894d3b6f79f0d9cfe303 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:26:56 +0800 Subject: [PATCH 14/29] Adds import helpers for sparse attention Introduces cached availability checks so integrations can detect flash sparse attention without importing local modules and ensures CUDA backed torch is present before enabling features. --- .../integrations/import_utils.py | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 flash_sparse_attn/integrations/import_utils.py diff --git a/flash_sparse_attn/integrations/import_utils.py b/flash_sparse_attn/integrations/import_utils.py new file mode 100644 index 0000000..70a91fe --- /dev/null +++ b/flash_sparse_attn/integrations/import_utils.py @@ -0,0 +1,95 @@ +# Copyright 2025 Jingze Shi and the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" + +import importlib.metadata +import importlib.util +from functools import lru_cache +from typing import Union + + +from transformers import is_torch_available +from transformers.utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better. +def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[tuple[bool, str], bool]: + # Check if the package spec exists and grab its version to avoid importing a local directory + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + # TODO: Once python 3.9 support is dropped, `importlib.metadata.packages_distributions()` + # should be used here to map from package name to distribution names + # e.g. PIL -> Pillow, Pillow-SIMD; quark -> amd-quark; onnxruntime -> onnxruntime-gpu. + # `importlib.metadata.packages_distributions()` is not available in Python 3.9. + + # Primary method to get the package version + package_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: + # Fallback method: Only for "torch" and versions containing "dev" + if pkg_name == "torch": + try: + package = importlib.import_module(pkg_name) + temp_version = getattr(package, "__version__", "N/A") + # Check if the version contains "dev" + if "dev" in temp_version: + package_version = temp_version + package_exists = True + else: + package_exists = False + except ImportError: + # If the package can't be imported, it's not available + package_exists = False + elif pkg_name == "quark": + # TODO: remove once `importlib.metadata.packages_distributions()` is supported. + try: + package_version = importlib.metadata.version("amd-quark") + except Exception: + package_exists = False + elif pkg_name == "triton": + try: + package_version = importlib.metadata.version("pytorch-triton") + except Exception: + package_exists = False + else: + # For packages other than "torch", don't attempt the fallback and set as not available + package_exists = False + logger.debug(f"Detected {pkg_name} version: {package_version}") + if return_version: + return package_exists, package_version + else: + return package_exists + + + +@lru_cache +def is_flash_sparse_attn_available(): + if not is_torch_available(): + return False + + if not _is_package_available("flash_sparse_attn"): + return False + + import torch + + if not torch.cuda.is_available(): + return False + + return True \ No newline at end of file From 0402b3969f9355ccf06bc7bd8ddcc74fc82b893e Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:27:31 +0800 Subject: [PATCH 15/29] Adds flash sparse attention wrapper Supports future HF integration by routing calls through flash sparse attention logic and normalizing autocast, causal, and dtype handling --- .../integrations/flash_sparse_attention.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 flash_sparse_attn/integrations/flash_sparse_attention.py diff --git a/flash_sparse_attn/integrations/flash_sparse_attention.py b/flash_sparse_attn/integrations/flash_sparse_attention.py new file mode 100644 index 0000000..1dd6550 --- /dev/null +++ b/flash_sparse_attn/integrations/flash_sparse_attention.py @@ -0,0 +1,111 @@ +from typing import Optional + +import torch + +from .modeling_flash_sparse_attention_utils import _flash_sparse_attention_forward +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +def flash_sparse_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + attention_bias: Optional[torch.Tensor], + scaling: Optional[float] = None, + window_size: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +) -> tuple[torch.Tensor, None]: + """ + A wrapper around the _flash_sparse_attention_forward function to be used in + the FlashSparseAttention class from HuggingFace Transformers. + + Args: + module (torch.nn.Module): The attention module. + query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim). + key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim). + value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape + (batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}). + attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}). + scaling (Optional[float]): The scaling factor for the attention scores. + window_size (Optional[int]): The size of the window to keep. + softcap (Optional[float]): The softcap value for the attention scores. + **kwargs: Additional keyword arguments. + Includes: + - is_causal (bool): Whether to apply a causal mask. + - layer_idx (int): The index of the layer (for logging purposes). + - implementation (str): The implementation to use ("flash_sparse_attn" or None). + + Returns: + tuple[torch.Tensor, None]: The output tensor of shape (batch_size, seq_len, num_heads, head_dim) + and None (for compatibility with other attention implementations). + """ + + if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None: + logger.warning_once( + "`flash_sparse_attention` does not support `output_attentions=True` or `head_mask`." + " Please set your attention to `eager` if you want any of these features." + ) + + # This is before the transpose + query_len = query.shape[2] + key_len = key.shape[2] + + if any(dim == 0 for dim in query.shape): + raise ValueError( + "Tensor query has shape with a zero dimension.\n" + "FlashSparseAttention does not support inputs with dim=0.\n" + "Please check your input shapes or use SDPA instead." + ) + + # FSA uses non-transposed inputs + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (usually our RMSNorm modules handle it correctly) + target_dtype = None + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(module.config, "_pre_quantization_dtype"): + target_dtype = module.config._pre_quantization_dtype + else: + target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype + + # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented + is_causal = kwargs.pop("is_causal", None) + if is_causal is None: + is_causal = module.is_causal + + attn_output = _flash_sparse_attention_forward( + query, + key, + value, + attention_mask, + attention_bias, + query_length=query_len, + key_length=key_len, + is_causal=is_causal, + softmax_scale=scaling, + softcap=softcap, + window_size=window_size, + target_dtype=target_dtype, + implementation="flash_sparse_attn", + layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None, + **kwargs, + ) + + return attn_output, None From 6bb896fbbbe35534139e7b18afab81f4e07b8502 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:27:53 +0800 Subject: [PATCH 16/29] Adds flash sparse attention utils Introduces lazy import plumbing for flash sparse attention kernels to streamline future integrations. Prepares padding-aware helpers and kwarg validation so padding-free flows and PEFT casting stay compatible with the kernels. --- .../modeling_flash_sparse_attention_utils.py | 596 ++++++++++++++++++ 1 file changed, 596 insertions(+) create mode 100644 flash_sparse_attn/integrations/modeling_flash_sparse_attention_utils.py diff --git a/flash_sparse_attn/integrations/modeling_flash_sparse_attention_utils.py b/flash_sparse_attn/integrations/modeling_flash_sparse_attention_utils.py new file mode 100644 index 0000000..b62b34e --- /dev/null +++ b/flash_sparse_attn/integrations/modeling_flash_sparse_attention_utils.py @@ -0,0 +1,596 @@ +# Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import os +from functools import partial +from typing import Optional, TypedDict + +import torch +import torch.nn.functional as F + +from .import_utils import is_flash_sparse_attn_available +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves +_fsa_fn = None +_fsa_varlen_fn = None +_pad_fn = None +_unpad_fn = None +_create_mask_fn = None + +# function that processes kwargs, generalized to handle any supported kwarg within the function +_process_flash_kwargs_fn = None +# exceptions where hf API doesn't match the original FSA API +_hf_api_to_flash_mapping = { + "dropout": None, + "sliding_window": None, +} + + +def _lazy_imports(implementation: Optional[str]): + """ + Lazy loads the respective flash sparse attention implementations. + + Return: + flash_sparse_attn_func: The base flash sparse attention function. + flash_sparse_attn_varlen_func: The flash sparse attention function supporting variable sequence lengths, e.g. for padding-free training. + pad_input: The function to pad inputs into one sequence and returning the respective kwargs. + unpad_input: The function to unpad outputs based on the kwargs (from pad_input). + """ + is_fsa = is_flash_sparse_attn_available() + + if (implementation == "flash_sparse_attn" and is_fsa) or (implementation is None and is_fsa): + from flash_sparse_attn import flash_sparse_attn_func, flash_sparse_attn_varlen_func + from flash_sparse_attn.utils.padding import pad_input, unpad_input + from flash_sparse_attn.utils.mask import create_mask + + return flash_sparse_attn_func, flash_sparse_attn_varlen_func, pad_input, unpad_input, create_mask + + +def _lazy_define_process_function(flash_function): + """ + Depending on the version and kernel some features are not supported. Due to limitations in + `torch.compile`, we opt to statically type which (optional) kwarg parameters are supported + within `_process_flash_sparse_attention_kwargs`. + + NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`. + This might be confusing for kwargs that we use in any case, e.g. `is_causal`. + """ + + flash_parameters = inspect.signature(flash_function).parameters + process_parameters = inspect.signature(_process_flash_sparse_attention_kwargs).parameters + + supports_mapping = {} + for param in process_parameters: + fsa_param = _hf_api_to_flash_mapping.get(param, param) + supports_mapping[fsa_param] = fsa_param in flash_parameters + + return partial(_process_flash_sparse_attention_kwargs, supports_mapping=supports_mapping) + + +def lazy_import_flash_sparse_attention(implementation: Optional[str], force_import: Optional[bool] = False): + """ + Lazily import flash sparse attention and return the respective functions + flags. + + NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can + work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`. + """ + global _fsa_fn, _fsa_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn + if force_import or any(k is None for k in [_fsa_fn, _fsa_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn]): + _fsa_fn, _fsa_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn = _lazy_imports(implementation) + + global _process_flash_kwargs_fn + if force_import or _process_flash_kwargs_fn is None: + _process_flash_kwargs_fn = _lazy_define_process_function(_fsa_varlen_fn) + return (_fsa_fn, _fsa_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn), _process_flash_kwargs_fn + + +def _index_first_axis(tensor, indices): + """ + A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, + after flattening the first two dimensions of the tensor. This is functionally equivalent to + FA2's `index_first_axis` and replaces the need to import it. + """ + # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first + # two dimensions to get (total_tokens, ...) before indexing. + reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) + return reshaped_tensor[indices] + + +def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + # NOTE: Similar to the `.item()` in prepare_fsa_kwargs_from_position_ids, with torch compile, + # this might cause a graph break + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + unpad_input_func, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + unpad_input_func: + The function to use for unpadding the input tensors. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage + # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores + if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): + key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] + + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = _index_first_axis(key_layer, indices_k) + value_layer = _index_first_axis(value_layer, indices_k) + if query_length == kv_seq_len: + query_layer = _index_first_axis(query_layer, indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def prepare_fsa_kwargs_from_position_ids(position_ids): + """ + This function returns all the necessary kwargs to call `flash_sparse_attn_varlen_func` extracted from position_ids. + + Arguments: + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into + ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, + `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device} + + position_ids = position_ids.view(-1) + indices_q = (position_ids == 0).nonzero().view(-1) + + cu_seq_lens_q = torch.cat( + ( + indices_q.to(**tensor_kwargs), + torch.tensor(position_ids.size(), **tensor_kwargs), + ) + ) + cu_seq_lens_k = cu_seq_lens_q + + # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 + # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing + # for some models (e.g. qwen2-vl). + max_length_q = cu_seq_lens_q.diff().max() + # NOTE: With torch compile, this will cause a graph break if you don't set + # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call + # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. + # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` + # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. + max_length_q = max_length_q.item() + max_length_k = max_length_q + + return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) + + +def _prepare_from_posids(query, key, value, position_ids): + """ + This function returns necessary arguments to call `flash_sparse_attn_varlen_func`. + All three query, key, value states will be flattened. + Cumulative lengths of each examples in the batch will be extracted from position_ids. + NOTE: ideally cumulative lengths should be prepared at the data collator stage + + Arguments: + query (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + query (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.contiguous().view(-1, query.size(-2), query.size(-1)) + key = key.contiguous().view(-1, key.size(-2), key.size(-1)) + value = value.contiguous().view(-1, value.size(-2), value.size(-1)) + + (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fsa_kwargs_from_position_ids(position_ids) + + return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)) + + +def _is_packed_sequence(position_ids, batch_size): + """ + Check the position ids whether packed sequences are indicated or not + 1. Position ids exist + 2. Flattened sequences only are supported + 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences + """ + if position_ids is None: + return False + + increasing_position_sequences = ( + torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min() + ) + return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool() + + +def fsa_peft_integration_check( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bias: Optional[torch.Tensor], + target_dtype: Optional[torch.dtype] = None +): + """ + PEFT usually casts the layer norms in float32 for training stability reasons + therefore the input hidden states gets silently casted in float32. Hence, we need + cast them back in float16 / bfloat16 just to be sure everything works as expected. + This might slowdown training & inference so it is recommended to not cast the LayerNorms! + """ + if target_dtype and q.dtype == torch.float32: + logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash_sparse_attn compatibility.") + q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) + if bias is not None: + bias = bias.to(target_dtype) + return q, k, v, bias + + +class FlashSparseAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for Flash Sparse Attention with Compile. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`, *optional*) + Gets cumulative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`, *optional*) + Gets cumulative sequence length for key state. + max_length_q (`int`, *optional*): + Maximum sequence length for query state. + max_length_k (`int`, *optional*): + Maximum sequence length for key state. + """ + + cu_seq_lens_q: Optional[torch.LongTensor] + cu_seq_lens_k: Optional[torch.LongTensor] + max_length_q: Optional[int] + max_length_k: Optional[int] + + +def _process_flash_sparse_attention_kwargs( + query_length: int, + key_length: int, + is_causal: bool, + softmax_scale: Optional[float] = None, + window_size: Optional[int] = None, + softcap: Optional[float] = None, + deterministic: Optional[bool] = None, + s_aux: Optional[torch.Tensor] = None, + supports_mapping: Optional[dict[str, bool]] = None, + **kwargs, +): + """ + Returns a set of kwargs that are passed down to the according flash attention function based on + requested features and whether it is supported - depends on the version and kernel implementation + which is dynamically configured at `lazy_import_flash_sparse_attention`. The (un)supported features can be + inspected in `supports_mapping`, see `_lazy_define_process_function` for more details. + + Args: + query_length (`int`): + Length of the query states + key_length (`int`): + Length of the key states + is_causal (`bool`): + Whether we perform causal (decoder) attention or full attention. + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`. + window_size (`int`, *optional*): + If set, only the `window_size` largest key/value pairs per query are kept. + softcap (`float`, *optional*): + Softcap for the attention logits, used e.g. in gemma2. + deterministic (`bool`, *optional*): + Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. + s_aux (`torch.Tensor`, *optional*): + Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head. + Return: + flash_kwargs (`dict`): + A dict of kwargs that are requested and supported. + """ + flash_kwargs = { + "is_causal": is_causal and not query_length == 1, + "softmax_scale": softmax_scale, + } + + if supports_mapping["window_size"] and window_size is not None and key_length > window_size: + flash_kwargs["window_size"] = window_size + + if supports_mapping["deterministic"]: + flash_kwargs["deterministic"] = ( + deterministic if deterministic is not None else os.getenv("FLASH_SPARSE_ATTENTION_DETERMINISTIC", "0") == "1" + ) + + if supports_mapping["softcap"] and softcap is not None: + flash_kwargs["softcap"] = softcap + + # Only within kernel implementation atm + if supports_mapping["s_aux"] and s_aux is not None: + flash_kwargs["s_aux"] = s_aux + + return flash_kwargs + + +def _flash_sparse_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + attention_bias: Optional[torch.Tensor], + query_length: int, + key_length: int, + is_causal: bool, + position_ids: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + window_size: Optional[int] = None, + softcap: Optional[float] = None, + deterministic: Optional[bool] = None, + cu_seq_lens_q: Optional[torch.LongTensor] = None, + cu_seq_lens_k: Optional[torch.LongTensor] = None, + max_length_q: Optional[int] = None, + max_length_k: Optional[int] = None, + target_dtype: Optional[torch.dtype] = None, + implementation: Optional[str] = None, + **kwargs, +): + """ + Calls the forward method of Flash Sparse Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + (Optional) kwargs are described further in `_process_flash_sparse_attention_kwargs` and `FlashSparseAttentionKwargs`. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to FSA API + key_states (`torch.Tensor`): + Input key states to be passed to FSA API + value_states (`torch.Tensor`): + Input value states to be passed to FSA API + attention_mask (`torch.Tensor`, *optional*): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + attention_bias (`torch.Tensor`, *optional*): + The attention bias tensor to add to attention scores. + implementation (`str`, *optional*): + The attention implementation to use. If None, will default to the one based on the environment. + """ + + if ( + attention_mask is not None + and attention_mask.dim() == 2 + and attention_bias is not None + ): + raise ValueError( + "If shape of attention_mask is (batch_size, seq_len), attention_bias has to be None." + ) + + (fsa_fn, fsa_varlen_fn, pad_fn, unpad_fn, create_mask_fn), process_flash_kwargs_fn = lazy_import_flash_sparse_attention(implementation) + + # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op + query_states, key_states, value_states, attention_bias = fsa_peft_integration_check( + query_states, key_states, value_states, attention_bias, target_dtype + ) + + # Extract the flash attention kwargs that have been requested (and are supported by the implementation) + flash_kwargs = process_flash_kwargs_fn( + query_length=query_length, + key_length=key_length, + is_causal=is_causal, + softmax_scale=softmax_scale, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + **kwargs, + ) + + # We will use `fsa_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases: + # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`. + # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to + # use `fsa_varlen_fn` knowing we already have all necessary the kwargs. + # + # NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model. + # See #39121 for more information. + is_fsa_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0)) + is_fsa_with_varlen_kwargs = all( + kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) + ) + + # Contains at least one padding token in the sequence + if attention_mask is not None and attention_mask.dim() == 2: + q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input( + query_states, key_states, value_states, attention_mask, query_length, unpad_fn + ) + + # TODO for now this is required to work with + # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py + if "mps" in str(q.device): + cu_seq_lens_k = cu_seq_lens_k.clone() + + out_unpad = fsa_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + **flash_kwargs, + ) + if isinstance(out_unpad, tuple): + out_unpad = out_unpad[0] + + out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length) + + # Padding free, i.e. sequences flattened into one total sequence + elif is_fsa_with_varlen_kwargs or is_fsa_with_position_ids: + if cu_seq_lens_q is None or cu_seq_lens_k is None: + q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids( + query_states, key_states, value_states, position_ids + ) + else: + q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) + k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) + v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) + + # TODO for now this is required to work with + # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py + if "mps" in str(q.device): + cu_seq_lens_k = cu_seq_lens_k.clone() + + out = fsa_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + **flash_kwargs, + ) + if isinstance(out, tuple): + out = out[0] + + out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1)) + + # No padding + else: + + # Generate a combined attention mask if `attention_bias` are provided + if ( + attention_bias is not None + and window_size is not None + and key_length > window_size + ): + attention_mask = create_mask_fn( + attention_bias, + attention_mask, + batch_size=query_states.size(0), + query_len=query_length, + key_len=key_length, + window_size=window_size, + min_dtype=torch.finfo(attention_bias.dtype).min, + ) + + out = fsa_fn( + query_states, + key_states, + value_states, + attention_mask, + attention_bias, + **flash_kwargs, + ) + if isinstance(out, tuple): + out = out[0] + + return out From 11a0862e82fb12cb95cf6a1d2aaad833d19e5709 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:28:35 +0800 Subject: [PATCH 17/29] Adds dynamic mask helpers Introduces mask utilities for top-k and relu masking to support flash sparse attention. Enables optional block smoothing to stabilize dynamic sparsity patterns. --- flash_sparse_attn/utils/mask.py | 240 ++++++++++++++++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 flash_sparse_attn/utils/mask.py diff --git a/flash_sparse_attn/utils/mask.py b/flash_sparse_attn/utils/mask.py new file mode 100644 index 0000000..4905835 --- /dev/null +++ b/flash_sparse_attn/utils/mask.py @@ -0,0 +1,240 @@ +# Copyright 2025 Jingze Shi and Liangdong Wang. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + + +def topk_indices( + attention_bias: torch.Tensor, + window_size: int, + **kwargs, +) -> torch.Tensor: + r""" + This function generates top-k indices based on the attention bias. + + Args: + attention_bias (torch.Tensor): The attention bias tensor of + (batch_size, num_kv_heads, key_len). + window_size (int): The number of top elements to consider for the mask. + **kwargs: Additional keyword arguments. + + Returns: + topk_indices (Tensor): The top-k indices tensor of shape + (batch_size, num_kv_heads, window_size). + """ + attention_bias = attention_bias.detach() + topk_indices = torch.topk( + attention_bias, + window_size, dim=-1, largest=True, sorted=False + ).indices + topk_indices = torch.sort(topk_indices, dim=-1).values + return topk_indices + + +def block_smooth( + attention_mask: torch.Tensor, + key_len: int, + block_size: int, +): + if block_size <= 0: + raise ValueError(f"block_size must be a positive integer, got {block_size}.") + + if block_size > 1: + full_len = (key_len // block_size) * block_size + + if full_len: + block_view = attention_mask[..., :full_len] + block_shape = (*block_view.shape[:-1], full_len // block_size, block_size) + blocks = block_view.view(*block_shape) + block_counts = blocks.sum(dim=-1).to(torch.int64) + block_keep = (block_counts * 2) > block_size + blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks)) + + if key_len > full_len: + tail_slice = attention_mask[..., full_len:] + tail_len = tail_slice.shape[-1] + tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64) + tail_keep = (tail_counts * 2) > tail_len + tail_slice.copy_(tail_keep.expand_as(tail_slice)) + + return attention_mask + + +def topk_mask( + attention_bias: torch.Tensor, + attention_mask: Optional[torch.Tensor], + window_size: int, + min_dtype: float, + block_size: Optional[int] = None, + **kwargs, +): + r""" + This function generates a dynamic mask based on the top-k attention bias. + + Args: + attention_bias (torch.Tensor): The attention bias tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + window_size (int): The number of top elements to consider for the mask. + min_dtype (float): The minimum value to use for masking. + block_size (Optional[int]): Optional size of aggregation blocks to smooth the + resulting mask along the key dimension. + + Returns: + attention_mask (Tensor): The attention mask tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + """ + + attention_bias = attention_bias.detach() + attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias + topk_values, topk_indices = torch.topk( + attention_bias, + window_size, dim=-1, largest=True, sorted=False + ) + attention_mask = torch.zeros_like( + attention_bias, dtype=torch.bool, device=attention_bias.device + ).scatter_(-1, topk_indices, topk_values != min_dtype) + + if block_size is not None and block_size > 1: + key_len = attention_mask.shape[-1] + attention_mask = block_smooth( + attention_mask=attention_mask, + key_len=key_len, + block_size=block_size + ) + + return attention_mask + + +def relu_mask( + attention_bias: torch.Tensor, + attention_mask: Optional[torch.Tensor], + min_dtype: float, + block_size: Optional[int] = None, + **kwargs +): + r""" + This function generates a dynamic mask based on the ReLU of attention bias. + + Args: + attention_bias (torch.Tensor): The attention bias tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + min_dtype (float): The minimum value to use for masking. + block_size (Optional[int]): Optional size of aggregation blocks to smooth the + resulting mask along the key dimension. + + Returns: + attention_mask (Tensor): The attention mask tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + """ + + attention_bias = attention_bias.detach() + attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias + attention_mask = attention_bias > 0 + + if block_size is not None and block_size > 1: + key_len = attention_mask.shape[-1] + attention_mask = block_smooth( + attention_mask=attention_mask, + key_len=key_len, + block_size=block_size + ) + + return attention_mask + + + +def create_mask( + attention_bias: torch.Tensor, + attention_mask: Optional[torch.Tensor], + batch_size: int, + query_len: int, + key_len: int, + window_size: int, + min_dtype: float, + block_size: Optional[int] = None, + type: str = "topk", +) -> torch.Tensor: + r""" + This function creates a mask tensor for Flash Sparse Attention. + + If attention_mask is not of shape (batch_size, seq_len), it needs to match the shape of attention_bias. + + Args: + attention_bias (torch.Tensor): The attention bias tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape + (batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + batch_size (int): The batch size. + query_len (int): The sequence length of the query. + key_len (int): The sequence length of the key. + window_size (int): The number of top elements to consider for the attention mask. + min_dtype (float): The minimum value to use for masking. + block_size (Optional[int]): Optional size of aggregation blocks after top-k masking. + type (str): The type of mask to create. Options are "topk" and "relu". + + Returns: + attention (Tensor): The attention mask tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + """ + + # If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len) + if attention_mask is not None and attention_mask.dim() == 2: + if attention_mask.shape[-1] == key_len: + attention_mask = attention_mask.view(batch_size, 1, 1, key_len) + elif attention_mask.shape[-1] == query_len: + pad_len = key_len - query_len + if pad_len > 0: + pad_mask = torch.ones( + (batch_size, 1, 1, pad_len), + dtype=torch.bool, + device=attention_mask.device, + ) + attention_mask = torch.cat( + [pad_mask, attention_mask.view(batch_size, 1, 1, query_len)], + dim=-1, + ) + else: + attention_mask = attention_mask.view(batch_size, 1, 1, query_len) + else: + raise ValueError( + f"attention_mask shape {attention_mask.shape} is not compatible with key_len {key_len} or query_len {query_len}." + ) + + # Generate dynamic mask based on attention_bias and attention_mask + if type == "topk": + attention_mask = topk_mask( + attention_bias=attention_bias, + attention_mask=attention_mask, + window_size=window_size, + min_dtype=min_dtype, + block_size=block_size, + ) + elif type == "relu": + attention_mask = relu_mask( + attention_bias=attention_bias, + attention_mask=attention_mask, + window_size=window_size, + min_dtype=min_dtype, + block_size=block_size, + ) + else: + raise ValueError(f"Unsupported mask type: {type}. Supported types are 'topk' and 'relu'.") + + return attention_mask From 3dd3392c28f31f83ace2f4bb696e73a21b981fa0 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:28:52 +0800 Subject: [PATCH 18/29] Adds shared unpadding utilities Introduces reusable padding helpers to consolidate ragged tensor handling and avoid recomputing per layer indices. Addresses static-cache overflow by slicing KV states and provides local indexing to keep graph-friendly. --- flash_sparse_attn/utils/padding.py | 170 +++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 flash_sparse_attn/utils/padding.py diff --git a/flash_sparse_attn/utils/padding.py b/flash_sparse_attn/utils/padding.py new file mode 100644 index 0000000..b675af7 --- /dev/null +++ b/flash_sparse_attn/utils/padding.py @@ -0,0 +1,170 @@ +# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py + +import torch +import torch.nn.functional as F + + +def index_first_axis(tensor, indices): + """ + A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, + after flattening the first two dimensions of the tensor. + """ + # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first + # two dimensions to get (total_tokens, ...) before indexing. + reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) + return reshaped_tensor[indices] + + +def unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + index_first_axis(hidden_states, indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +def get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + # NOTE: Similar to the `.item()` in prepare_fdma_from_position_ids, with torch compile, + # this might cause a graph break + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + unpad_input_func, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + unpad_input_func: + The function to use for unpadding the input tensors. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask) + + # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage + # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores + if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): + key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] + + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer, indices_k) + value_layer = index_first_axis(value_layer, indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer, indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) From df69839ff2fd96c9a961be0e83f63a6f78e9fc83 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:29:42 +0800 Subject: [PATCH 19/29] Updates flash attention integration Points the integration to the renamed sparse attention package so setup guidance stays accurate. --- examples/modeling/modeling_doge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/modeling/modeling_doge.py b/examples/modeling/modeling_doge.py index d350f71..af1f791 100644 --- a/examples/modeling/modeling_doge.py +++ b/examples/modeling/modeling_doge.py @@ -45,9 +45,9 @@ from .configuration_doge import DogeConfig try: - from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward + from flash_sparse_attn.integrations.flash_sparse_attention import flash_dynamic_mask_attention_forward except ImportError: - print("Please install flash_dmattn to use this model: pip install flash-dmattn") + print("Please install flash_sparse_attn to use this model: pip install flash-sparse-attn") if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask From 7e3faab11b1bd42a5c4f598832b158656c069c9f Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:30:25 +0800 Subject: [PATCH 20/29] Remove outdated documentation files for Flash Dynamic Mask Attention integration and v1.0.0 technical report. These files have been superseded by updated documentation reflecting recent changes and improvements in the codebase. --- docs/integration.md | 2872 ------------------------------- docs/integration_zh.md | 522 ------ docs/v1.0.0_technical_report.md | 299 ---- 3 files changed, 3693 deletions(-) delete mode 100644 docs/integration.md delete mode 100644 docs/integration_zh.md delete mode 100644 docs/v1.0.0_technical_report.md diff --git a/docs/integration.md b/docs/integration.md deleted file mode 100644 index 31a57d1..0000000 --- a/docs/integration.md +++ /dev/null @@ -1,2872 +0,0 @@ -# Flash Dynamic Mask Attention Integration Guide - -## Overview - -This document describes the integration of Dynamic Mask Attention into the Flash Attention framework. The integration enables efficient sparse attention computation by combining Flash Attention's memory-efficient approach with dynamic masking capabilities for handling extremely long sequences. - -The integration implements a unified sparse computation approach with block-level skip logic: Python frontend pre-computes Attention Mask and Attention Bias tensors, while the CUDA backend performs block-level skip decisions and sparse attention computation for both forward and backward passes. - -## Table of Contents - -1. [Integration Architecture](#integration-architecture) -2. [Core Modifications](#core-modifications) -3. [Implementation Details](#implementation-details) -4. [Sparse Computation Strategy](#sparse-computation-strategy) -5. [Memory Layout](#memory-layout) -6. [Performance Considerations](#performance-considerations) -7. [API Changes](#api-changes) - -## Integration Architecture - -### High-Level Design - -The Dynamic Mask Attention integration implements a unified sparse computation approach with block-level skip logic for both forward and backward passes: - -1. **Dynamic Mask Computation**: Python frontend pre-computes Attention Mask and Attention Bias tensors -2. **Unified Sparse Execution**: CUDA backend performs block-level skip decisions for both forward and backward passes -3. **Memory Optimization**: Smart shared memory aliasing and barrier synchronization - - -### Key Components - -- **Attention Mask**: Binary mask `(batch, num_kv_heads, query_len, key_len)` indicating which positions should be computed (1.0) or skipped (0.0) -- **Attention Bias**: Dynamic attention bias values `(batch, num_kv_heads, query_len, key_len)` applied to attention scores before softmax -- **Block-level Skip Logic**: Unified OR-reduction over (BlockM × BlockN) tiles to determine if computation should be performed -- **LSE Caching**: Log-sum-exp values cached during forward pass for numerically stable backward recomputation -- **Shared Memory Aliasing**: Smart memory reuse with explicit barrier synchronization -- **Complete Gradient Chain**: Full gradient computation pipeline with sparse skip capability -- **Memory Optimization**: Reduced shared memory footprint enabling larger tile sizes and higher occupancy - -## Core Modifications - -### 1. Parameter Structure Extensions (`flash.h`) - -**Purpose**: Extended parameter structures to support dynamic masking tensors with proper memory layout information. - -**Changes Made**: -```cpp -struct QKV_params { - // The QKV matrices. - void *__restrict__ q_ptr; // Query tensor [batch_size, num_heads, query_len, head_dim] - void *__restrict__ k_ptr; // Key tensor [batch_size, num_kv_heads, key_len, head_dim] - void *__restrict__ v_ptr; // Value tensor [batch_size, num_kv_heads, key_len, head_dim] - - // The stride between rows of the Q, K and V matrices. - index_t q_batch_stride, k_batch_stride, v_batch_stride; - index_t q_row_stride, k_row_stride, v_row_stride; - index_t q_head_stride, k_head_stride, v_head_stride; - - // The number of heads. - int h, h_k; - int h_h_k_ratio; // precompute h / h_k -}; - -struct Mask_params { - void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len] - - // The stride of the attention mask tensors. - index_t mask_batch_stride; // Stride between batches of attention mask - index_t mask_head_stride; // Stride between heads of attention mask - index_t mask_row_stride; // Stride between rows of attention mask -}; - -struct Bias_params { - void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len] - - // The stride of the attention bias tensor. - index_t bias_batch_stride; // Stride between batches of attention bias - index_t bias_head_stride; // Stride between heads of attention bias - index_t bias_row_stride; // Stride between rows of attention bias -}; - -struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params { - - // The O matrix (output). - void * __restrict__ o_ptr; - void * __restrict__ oaccum_ptr; - - // The stride between rows of O. - index_t o_batch_stride; - index_t o_row_stride; - index_t o_head_stride; - - // The pointer to the P matrix. - void * __restrict__ p_ptr; - - // The pointer to the softmax sum. - void * __restrict__ softmax_lse_ptr; - void * __restrict__ softmax_lseaccum_ptr; - - // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; - - // The scaling factors for the kernel. - float scale_softmax; - float scale_softmax_log2; - float softcap; - - // array of length b+1 holding starting offset of each sequence. - int * __restrict__ cu_seqlens_q; - int * __restrict__ cu_seqlens_k; - int * __restrict__ leftpad_k; - - // If provided, the actual length of each k sequence. - int * __restrict__ seqused_k; - - int *__restrict__ blockmask; - - // The K_new and V_new matrices. - void * __restrict__ knew_ptr; - void * __restrict__ vnew_ptr; - - // The stride between rows of the K_new and V_new matrices. - index_t knew_batch_stride; - index_t vnew_batch_stride; - index_t knew_row_stride; - index_t vnew_row_stride; - index_t knew_head_stride; - index_t vnew_head_stride; - - // The cos and sin matrices for rotary embedding. - void * __restrict__ rotary_cos_ptr; - void * __restrict__ rotary_sin_ptr; - - // The indices to index into the KV cache. - int * __restrict__ cache_batch_idx; - - // Paged KV cache - int * __restrict__ block_table; - index_t block_table_batch_stride; - int page_block_size; - - bool is_bf16; - bool is_causal; - - // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. - // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. - bool is_seqlens_k_cumulative; - - int num_splits; // For split-KV version - - bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. - bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). -}; -``` - -**Rationale**: -- **Multiple Inheritance Design**: Cleanly separates QKV parameters from Mask/Bias parameters while maintaining unified access -- **Comprehensive Stride Information**: Provides all necessary stride information for efficient tensor indexing in CUDA kernels -- **Memory Layout Optimization**: Enables optimal memory access patterns for both regular and sparse tensors - -### 2. Kernel Traits and Memory Layout (`kernel_traits.h`) - -**Purpose**: Define kernel characteristics and memory layouts optimized for dynamic masking operations, supporting both SM75 and SM80+ architectures. - -**Changes Made**: -```cpp -template -struct Flash_kernel_traits { - using Element = elem_type; - using ElementAccum = float; - using index_t = int64_t; - - static constexpr int kHeadDim = kHeadDim_; - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kNWarps = kNWarps_; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - static constexpr bool Has_cp_async = true; - using MMA_Atom_Arch = std::conditional_t< - std::is_same_v, - MMA_Atom, - MMA_Atom - >; -#else - static constexpr bool Has_cp_async = false; - using MMA_Atom_Arch = MMA_Atom; -#endif - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#else - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#endif - - // Specialized traits for mask and bias operations - using SmemCopyAtomMask = SmemCopyAtom; - using SmemCopyAtomBias = SmemCopyAtom; -}; -``` - -**Rationale**: -- **Architecture Adaptation**: Automatically selects optimal MMA atoms and copy operations based on GPU architecture -- **Type Safety**: Template-based design ensures type consistency across mask, bias, and attention operations -- **Performance Optimization**: Leverages specialized load/store instructions (LDSM) for maximum memory bandwidth - -### 3. Block Information Extension (`block_info.h`) - -**Purpose**: Calculate memory offsets for attention bias and attention masks within thread blocks, enabling efficient global memory access. - -**Changes Made**: -```cpp -template -struct BlockInfo { - template - __device__ BlockInfo(const Params ¶ms, const int bidb) - : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) - , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) - , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) - , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) - , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : - (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k) - , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : - seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) - { - } - - template - __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; - } - - template - __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride; - } - - template - __forceinline__ __device__ index_t mask_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; - sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k); - return offset; - } - - template - __forceinline__ __device__ index_t bias_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; - sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k); - return offset; - } - - const int sum_s_q, sum_s_k; - const int actual_seqlen_q; - const int leftpad_k; - const int seqlen_k_cache; - const int actual_seqlen_k; -}; -``` - -**Rationale**: -- **Unified Offset Calculation**: Provides dedicated methods for calculating mask and bias tensor offsets -- **Variable Length Support**: Handles both fixed and variable length sequences through template specialization -- **Memory Access Optimization**: Encapsulates complex address arithmetic for efficient global memory access - -### 4. Memory Copy Operations (`utils.h`) - -**Purpose**: Implement efficient tensor operations and layout conversions optimized for Flash Attention's memory hierarchy. - -**Changes Made**: -```cpp -namespace FLASH_NAMESPACE { - -// Convert accumulator layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) -template -__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); -}; - -// Type conversion utilities for different precisions -template -__forceinline__ __device__ T convert_type(float x) { - return T(x); -} - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template<> -__forceinline__ __device__ cutlass::bfloat16_t convert_type(float x) { - return cutlass::bfloat16_t(x); -} -#endif - -// Warp-level reduction operations -template -__forceinline__ __device__ float warp_reduce_sum(float x) { -#pragma unroll - for (int mask = THREADS / 2; mask > 0; mask >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, mask); - } - return x; -} - -// GEMM operations with register and shared memory variants -template < - bool A_in_regs=false, bool B_in_regs=false, - typename Tensor0, typename Tensor1, typename Tensor2, - typename Tensor3, typename Tensor4, - typename TiledMma, typename TiledCopyA, typename TiledCopyB, - typename ThrCopyA, typename ThrCopyB -> -__forceinline__ __device__ void gemm( - Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, - Tensor3 &tCsA, Tensor4 &tCsB, - TiledMma tiled_mma, - TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, - ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B -) { - if constexpr (!A_in_regs) { - copy(smem_tiled_copy_A, tCsA, tCrA); - } - if constexpr (!B_in_regs) { - copy(smem_tiled_copy_B, tCsB, tCrB); - } - - // Perform matrix multiplication - gemm(tiled_mma, acc, tCrA, tCrB, acc); -} - -} // namespace FLASH_NAMESPACE -``` - -**Rationale**: -- **Layout Conversion**: Efficient transformation between MMA and row-column layouts for easier tensor manipulation -- **Multi-Precision Support**: Proper type conversion utilities for FP16 and BF16 operations -- **Memory Hierarchy Management**: Flexible GEMM operations supporting different data residency patterns -- **Performance Optimization**: Warp-level reductions and vectorized operations for maximum throughput - -### 5. Dynamic Masking Logic (`mask.h`) - -**Purpose**: Implement the core dynamic masking functionality that applies attention bias and attention masks during attention computation. - -**Changes Made**: -```cpp -template -__forceinline__ __device__ void apply_mask( - TensorType &tensor, - MaskType &mask, - BiasType &bias, - const float scale_softmax, - const int col_idx_offset_, - const int max_seqlen_k, - const int row_idx_offset, - const int max_seqlen_q, - const int warp_row_stride -) { - // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) - static_assert(TensorType::rank == 2, "Only support 2D Tensor"); - static_assert(MaskType::rank == 2, "Only support 2D Mask"); - static_assert(BiasType::rank == 2, "Only support 2D Bias"); - - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? - std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : - max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias or masking - tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); - } - } - } - } -} - -template -struct Mask { - const int max_seqlen_k, max_seqlen_q; - - __forceinline__ __device__ Mask( - const int max_seqlen_k, - const int max_seqlen_q - ) // Constructor - : max_seqlen_k(max_seqlen_k) - , max_seqlen_q(max_seqlen_q) { - }; - - template - __forceinline__ __device__ void apply_mask( - TensorType &tensor_, // acc_s (attention scores, MMA=4, MMA_M, MMA_N) - MaskType &tSrMask, // Attention Mask (MMA=4, MMA_M, MMA_N) - BiasType &tSrBias, // Attention Bias (MMA=4, MMA_M, MMA_N) - const float scale_softmax, // Scale for softmax - const int col_idx_offset_, // Column index offset - const int row_idx_offset, // Row index offset - const int warp_row_stride // Warp row stride - ) { - // Reshape tensors from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); - Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); - Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); - - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? - std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : - max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias or masking - tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); - } - } - } - } - } -}; -``` - -**Rationale**: -- **Register-Level Operations**: All masking operations performed in registers for maximum efficiency -- **Unified Masking Logic**: Combines causal masking, boundary checking, and dynamic masking in a single pass -- **Layout Conversion**: Properly handles MMA tensor layout conversion for efficient indexing -- **Numerical Stability**: Proper handling of infinity values for masked positions ensures stable softmax computation - -### 6. Backward Pass Integration (`flash_bwd_kernel.h`) - -**Purpose**: Extend backward pass computation to support dynamic masking with proper gradient computation for masked positions. - -**Changes Made**: -```cpp -struct Flash_bwd_params : public Flash_fwd_params { - - // The dO and dQKV and dBias matrices. - void *__restrict__ do_ptr; - void *__restrict__ dq_ptr; - void *__restrict__ dk_ptr; - void *__restrict__ dv_ptr; - void *__restrict__ dbias_ptr; - - // To accumulate dQ, dK, dV - void *__restrict__ dq_accum_ptr; - void *__restrict__ dk_accum_ptr; - void *__restrict__ dv_accum_ptr; - - // The stride between rows of the dO, dQ, dK and dV matrices. - index_t do_batch_stride; - index_t do_row_stride; - index_t do_head_stride; - index_t dq_batch_stride; - index_t dk_batch_stride; - index_t dv_batch_stride; - index_t dq_row_stride; - index_t dk_row_stride; - index_t dv_row_stride; - index_t dq_head_stride; - index_t dk_head_stride; - index_t dv_head_stride; - index_t dbias_batch_stride; - index_t dbias_head_stride; - index_t dbias_row_stride; - - // The pointer to the softmax d sum. - void *__restrict__ dsoftmax_sum; - - bool deterministic; - index_t dq_accum_split_stride; -}; - -template -inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { - // Backward pass computation with dynamic masking support - // Includes proper gradient computation through masked attention scores - // Maintains numerical stability for masked positions - - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Initialize block information and tensor views - const BlockInfo binfo(params, bidb); - - // Set up gradient computation with masking awareness - // Load bias and mask gradients when computing dBias - // Apply masking logic consistently with forward pass -} -``` - -**Rationale**: -- **Gradient Consistency**: Ensures gradients are computed consistently with forward pass masking logic -- **Memory Layout Preservation**: Maintains the same memory layout and stride patterns as forward pass -- **Numerical Stability**: Proper handling of gradients at masked positions to prevent NaN propagation - -### 7. Attention Kernel Modifications (`flash_fwd_kernel.h`) - -**Purpose**: Integrate dynamic masking into the core attention computation kernels while maintaining Flash Attention's memory efficiency and optimization strategies. - -**Changes Made**: -```cpp -template -inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Initialize block information - const BlockInfo binfo(params, bidb); - - // Set up tensor views for Q, K, V matrices - Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, Int{}), - make_stride(params.q_row_stride, _1{})); - - Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), - make_shape(binfo.actual_seqlen_k, Int{}), - make_stride(params.k_row_stride, _1{})); - - // Set up mask and bias tensor views if available - Tensor mMask, mBias; - if (params.mask_ptr != nullptr) { - mMask = make_tensor(make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.mask_row_stride, _1{})); - } - - if (params.bias_ptr != nullptr) { - mBias = make_tensor(make_gmem_ptr(reinterpret_cast(params.bias_ptr) + binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.bias_row_stride, _1{})); - } - - // Main computation loop with dynamic masking integration - for (int n_block = n_block_min; n_block < n_block_max; ++n_block) { - // Standard Flash Attention computation: Q*K^T - gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, - smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K); - - // Apply dynamic masking if mask/bias tensors are provided - if (params.mask_ptr != nullptr || params.bias_ptr != nullptr) { - Mask mask(params.seqlen_k, params.seqlen_q); - mask.apply_mask(acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * Kernel_traits::kBlockN, m_block * Kernel_traits::kBlockM, - Kernel_traits::kBlockM); - } - - // Continue with softmax computation - softmax.template softmax_rescale_o( - acc_s, acc_o, params.scale_softmax_log2 - ); - - // Attention * V computation - gemm(acc_o, acc_s, tSrV, acc_s, tSsV, tiled_mma, - smem_tiled_copy_S, smem_tiled_copy_V, - smem_thr_copy_S, smem_thr_copy_V); - } -} -``` - -**Rationale**: -- **Seamless Integration**: Dynamic masking logic integrated into existing Flash Attention computation flow without affecting core performance -- **Memory Efficiency Preservation**: Maintains Flash Attention's tiling and shared memory optimization strategies -- **Conditional Execution**: Only applies masking operations when mask/bias tensors are actually provided -- **Template Specialization**: Compile-time optimization eliminates runtime branching for better performance - -### 8. Launch Template Updates (`flash_fwd_launch_template.h`) - -**Purpose**: Update kernel launch templates to support dynamic masking functionality with proper template instantiation and dispatch logic. - -**Changes Made**: -```cpp -// Determine if the architecture supports FLASH and define parameter modifiers -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#define ARCH_SUPPORTS_FLASH -#define KERNEL_PARAM_MODIFIER __grid_constant__ -#else -#define KERNEL_PARAM_MODIFIER -#endif - -// Define unsupported architecture error handling -#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashDynamicMaskAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); - -// Kernel definition macro for cleaner code -#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ -template \ -__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) - -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { - #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn(params); - #else - FLASH_UNSUPPORTED_ARCH - #endif -} - -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split) { - #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn_splitkv(params); - #else - FLASH_UNSUPPORTED_ARCH - #endif -} - -template -void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr size_t smem_size = Kernel_traits::kSmemSize; - - // Handle different precision types and head dimensions - BOOL_SWITCH(params.is_bf16, Is_Bf16, [&] { - using elem_type = std::conditional_t; - HEADDIM_SWITCH(params.d, [&] { - BOOL_SWITCH(params.seqlen_k % Kernel_traits::kBlockN == 0, Is_even_N, [&] { - BOOL_SWITCH(params.d == kHeadDim, Is_even_K, [&] { - SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - auto kernel = &flash_fwd_kernel; - // Launch kernel with appropriate grid and block dimensions - kernel<<>>(params); - }); - }); - }); - }); - }); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -// Template instantiations for different configurations -template -void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream); -template -void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream); -template -void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream); -template -void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream); -template -void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream); -template -void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream); -``` - -**Rationale**: -- **Template Dispatch**: Efficient compile-time branching based on runtime parameters for optimal performance -- **Architecture Support**: Proper handling of different GPU architectures with appropriate error messages -- **Memory Management**: Correct shared memory allocation based on kernel requirements -- **Type Safety**: Strong typing through template parameters ensures correctness across different precisions - -**Purpose**: Update kernel launch functions to properly configure and validate dynamic masking parameters, ensuring correct shared memory allocation and kernel selection. - -**Changes Made**: -```cpp -template -void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - // Calculate shared memory requirements - constexpr size_t smem_size = Kernel_traits::kSmemSize; - - // Set up grid dimensions - const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; - dim3 grid(num_m_block, params.b, params.h); - - // Determine kernel variant based on sequence lengths and alignment - const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && - params.seqlen_k % Kernel_traits::kBlockN == 0 && - params.seqlen_q % Kernel_traits::kBlockM == 0; - const bool is_even_K = params.d == Kernel_traits::kHeadDim; - const bool return_softmax = params.p_ptr != nullptr; - - // Launch appropriate kernel variant with dynamic masking support - BOOL_SWITCH(is_even_MN, IsEvenMN, [&] { - BOOL_SWITCH(is_even_K, IsEvenK, [&] { - BOOL_SWITCH(return_softmax, ReturnSoftmax, [&] { - auto kernel = &flash_fwd_kernel; - - // Configure dynamic shared memory if needed - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - - // Launch kernel with extended parameter set - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); -} - -template -void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - // Split-K variant launch with dynamic masking support - // Handles cases where sequence length exceeds single kernel capacity - static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); - static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); - - // Configure split parameters based on sequence length and hardware capabilities - const int num_splits = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; - // ... split-K launch logic with dynamic masking support -} -``` - -**Rationale**: -- **Resource Management**: Proper shared memory allocation and validation for extended tensor requirements -- **Kernel Selection**: Intelligent kernel variant selection based on problem size and hardware capabilities -- **Error Handling**: Comprehensive validation of parameters and device limits -- **Performance Optimization**: Compile-time optimizations through template specialization - -### 9. API Interface Extensions (`flash_api.cpp`) - -**Purpose**: Extend the Python-facing API to support dynamic masking tensors with comprehensive validation and backward compatibility. - -**Changes Made**: -```cpp -void set_params_fprop( - Flash_fwd_params ¶ms, - // ... existing parameters ... - const at::Tensor mask, // Attention mask tensor - const at::Tensor bias, // Attention bias tensor - // ... other parameters ... -) { - // Reset parameters and set basic properties - params = {}; - params.is_bf16 = q.dtype() == torch::kBFloat16; - - // Set attention mask pointers and strides - params.mask_ptr = mask.data_ptr(); - params.mask_batch_stride = mask.stride(-4); - params.mask_head_stride = mask.stride(-3); - params.mask_row_stride = mask.stride(-2); - - // Set attention bias pointers and strides - params.bias_ptr = bias.data_ptr(); - params.bias_batch_stride = bias.stride(-4); - params.bias_head_stride = bias.stride(-3); - params.bias_row_stride = bias.stride(-2); - - // ... existing parameter setup ... -} - -std::vector mha_fwd( - at::Tensor &q, // Query tensor - const at::Tensor &k, // Key tensor - const at::Tensor &v, // Value tensor - const at::Tensor &mask, // Attention mask tensor - const at::Tensor &bias, // Attention bias tensor - std::optional &out_, // Optional output tensor - const float softmax_scale, - bool is_causal, - const float softcap, - const bool return_softmax -) { - // Comprehensive input validation - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - CHECK_DEVICE(mask); CHECK_DEVICE(bias); - CHECK_CONTIGUOUS(q); CHECK_CONTIGUOUS(k); CHECK_CONTIGUOUS(v); - CHECK_CONTIGUOUS(mask); CHECK_CONTIGUOUS(bias); - - // Validate tensor shapes - auto batch_size = q.size(0); - auto seqlen_q = q.size(1); - auto num_heads = q.size(2); - auto head_dim = q.size(3); - auto seqlen_k = k.size(1); - auto num_heads_k = k.size(2); - - CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); - CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); - - // Validate data types consistency - TORCH_CHECK(q.dtype() == k.dtype() && k.dtype() == v.dtype(), - "All QKV tensors must have the same dtype"); - TORCH_CHECK(mask.dtype() == q.dtype(), - "Attention mask must have the same dtype as QKV tensors"); - TORCH_CHECK(bias.dtype() == q.dtype(), - "Attention bias must have the same dtype as QKV tensors"); - - // Set up parameters and launch computation - Flash_fwd_params params; - set_params_fprop(params, batch_size, seqlen_q, seqlen_k, /* ... */, - q, k, v, mask, bias, /* ... */); - - // Launch kernel with appropriate configuration - run_mha_fwd(params, at::cuda::getCurrentCUDAStream()); - - // Return results - return {out, softmax_lse}; -} - -// Python binding -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashDynamicMaskAttention"; - m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass with dynamic masking", - py::arg("q"), py::arg("k"), py::arg("v"), - py::arg("mask"), py::arg("bias"), // Updated arguments - py::arg("out") = py::none(), - py::arg("softmax_scale") = 0.0f, - py::arg("is_causal") = false, - py::arg("softcap") = 0.0f, - py::arg("return_softmax") = false); -} -``` - -**Rationale**: -- **Comprehensive Validation**: Thorough validation of all input tensors for shape, type, and device consistency -- **Backward Compatibility**: Maintains existing parameter order while adding new functionality -- **Error Handling**: Clear error messages for common usage mistakes -- **Type Safety**: Strict type checking to prevent runtime errors -- **Documentation**: Clear parameter documentation for Python users - -## Implementation Details - -### C++ API Interface (`flash_api.cpp`) - -The core C++ API provides the following main functions for Dynamic Mask Attention: - -```cpp -namespace FLASH_NAMESPACE { - -std::vector mha_fwd( - at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k - const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k - std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const float softmax_scale, - bool is_causal, - const float softcap, - const bool return_softmax -); - -std::vector mha_varlen_fwd( - at::Tensor &q, // total_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &k, // total_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &v, // total_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k - const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k - std::optional &out_, // total_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &cu_seqlens_q, // batch_size + 1 - const at::Tensor &cu_seqlens_k, // batch_size + 1 - std::optional &seqused_k, - std::optional &leftpad_k, - const int max_seqlen_q, - const int max_seqlen_k, - const float softmax_scale, - bool is_causal, - const float softcap, - const bool return_softmax -); - -std::vector mha_bwd( - const at::Tensor &dout, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k - const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k - const at::Tensor &out, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &softmax_lse, // batch_size x num_heads x seqlen_q - std::optional &dq_, - std::optional &dk_, - std::optional &dv_, - std::optional &dbias_, - const float softmax_scale, - bool is_causal, - const float softcap, - bool deterministic, - std::optional gen_ -); - -} // namespace FLASH_NAMESPACE -``` - -### Parameter Setup and Validation - -The implementation includes comprehensive parameter validation and setup: - -```cpp -void set_params_fprop( - Flash_fwd_params ¶ms, - const size_t b, const size_t seqlen_q, const size_t seqlen_k, - const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, - const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, - const at::Tensor q, const at::Tensor k, const at::Tensor v, - const at::Tensor mask, const at::Tensor bias, at::Tensor out, - void *cu_seqlens_q_d, void *cu_seqlens_k_d, void *seqused_k, - void *p_d, void *softmax_lse_d, float softmax_scale, bool is_causal, - const float softcap, bool seqlenq_ngroups_swapped=false, - const bool unpadded_lse=false -) { - // Reset parameters - params = {}; - params.is_bf16 = q.dtype() == torch::kBFloat16; - - // Set tensor pointers - params.q_ptr = q.data_ptr(); - params.k_ptr = k.data_ptr(); - params.v_ptr = v.data_ptr(); - params.mask_ptr = mask.data_ptr(); - params.bias_ptr = bias.data_ptr(); - params.o_ptr = out.data_ptr(); - - // Set stride information (all strides are in elements, not bytes) - params.q_row_stride = q.stride(-3); - params.k_row_stride = k.stride(-3); - params.v_row_stride = v.stride(-3); - params.mask_row_stride = mask.stride(-2); - params.bias_row_stride = bias.stride(-2); - params.o_row_stride = out.stride(-3); - - params.q_head_stride = q.stride(-2); - params.k_head_stride = k.stride(-2); - params.v_head_stride = v.stride(-2); - params.mask_head_stride = mask.stride(-3); - params.bias_head_stride = bias.stride(-3); - params.o_head_stride = out.stride(-2); - - // Set batch stride information - if (cu_seqlens_q_d == nullptr) { - params.q_batch_stride = q.stride(0); - params.k_batch_stride = k.stride(0); - params.v_batch_stride = v.stride(0); - params.mask_batch_stride = mask.stride(0); - params.bias_batch_stride = bias.stride(0); - params.o_batch_stride = out.stride(0); - } - - // Set sequence length and dimension parameters - params.b = b; params.h = h; params.h_k = h_k; - params.h_h_k_ratio = h / h_k; - params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; - params.seqlen_q_rounded = seqlen_q_rounded; - params.seqlen_k_rounded = seqlen_k_rounded; - params.d = d; params.d_rounded = d_rounded; - - // Set scaling and control parameters - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; - params.softcap = softcap; - params.is_causal = is_causal; - params.unpadded_lse = unpadded_lse; - params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; -} -``` - -### Python Binding and Interface - -The C++ functions are exposed to Python through PyBind11: - -```cpp -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashDynamicMaskAttention"; - m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); - m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length"); - m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); - m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length"); -} -``` - -### Python Frontend Integration Example - -Dynamic Mask Attention can be integrated into transformer models as follows: - -```python -import torch -import torch.nn as nn -import flash_dmattn_cuda as flash_dmattn - -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.scaling = 1.0 / math.sqrt(self.head_dim) - - # Standard attention projections - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) - - def forward(self, hidden_states, attention_mask=None, attention_bias=None): - batch_size, seq_len, _ = hidden_states.shape - - # Project to Q, K, V - query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - - # Prepare mask and bias tensors with proper shapes - if attention_mask is None: - attention_mask = torch.ones((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - if attention_bias is None: - attention_bias = torch.zeros((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - # Call Flash Dynamic Mask Attention - output, _ = flash_dmattn.fwd( - query_states, key_states, value_states, - attention_mask, attention_bias, - None, # out - self.scaling, # softmax_scale - False, # is_causal - 0.0, # softcap - False # return_softmax - ) - - # Output projection - output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) - return self.o_proj(output) -``` - - # Call attention implementation - attn_output, attn_weights = flash_dynamic_mask_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - attention_bias=attn_bias, - scaling=self.scaling, - ) - - return attn_output, attn_weights -``` - -The attention bias generation process: - -1. **Value-based Dynamic States**: - ```python - dt_states = self.dt_proj(value_states_flattened) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - ``` - -2. **Bias Expansion**: - ```python - attn_bias = dt_states[:, :, None, :].expand(-1, -1, query_len, -1) - ``` - -3. **Mask Processing**: Done internally in `_flash_dynamic_mask_attention_forward` - - -### CUDA Backend: Sparse Attention Computation - -The CUDA backend implements the sparse attention computation through `_flash_dynamic_mask_attention_forward`: - -```python -def _flash_dynamic_mask_attention_forward( - query_states, key_states, value_states, - attention_mask, attention_bias, - query_length, key_length, - is_causal, softmax_scale=None, softcap=None, - target_dtype=None, implementation=None, **kwargs -): - dtype = query_states.dtype - min_dtype = torch.finfo(dtype).min - batch_size, _, num_kv_heads, _ = key_states.shape - - # Initialize attention bias if not provided - if attention_bias is None: - attention_bias = torch.zeros( - (batch_size, num_kv_heads, query_length, key_length), - dtype=dtype, device=query_states.device - ) - - # Apply attention mask to bias - if attention_mask is not None: - attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) - attention_mask = attention_mask.to(dtype) - - # Call Flash Attention with dynamic masking - out = flash_dmattn_func( - query_states, key_states, value_states, - attn_mask=attention_mask, attn_bias=attention_bias, - softmax_scale=softmax_scale, is_causal=is_causal - ) - - return out[0] if isinstance(out, tuple) else out -``` - -The backend processing stages: - -1. **Bias Initialization**: Create zero bias tensor if not provided -2. **Mask Application**: Apply boolean attention mask to bias tensor -3. **Flash Attention Call**: Execute optimized CUDA kernels with sparse patterns - -#### Updated Forward Algorithm - -The implementation introduces unified block-level skip logic that optimizes computation by skipping entire tiles when they are fully masked: - -```cpp -// Forward pass with unified skip logic -for m_block in M_tiles: - load Q_tile - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Block-level skip decision - if !any_active: - advance_pointers() // Skip computation, advance to next tile - continue - - // Only execute for active tiles - load K_tile, V_tile // Load data only when needed - S = Q_tile @ K_tile^T + bias_block // Sparse Q*K^T GEMM - S_masked = apply_mask(S, mask_block) // Apply dynamic masking - P = softmax(S_masked, LSE_cache) // Softmax with LSE caching - O_partial += P @ V_tile // Sparse Score*V GEMM -write O -``` - -Key improvements: -- **Block-level Skip Logic**: OR-reduction over entire (BlockM × BlockN) tile determines if computation is needed -- **Early Skip Decision**: Mask evaluation happens before expensive K/V loading and computation -- **Pointer Management**: Safe pointer advancement ensures correct memory layout for subsequent tiles - -#### Updated Backward Algorithm - -The backward pass also benefits from the unified skip logic, maintaining numerical correctness while significantly reducing computation for sparse patterns: - -```cpp -// Backward pass with unified skip logic -for m_block in reversed(M_tiles): - load Q_tile, dO_tile - init accum_dQ - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Same skip decision as forward - if !any_active: - advance_pointers_zero_side_outputs() // Skip computation, zero side outputs - continue - - // Only execute for active tiles - load K_tile, V_tile - - # Recompute (identical to forward for active tiles) - S = Q_tile @ K_tile^T + bias_block - P = softmax(S, LSE_cache) // Use cached LSE for stability - - # Gradient computation chain (5 GEMMs) - dV += P^T @ dO_tile // Accumulate dV - dP = dO_tile @ V_tile^T // Compute dP - dS = g(P, dP) // dS = (dP - (P ⊙ dP).sum(axis)) * P - dQ += dS @ K_tile // Accumulate dQ - dK += dS^T @ Q_tile // Accumulate dK - write dQ, accumulate dK, dV -``` - -Key features: -- **Recomputation Strategy**: Forward computation is recomputed only for active tiles to maintain numerical precision -- **LSE Caching**: Uses cached log-sum-exp values from forward pass for stable softmax recomputation -- **Gradient Chain**: All five gradient GEMMs are skipped for fully masked tiles, maintaining mathematical correctness -- **Zero Handling**: Properly handles zero contributions from skipped tiles in accumulation - -#### Skip Logic Correctness - -The mathematical correctness of the skip logic relies on the following principles: - -1. **Forward Skip**: If a tile is entirely masked (active_mask = 0), its contribution to the output is exactly zero: - ``` - O_contribution = P @ V = 0 @ V = 0 - ``` - -2. **Backward Skip**: For fully masked tiles, all intermediate gradients are zero: - ``` - P = 0 ⟹ dS = 0 ⟹ dQ = dK = dV = 0 (from this tile) - ``` - -3. **LSE Preservation**: Skipped tiles don't contribute to the log-sum-exp, maintaining numerical stability. - -### Sparse Computation Strategy - -### Block-level Skip Logic - -The implementation introduces unified block-level skip logic that operates at the tile granularity rather than individual elements: - -1. **Tile-level Active Detection**: - ```cpp - any_active = OR_reduce(mask_block) // Single bit indicating if any position in tile is active - ``` - -2. **Skip Decision**: Binary branch based on tile activity: - ```cpp - if (!any_active) { - advance_pointers(); // Forward: skip all computation - advance_pointers_zero_outputs(); // Backward: skip computation, zero side outputs - continue; - } - ``` - -3. **Computational Benefits**: - - Skip entire K/V loads for inactive tiles - - Eliminate all 5 GEMMs in backward pass for inactive tiles - - Reduce memory bandwidth and arithmetic operations proportional to sparsity - -### Sparsity Pattern Recognition - -The Dynamic Mask Attention implements structured sparsity based on learned importance scores: - -1. **Attention Bias Computation**: Attention bias values are computed based on dynamic states derived from value tensors - - Learned projection matrices map value features to importance scores - - Coefficient parameters control the dynamic range of importance values - - Activation functions ensure appropriate bias magnitude - -2. **Binary Attention Mask**: - - 1.0 for positions that should be computed - - 0.0 for positions that should be skipped - -### Performance Model (Updated) - -For block-level sparsity with active tile fraction $p$, skip overhead ratio $\varepsilon$, and early-exit efficiency $\eta$: - -$$ -\text{Speedup} \approx \frac{1}{p + (1-p)(\varepsilon + \eta \cdot \text{LoadOverhead})} -$$ - -Where: -- $p$: fraction of active tiles -- $\varepsilon$: skip branching overhead -- $\eta$: efficiency of early memory load exit -- $\text{LoadOverhead}$: relative cost of K/V loading vs computation - -Upper bound as $\varepsilon, \eta \to 0$: $1/p$ - -### Shared Memory Aliasing - -The implementation introduces smart shared memory aliasing to reduce footprint and enable larger tile sizes: - -1. **sMask ↔ sP Aliasing**: Mask shared memory region is reused for storing softmax probabilities P after mask consumption -2. **sBias ↔ sdS Aliasing**: Bias shared memory region is reused for gradient computations dS -3. **Barrier Synchronization**: Explicit `__syncthreads()` calls ensure safe transitions between aliased usage - -```cpp -// Example aliasing pattern -load mask -> sMask -any_active = or_reduce(sMask) -if any_active: - compute S - __syncthreads() // ensure mask fully consumed - softmax -> write P into aliased region (sP) // reuse sMask region as sP - ... -__syncthreads() // ensure dS consumed -// reuse sBias region as sdS in next iteration -``` - -### Memory Efficiency Optimizations - -1. **Shared Memory Aliasing**: Smart reuse of memory regions (sMask ↔ sP, sBias ↔ sdS) with explicit barrier synchronization -2. **Block-level Skip**: Early exit from computation and memory loading for inactive tiles -3. **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -4. **Register-Optimized Operations**: Critical masking and gradient operations performed in register memory -5. **Coalesced Memory Access**: Optimized access patterns for GPU memory hierarchy -6. **Template Specialization**: Compile-time optimization eliminates runtime branching overhead - -## Memory Layout - -### Tensor Memory Organization - -The Dynamic Mask Attention extends Flash Attention's memory layout to include attention masks and attention bias: - -``` -Global Memory Layout: -┌─────────────────────────────────────────────────────────────────┐ -│ Q: [batch, seqlen_q, num_heads, head_dim] │ -│ K: [batch, seqlen_k, num_heads_k, head_dim] │ -│ V: [batch, seqlen_k, num_heads_k, head_dim] │ -│ Mask: [batch, num_heads_k, seqlen_q, seqlen_k] │ -│ Bias: [batch, num_heads_k, seqlen_q, seqlen_k] │ -│ Output: [batch, seqlen_q, num_heads, head_dim] │ -└─────────────────────────────────────────────────────────────────┘ - -Shared Memory Layout (per thread block): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Tile: [kBlockM, head_dim] │ K Tile: [kBlockN, head_dim] │ -│ V Tile: [kBlockN, head_dim] │ S Tile: [kBlockM, kBlockN] │ -│ AM Tile: [kBlockM, kBlockN] │ Bias Tile: [kBlockM, kBlockN] │ -└─────────────────────────────────────────────────────────────────────┘ - -Register Memory (per thread): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Frag: [MMA_M, head_dim/N] │ K Frag: [MMA_N, head_dim/N] │ -│ V Frag: [MMA_N, head_dim/N] │ S Frag: [MMA_M, MMA_N] │ -│ AM Frag: [MMA_M, MMA_N] │ Bias Frag: [MMA_M, MMA_N] │ -│ Acc Frag: [MMA_M, head_dim/N] │ │ -└─────────────────────────────────────────────────────────────────────┘ -``` - -### Memory Access Patterns - -#### Attention Mask and Attention Bias Loading -```cpp -// Global to Shared Memory (coalesced access) -Tensor tSgBias = local_partition(mBias, smem_tiled_copy_Bias, thread_idx); -Tensor tSsBias = local_partition(sBias, smem_tiled_copy_Bias, thread_idx); - -// Each thread loads a contiguous chunk to maximize memory bandwidth -copy(smem_tiled_copy_Bias, tSgBias, tSsBias); - -// Shared to Register Memory (bank-conflict-free) -Tensor tSrBias = local_partition(sBias, smem_thr_copy_Bias, thread_idx); -copy(smem_thr_copy_Bias, tSsBias, tSrBias); -``` - -#### Memory Layout Transformations -```cpp -// Convert MMA accumulator layout to row-column layout for masking -// From: (MMA=4, MMA_M, MMA_N) -> (nrow=(2, MMA_M), ncol=(2, MMA_N)) -auto convert_layout_acc_rowcol = [](auto layout) { - return make_layout( - make_layout(make_shape(Int<2>{}, get<1>(layout.shape())), - make_stride(Int(layout.stride())* 2>{}, get<1>(layout.stride()))), - make_layout(make_shape(Int<2>{}, get<2>(layout.shape())), - make_stride(Int<1>{}, Int<2>{})) - ); -}; -``` - -### Shared Memory Optimization - -#### Bank Conflict Avoidance -- Attention bias and attention masks use the same copy patterns as Q/K/V to avoid bank conflicts -- Padding added when necessary to ensure 128-bit aligned access -- Thread block size chosen to maximize occupancy while maintaining memory efficiency - -#### Memory Coalescing -```cpp -// Example: Loading 128-bit aligned chunks for optimal bandwidth -using SmemCopyAtomBias = Copy_Atom; // 128-bit loads -using SmemCopyAtomAttnMask = Copy_Atom; -``` - -## Performance Considerations - -### Memory Efficiency -- **Shared Memory Aliasing**: Smart memory reuse (sMask ↔ sP, sBias ↔ sdS) reduces footprint by ~30% -- **Block-level Skip**: Early exit eliminates unnecessary memory loads for inactive tiles -- **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -- **Coalesced Access**: Optimized tensor layouts for GPU memory hierarchy - -### Computational Efficiency -- **Unified Skip Logic**: Both forward and backward passes benefit from block-level computation skipping -- **5-GEMM Chain Skip**: Complete gradient computation chain skipped for inactive tiles -- **Early Branch Decision**: Mask OR-reduction happens before expensive K/V loads -- **Warp-Level Optimization**: Operations optimized for GPU warp execution model - -### Scalability -- **Block-level Granularity**: Tile-level sparsity more efficient than element-level for long sequences -- **Multi-Head Support**: Efficient handling of multiple attention heads with per-head sparsity patterns -- **Barrier Optimization**: Minimal synchronization overhead through smart aliasing strategies - -### Performance Model - -Expected speedup for various sparsity levels: -- **50% sparsity**: ~1.8x speedup -- **75% sparsity**: ~3.2x speedup -- **90% sparsity**: ~6.5x speedup - -Performance factors: -- Skip overhead typically <5% of dense computation time -- Memory bandwidth reduction scales linearly with sparsity -- Shared memory aliasing enables 20-30% larger tile sizes - -## API Changes - -### New Required Parameters - -The Dynamic Mask Attention integration introduces new required parameters to the forward pass: - -- **`attn_mask`** (`torch.Tensor`): Attention mask tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Binary mask (1.0 = compute, 0.0 = skip) indicating which positions should be processed - - Determines the sparsity pattern for computational efficiency - -- **`attn_bias`** (`torch.Tensor`): Attention bias tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Contains dynamic attention bias values applied to attention scores before softmax - - Must have the same dtype and device as Q/K/V tensors - -### Updated Function Signature - -```python -def fwd( - q: torch.Tensor, # Query tensor - k: torch.Tensor, # Key tensor - v: torch.Tensor, # Value tensor - attn_mask: torch.Tensor, # Attention mask (REQUIRED) - attn_bias: torch.Tensor, # Attention bias (REQUIRED) - out: Optional[torch.Tensor] = None, # Pre-allocated output - softmax_scale: float = None, # Attention scaling - is_causal: bool = False, # Causal masking - softcap: float = 0.0, # Soft capping - return_softmax: bool = False, # Return attention weights -) -> List[torch.Tensor] -``` - -### Backward Compatibility - -**Breaking Change Notice**: The integration requires attention bias and attention mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. - -**Migration Path**: Users need to: -1. Add attention mask and bias generation logic to attention modules -2. Implement appropriate mask and bias computation within the attention forward pass -3. Ensure proper tensor shapes and dtypes for mask and bias tensors - -### Complete Usage Example - -```python -import torch -import torch.nn as nn -import flash_dmattn_cuda as flash_dmattn - -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.scaling = 1.0 / math.sqrt(self.head_dim) - - # Standard attention projections - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) - - def forward(self, hidden_states, attention_mask=None, attention_bias=None): - batch_size, seq_len, _ = hidden_states.shape - - # Project to Q, K, V - query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - - # Prepare mask and bias tensors with proper shapes - if attention_mask is None: - attention_mask = torch.ones((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - if attention_bias is None: - attention_bias = torch.zeros((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - # Call Flash Dynamic Mask Attention - output, _ = flash_dmattn.fwd( - query_states, key_states, value_states, - attention_mask, attention_bias, - None, # out - self.scaling, # softmax_scale - False, # is_causal - 0.0, # softcap - False # return_softmax - ) - - # Output projection - output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) - return self.o_proj(output) -``` - - # Call attention implementation - attn_output, attn_weights = flash_dynamic_mask_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - attention_bias=attn_bias, - scaling=self.scaling, - ) - - return attn_output, attn_weights -``` - -The attention bias generation process: - -1. **Value-based Dynamic States**: - ```python - dt_states = self.dt_proj(value_states_flattened) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - ``` - -2. **Bias Expansion**: - ```python - attn_bias = dt_states[:, :, None, :].expand(-1, -1, query_len, -1) - ``` - -3. **Mask Processing**: Done internally in `_flash_dynamic_mask_attention_forward` - - -### CUDA Backend: Sparse Attention Computation - -The CUDA backend implements the sparse attention computation through `_flash_dynamic_mask_attention_forward`: - -```python -def _flash_dynamic_mask_attention_forward( - query_states, key_states, value_states, - attention_mask, attention_bias, - query_length, key_length, - is_causal, softmax_scale=None, softcap=None, - target_dtype=None, implementation=None, **kwargs -): - dtype = query_states.dtype - min_dtype = torch.finfo(dtype).min - batch_size, _, num_kv_heads, _ = key_states.shape - - # Initialize attention bias if not provided - if attention_bias is None: - attention_bias = torch.zeros( - (batch_size, num_kv_heads, query_length, key_length), - dtype=dtype, device=query_states.device - ) - - # Apply attention mask to bias - if attention_mask is not None: - attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) - attention_mask = attention_mask.to(dtype) - - # Call Flash Attention with dynamic masking - out = flash_dmattn_func( - query_states, key_states, value_states, - attn_mask=attention_mask, attn_bias=attention_bias, - softmax_scale=softmax_scale, is_causal=is_causal - ) - - return out[0] if isinstance(out, tuple) else out -``` - -The backend processing stages: - -1. **Bias Initialization**: Create zero bias tensor if not provided -2. **Mask Application**: Apply boolean attention mask to bias tensor -3. **Flash Attention Call**: Execute optimized CUDA kernels with sparse patterns - -#### Updated Forward Algorithm - -The implementation introduces unified block-level skip logic that optimizes computation by skipping entire tiles when they are fully masked: - -```cpp -// Forward pass with unified skip logic -for m_block in M_tiles: - load Q_tile - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Block-level skip decision - if !any_active: - advance_pointers() // Skip computation, advance to next tile - continue - - // Only execute for active tiles - load K_tile, V_tile // Load data only when needed - S = Q_tile @ K_tile^T + bias_block // Sparse Q*K^T GEMM - S_masked = apply_mask(S, mask_block) // Apply dynamic masking - P = softmax(S_masked, LSE_cache) // Softmax with LSE caching - O_partial += P @ V_tile // Sparse Score*V GEMM -write O -``` - -Key improvements: -- **Block-level Skip Logic**: OR-reduction over entire (BlockM × BlockN) tile determines if computation is needed -- **Early Skip Decision**: Mask evaluation happens before expensive K/V loading and computation -- **Pointer Management**: Safe pointer advancement ensures correct memory layout for subsequent tiles - -#### Updated Backward Algorithm - -The backward pass also benefits from the unified skip logic, maintaining numerical correctness while significantly reducing computation for sparse patterns: - -```cpp -// Backward pass with unified skip logic -for m_block in reversed(M_tiles): - load Q_tile, dO_tile - init accum_dQ - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Same skip decision as forward - if !any_active: - advance_pointers_zero_side_outputs() // Skip computation, zero side outputs - continue - - // Only execute for active tiles - load K_tile, V_tile - - # Recompute (identical to forward for active tiles) - S = Q_tile @ K_tile^T + bias_block - P = softmax(S, LSE_cache) // Use cached LSE for stability - - # Gradient computation chain (5 GEMMs) - dV += P^T @ dO_tile // Accumulate dV - dP = dO_tile @ V_tile^T // Compute dP - dS = g(P, dP) // dS = (dP - (P ⊙ dP).sum(axis)) * P - dQ += dS @ K_tile // Accumulate dQ - dK += dS^T @ Q_tile // Accumulate dK - write dQ, accumulate dK, dV -``` - -Key features: -- **Recomputation Strategy**: Forward computation is recomputed only for active tiles to maintain numerical precision -- **LSE Caching**: Uses cached log-sum-exp values from forward pass for stable softmax recomputation -- **Gradient Chain**: All five gradient GEMMs are skipped for fully masked tiles, maintaining mathematical correctness -- **Zero Handling**: Properly handles zero contributions from skipped tiles in accumulation - -#### Skip Logic Correctness - -The mathematical correctness of the skip logic relies on the following principles: - -1. **Forward Skip**: If a tile is entirely masked (active_mask = 0), its contribution to the output is exactly zero: - ``` - O_contribution = P @ V = 0 @ V = 0 - ``` - -2. **Backward Skip**: For fully masked tiles, all intermediate gradients are zero: - ``` - P = 0 ⟹ dS = 0 ⟹ dQ = dK = dV = 0 (from this tile) - ``` - -3. **LSE Preservation**: Skipped tiles don't contribute to the log-sum-exp, maintaining numerical stability. - -### Sparse Computation Strategy - -### Block-level Skip Logic - -The implementation introduces unified block-level skip logic that operates at the tile granularity rather than individual elements: - -1. **Tile-level Active Detection**: - ```cpp - any_active = OR_reduce(mask_block) // Single bit indicating if any position in tile is active - ``` - -2. **Skip Decision**: Binary branch based on tile activity: - ```cpp - if (!any_active) { - advance_pointers(); // Forward: skip all computation - advance_pointers_zero_outputs(); // Backward: skip computation, zero side outputs - continue; - } - ``` - -3. **Computational Benefits**: - - Skip entire K/V loads for inactive tiles - - Eliminate all 5 GEMMs in backward pass for inactive tiles - - Reduce memory bandwidth and arithmetic operations proportional to sparsity - -### Sparsity Pattern Recognition - -The Dynamic Mask Attention implements structured sparsity based on learned importance scores: - -1. **Attention Bias Computation**: Attention bias values are computed based on dynamic states derived from value tensors - - Learned projection matrices map value features to importance scores - - Coefficient parameters control the dynamic range of importance values - - Activation functions ensure appropriate bias magnitude - -2. **Binary Attention Mask**: - - 1.0 for positions that should be computed - - 0.0 for positions that should be skipped - -### Performance Model (Updated) - -For block-level sparsity with active tile fraction $p$, skip overhead ratio $\varepsilon$, and early-exit efficiency $\eta$: - -$$ -\text{Speedup} \approx \frac{1}{p + (1-p)(\varepsilon + \eta \cdot \text{LoadOverhead})} -$$ - -Where: -- $p$: fraction of active tiles -- $\varepsilon$: skip branching overhead -- $\eta$: efficiency of early memory load exit -- $\text{LoadOverhead}$: relative cost of K/V loading vs computation - -Upper bound as $\varepsilon, \eta \to 0$: $1/p$ - -### Shared Memory Aliasing - -The implementation introduces smart shared memory aliasing to reduce footprint and enable larger tile sizes: - -1. **sMask ↔ sP Aliasing**: Mask shared memory region is reused for storing softmax probabilities P after mask consumption -2. **sBias ↔ sdS Aliasing**: Bias shared memory region is reused for gradient computations dS -3. **Barrier Synchronization**: Explicit `__syncthreads()` calls ensure safe transitions between aliased usage - -```cpp -// Example aliasing pattern -load mask -> sMask -any_active = or_reduce(sMask) -if any_active: - compute S - __syncthreads() // ensure mask fully consumed - softmax -> write P into aliased region (sP) // reuse sMask region as sP - ... -__syncthreads() // ensure dS consumed -// reuse sBias region as sdS in next iteration -``` - -### Memory Efficiency Optimizations - -1. **Shared Memory Aliasing**: Smart reuse of memory regions (sMask ↔ sP, sBias ↔ sdS) with explicit barrier synchronization -2. **Block-level Skip**: Early exit from computation and memory loading for inactive tiles -3. **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -4. **Register-Optimized Operations**: Critical masking and gradient operations performed in register memory -5. **Coalesced Memory Access**: Optimized access patterns for GPU memory hierarchy -6. **Template Specialization**: Compile-time optimization eliminates runtime branching overhead - -## Memory Layout - -### Tensor Memory Organization - -The Dynamic Mask Attention extends Flash Attention's memory layout to include attention masks and attention bias: - -``` -Global Memory Layout: -┌─────────────────────────────────────────────────────────────────┐ -│ Q: [batch, seqlen_q, num_heads, head_dim] │ -│ K: [batch, seqlen_k, num_heads_k, head_dim] │ -│ V: [batch, seqlen_k, num_heads_k, head_dim] │ -│ AttnMask: [batch, num_kv_heads, seqlen_q, seqlen_k] │ -│ Bias: [batch, num_kv_heads, seqlen_q, seqlen_k] │ -│ Output: [batch, seqlen_q, num_heads, head_dim] │ -└─────────────────────────────────────────────────────────────────┘ - -Shared Memory Layout (per thread block): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Tile: [kBlockM, head_dim] │ K Tile: [kBlockN, head_dim] │ -│ V Tile: [kBlockN, head_dim] │ S Tile: [kBlockM, kBlockN] │ -│ AM Tile: [kBlockM, kBlockN] │ Bias Tile: [kBlockM, kBlockN] │ -└─────────────────────────────────────────────────────────────────────┘ - -Register Memory (per thread): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Frag: [MMA_M, head_dim/N] │ K Frag: [MMA_N, head_dim/N] │ -│ V Frag: [MMA_N, head_dim/N] │ S Frag: [MMA_M, MMA_N] │ -│ AM Frag: [MMA_M, MMA_N] │ Bias Frag: [MMA_M, MMA_N] │ -│ Acc Frag: [MMA_M, head_dim/N] │ │ -└─────────────────────────────────────────────────────────────────────┘ -``` - -### Memory Access Patterns - -#### Attention Mask and Attention Bias Loading -```cpp -// Global to Shared Memory (coalesced access) -Tensor tSgBias = local_partition(mBias, smem_tiled_copy_Bias, thread_idx); -Tensor tSsBias = local_partition(sBias, smem_tiled_copy_Bias, thread_idx); - -// Each thread loads a contiguous chunk to maximize memory bandwidth -copy(smem_tiled_copy_Bias, tSgBias, tSsBias); - -// Shared to Register Memory (bank-conflict-free) -Tensor tSrBias = local_partition(sBias, smem_thr_copy_Bias, thread_idx); -copy(smem_thr_copy_Bias, tSsBias, tSrBias); -``` - -#### Memory Layout Transformations -```cpp -// Convert MMA accumulator layout to row-column layout for masking -// From: (MMA=4, MMA_M, MMA_N) -> (nrow=(2, MMA_M), ncol=(2, MMA_N)) -auto convert_layout_acc_rowcol = [](auto layout) { - return make_layout( - make_layout(make_shape(Int<2>{}, get<1>(layout.shape())), - make_stride(Int(layout.stride())* 2>{}, get<1>(layout.stride()))), - make_layout(make_shape(Int<2>{}, get<2>(layout.shape())), - make_stride(Int<1>{}, Int<2>{})) - ); -}; -``` - -### Shared Memory Optimization - -#### Bank Conflict Avoidance -- Attention bias and attention masks use the same copy patterns as Q/K/V to avoid bank conflicts -- Padding added when necessary to ensure 128-bit aligned access -- Thread block size chosen to maximize occupancy while maintaining memory efficiency - -#### Memory Coalescing -```cpp -// Example: Loading 128-bit aligned chunks for optimal bandwidth -using SmemCopyAtomBias = Copy_Atom; // 128-bit loads -using SmemCopyAtomAttnMask = Copy_Atom; -``` - -## Performance Considerations - -### Memory Efficiency -- **Shared Memory Aliasing**: Smart memory reuse (sMask ↔ sP, sBias ↔ sdS) reduces footprint by ~30% -- **Block-level Skip**: Early exit eliminates unnecessary memory loads for inactive tiles -- **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -- **Coalesced Access**: Optimized tensor layouts for GPU memory hierarchy - -### Computational Efficiency -- **Unified Skip Logic**: Both forward and backward passes benefit from block-level computation skipping -- **5-GEMM Chain Skip**: Complete gradient computation chain skipped for inactive tiles -- **Early Branch Decision**: Mask OR-reduction happens before expensive K/V loads -- **Warp-Level Optimization**: Operations optimized for GPU warp execution model - -### Scalability -- **Block-level Granularity**: Tile-level sparsity more efficient than element-level for long sequences -- **Multi-Head Support**: Efficient handling of multiple attention heads with per-head sparsity patterns -- **Barrier Optimization**: Minimal synchronization overhead through smart aliasing strategies - -### Performance Model - -Expected speedup for various sparsity levels: -- **50% sparsity**: ~1.8x speedup -- **75% sparsity**: ~3.2x speedup -- **90% sparsity**: ~6.5x speedup - -Performance factors: -- Skip overhead typically <5% of dense computation time -- Memory bandwidth reduction scales linearly with sparsity -- Shared memory aliasing enables 20-30% larger tile sizes - -## API Changes - -### New Required Parameters - -The Dynamic Mask Attention integration introduces new required parameters to the forward pass: - -- **`attn_mask`** (`torch.Tensor`): Attention mask tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Binary mask (1.0 = compute, 0.0 = skip) indicating which positions should be processed - - Determines the sparsity pattern for computational efficiency - -- **`attn_bias`** (`torch.Tensor`): Attention bias tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Contains dynamic attention bias values applied to attention scores before softmax - - Must have the same dtype and device as Q/K/V tensors - -### Updated Function Signature - -```python -def fwd( - q: torch.Tensor, # Query tensor - k: torch.Tensor, # Key tensor - v: torch.Tensor, # Value tensor - attn_mask: torch.Tensor, # Attention mask (REQUIRED) - attn_bias: torch.Tensor, # Attention bias (REQUIRED) - out: Optional[torch.Tensor] = None, # Pre-allocated output - softmax_scale: float = None, # Attention scaling - is_causal: bool = False, # Causal masking - softcap: float = 0.0, # Soft capping - return_softmax: bool = False, # Return attention weights -) -> List[torch.Tensor] -``` - -### Backward Compatibility - -**Breaking Change Notice**: The integration requires attention bias and attention mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. - -**Migration Path**: Users need to: -1. Add attention mask and bias generation logic to attention modules -2. Implement appropriate mask and bias computation within the attention forward pass -3. Ensure proper tensor shapes and dtypes for mask and bias tensors - -### Complete Usage Example - -```python -import torch -import torch.nn as nn -import flash_dmattn_cuda as flash_dmattn - -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.scaling = 1.0 / math.sqrt(self.head_dim) - - # Standard attention projections - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) - - def forward(self, hidden_states, attention_mask=None, attention_bias=None): - batch_size, seq_len, _ = hidden_states.shape - - # Project to Q, K, V - query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - - # Prepare mask and bias tensors with proper shapes - if attention_mask is None: - attention_mask = torch.ones((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - if attention_bias is None: - attention_bias = torch.zeros((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - # Call Flash Dynamic Mask Attention - output, _ = flash_dmattn.fwd( - query_states, key_states, value_states, - attention_mask, attention_bias, - None, # out - self.scaling, # softmax_scale - False, # is_causal - 0.0, # softcap - False # return_softmax - ) - - # Output projection - output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) - return self.o_proj(output) -``` - - # Call attention implementation - attn_output, attn_weights = flash_dynamic_mask_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - attention_bias=attn_bias, - scaling=self.scaling, - ) - - return attn_output, attn_weights -``` - -The attention bias generation process: - -1. **Value-based Dynamic States**: - ```python - dt_states = self.dt_proj(value_states_flattened) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - ``` - -2. **Bias Expansion**: - ```python - attn_bias = dt_states[:, :, None, :].expand(-1, -1, query_len, -1) - ``` - -3. **Mask Processing**: Done internally in `_flash_dynamic_mask_attention_forward` - - -### CUDA Backend: Sparse Attention Computation - -The CUDA backend implements the sparse attention computation through `_flash_dynamic_mask_attention_forward`: - -```python -def _flash_dynamic_mask_attention_forward( - query_states, key_states, value_states, - attention_mask, attention_bias, - query_length, key_length, - is_causal, softmax_scale=None, softcap=None, - target_dtype=None, implementation=None, **kwargs -): - dtype = query_states.dtype - min_dtype = torch.finfo(dtype).min - batch_size, _, num_kv_heads, _ = key_states.shape - - # Initialize attention bias if not provided - if attention_bias is None: - attention_bias = torch.zeros( - (batch_size, num_kv_heads, query_length, key_length), - dtype=dtype, device=query_states.device - ) - - # Apply attention mask to bias - if attention_mask is not None: - attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) - attention_mask = attention_mask.to(dtype) - - # Call Flash Attention with dynamic masking - out = flash_dmattn_func( - query_states, key_states, value_states, - attn_mask=attention_mask, attn_bias=attention_bias, - softmax_scale=softmax_scale, is_causal=is_causal - ) - - return out[0] if isinstance(out, tuple) else out -``` - -The backend processing stages: - -1. **Bias Initialization**: Create zero bias tensor if not provided -2. **Mask Application**: Apply boolean attention mask to bias tensor -3. **Flash Attention Call**: Execute optimized CUDA kernels with sparse patterns - -#### Updated Forward Algorithm - -The implementation introduces unified block-level skip logic that optimizes computation by skipping entire tiles when they are fully masked: - -```cpp -// Forward pass with unified skip logic -for m_block in M_tiles: - load Q_tile - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Block-level skip decision - if !any_active: - advance_pointers() // Skip computation, advance to next tile - continue - - // Only execute for active tiles - load K_tile, V_tile // Load data only when needed - S = Q_tile @ K_tile^T + bias_block // Sparse Q*K^T GEMM - S_masked = apply_mask(S, mask_block) // Apply dynamic masking - P = softmax(S_masked, LSE_cache) // Softmax with LSE caching - O_partial += P @ V_tile // Sparse Score*V GEMM -write O -``` - -Key improvements: -- **Block-level Skip Logic**: OR-reduction over entire (BlockM × BlockN) tile determines if computation is needed -- **Early Skip Decision**: Mask evaluation happens before expensive K/V loading and computation -- **Pointer Management**: Safe pointer advancement ensures correct memory layout for subsequent tiles - -#### Updated Backward Algorithm - -The backward pass also benefits from the unified skip logic, maintaining numerical correctness while significantly reducing computation for sparse patterns: - -```cpp -// Backward pass with unified skip logic -for m_block in reversed(M_tiles): - load Q_tile, dO_tile - init accum_dQ - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Same skip decision as forward - if !any_active: - advance_pointers_zero_side_outputs() // Skip computation, zero side outputs - continue - - // Only execute for active tiles - load K_tile, V_tile - - # Recompute (identical to forward for active tiles) - S = Q_tile @ K_tile^T + bias_block - P = softmax(S, LSE_cache) // Use cached LSE for stability - - # Gradient computation chain (5 GEMMs) - dV += P^T @ dO_tile // Accumulate dV - dP = dO_tile @ V_tile^T // Compute dP - dS = g(P, dP) // dS = (dP - (P ⊙ dP).sum(axis)) * P - dQ += dS @ K_tile // Accumulate dQ - dK += dS^T @ Q_tile // Accumulate dK - write dQ, accumulate dK, dV -``` - -Key features: -- **Recomputation Strategy**: Forward computation is recomputed only for active tiles to maintain numerical precision -- **LSE Caching**: Uses cached log-sum-exp values from forward pass for stable softmax recomputation -- **Gradient Chain**: All five gradient GEMMs are skipped for fully masked tiles, maintaining mathematical correctness -- **Zero Handling**: Properly handles zero contributions from skipped tiles in accumulation - -#### Skip Logic Correctness - -The mathematical correctness of the skip logic relies on the following principles: - -1. **Forward Skip**: If a tile is entirely masked (active_mask = 0), its contribution to the output is exactly zero: - ``` - O_contribution = P @ V = 0 @ V = 0 - ``` - -2. **Backward Skip**: For fully masked tiles, all intermediate gradients are zero: - ``` - P = 0 ⟹ dS = 0 ⟹ dQ = dK = dV = 0 (from this tile) - ``` - -3. **LSE Preservation**: Skipped tiles don't contribute to the log-sum-exp, maintaining numerical stability. - -### Sparse Computation Strategy - -### Block-level Skip Logic - -The implementation introduces unified block-level skip logic that operates at the tile granularity rather than individual elements: - -1. **Tile-level Active Detection**: - ```cpp - any_active = OR_reduce(mask_block) // Single bit indicating if any position in tile is active - ``` - -2. **Skip Decision**: Binary branch based on tile activity: - ```cpp - if (!any_active) { - advance_pointers(); // Forward: skip all computation - advance_pointers_zero_outputs(); // Backward: skip computation, zero side outputs - continue; - } - ``` - -3. **Computational Benefits**: - - Skip entire K/V loads for inactive tiles - - Eliminate all 5 GEMMs in backward pass for inactive tiles - - Reduce memory bandwidth and arithmetic operations proportional to sparsity - -### Sparsity Pattern Recognition - -The Dynamic Mask Attention implements structured sparsity based on learned importance scores: - -1. **Attention Bias Computation**: Attention bias values are computed based on dynamic states derived from value tensors - - Learned projection matrices map value features to importance scores - - Coefficient parameters control the dynamic range of importance values - - Activation functions ensure appropriate bias magnitude - -2. **Binary Attention Mask**: - - 1.0 for positions that should be computed - - 0.0 for positions that should be skipped - -### Performance Model (Updated) - -For block-level sparsity with active tile fraction $p$, skip overhead ratio $\varepsilon$, and early-exit efficiency $\eta$: - -$$ -\text{Speedup} \approx \frac{1}{p + (1-p)(\varepsilon + \eta \cdot \text{LoadOverhead})} -$$ - -Where: -- $p$: fraction of active tiles -- $\varepsilon$: skip branching overhead -- $\eta$: efficiency of early memory load exit -- $\text{LoadOverhead}$: relative cost of K/V loading vs computation - -Upper bound as $\varepsilon, \eta \to 0$: $1/p$ - -### Shared Memory Aliasing - -The implementation introduces smart shared memory aliasing to reduce footprint and enable larger tile sizes: - -1. **sMask ↔ sP Aliasing**: Mask shared memory region is reused for storing softmax probabilities P after mask consumption -2. **sBias ↔ sdS Aliasing**: Bias shared memory region is reused for gradient computations dS -3. **Barrier Synchronization**: Explicit `__syncthreads()` calls ensure safe transitions between aliased usage - -```cpp -// Example aliasing pattern -load mask -> sMask -any_active = or_reduce(sMask) -if any_active: - compute S - __syncthreads() // ensure mask fully consumed - softmax -> write P into aliased region (sP) // reuse sMask region as sP - ... -__syncthreads() // ensure dS consumed -// reuse sBias region as sdS in next iteration -``` - -### Memory Efficiency Optimizations - -1. **Shared Memory Aliasing**: Smart reuse of memory regions (sMask ↔ sP, sBias ↔ sdS) with explicit barrier synchronization -2. **Block-level Skip**: Early exit from computation and memory loading for inactive tiles -3. **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -4. **Register-Optimized Operations**: Critical masking and gradient operations performed in register memory -5. **Coalesced Memory Access**: Optimized access patterns for GPU memory hierarchy -6. **Template Specialization**: Compile-time optimization eliminates runtime branching overhead - -## Memory Layout - -### Tensor Memory Organization - -The Dynamic Mask Attention extends Flash Attention's memory layout to include attention masks and attention bias: - -``` -Global Memory Layout: -┌─────────────────────────────────────────────────────────────────┐ -│ Q: [batch, seqlen_q, num_heads, head_dim] │ -│ K: [batch, seqlen_k, num_heads_k, head_dim] │ -│ V: [batch, seqlen_k, num_heads_k, head_dim] │ -│ AttnMask: [batch, num_kv_heads, seqlen_q, seqlen_k] │ -│ Bias: [batch, num_kv_heads, seqlen_q, seqlen_k] │ -│ Output: [batch, seqlen_q, num_heads, head_dim] │ -└─────────────────────────────────────────────────────────────────┘ - -Shared Memory Layout (per thread block): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Tile: [kBlockM, head_dim] │ K Tile: [kBlockN, head_dim] │ -│ V Tile: [kBlockN, head_dim] │ S Tile: [kBlockM, kBlockN] │ -│ AM Tile: [kBlockM, kBlockN] │ Bias Tile: [kBlockM, kBlockN] │ -└─────────────────────────────────────────────────────────────────────┘ - -Register Memory (per thread): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Frag: [MMA_M, head_dim/N] │ K Frag: [MMA_N, head_dim/N] │ -│ V Frag: [MMA_N, head_dim/N] │ S Frag: [MMA_M, MMA_N] │ -│ AM Frag: [MMA_M, MMA_N] │ Bias Frag: [MMA_M, MMA_N] │ -│ Acc Frag: [MMA_M, head_dim/N] │ │ -└─────────────────────────────────────────────────────────────────────┘ -``` - -### Memory Access Patterns - -#### Attention Mask and Attention Bias Loading -```cpp -// Global to Shared Memory (coalesced access) -Tensor tSgBias = local_partition(mBias, smem_tiled_copy_Bias, thread_idx); -Tensor tSsBias = local_partition(sBias, smem_tiled_copy_Bias, thread_idx); - -// Each thread loads a contiguous chunk to maximize memory bandwidth -copy(smem_tiled_copy_Bias, tSgBias, tSsBias); - -// Shared to Register Memory (bank-conflict-free) -Tensor tSrBias = local_partition(sBias, smem_thr_copy_Bias, thread_idx); -copy(smem_thr_copy_Bias, tSsBias, tSrBias); -``` - -#### Memory Layout Transformations -```cpp -// Convert MMA accumulator layout to row-column layout for masking -// From: (MMA=4, MMA_M, MMA_N) -> (nrow=(2, MMA_M), ncol=(2, MMA_N)) -auto convert_layout_acc_rowcol = [](auto layout) { - return make_layout( - make_layout(make_shape(Int<2>{}, get<1>(layout.shape())), - make_stride(Int(layout.stride())* 2>{}, get<1>(layout.stride()))), - make_layout(make_shape(Int<2>{}, get<2>(layout.shape())), - make_stride(Int<1>{}, Int<2>{})) - ); -}; -``` - -### Shared Memory Optimization - -#### Bank Conflict Avoidance -- Attention bias and attention masks use the same copy patterns as Q/K/V to avoid bank conflicts -- Padding added when necessary to ensure 128-bit aligned access -- Thread block size chosen to maximize occupancy while maintaining memory efficiency - -#### Memory Coalescing -```cpp -// Example: Loading 128-bit aligned chunks for optimal bandwidth -using SmemCopyAtomBias = Copy_Atom; // 128-bit loads -using SmemCopyAtomAttnMask = Copy_Atom; -``` - -## Performance Considerations - -### Memory Efficiency -- **Shared Memory Aliasing**: Smart memory reuse (sMask ↔ sP, sBias ↔ sdS) reduces footprint by ~30% -- **Block-level Skip**: Early exit eliminates unnecessary memory loads for inactive tiles -- **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -- **Coalesced Access**: Optimized tensor layouts for GPU memory hierarchy - -### Computational Efficiency -- **Unified Skip Logic**: Both forward and backward passes benefit from block-level computation skipping -- **5-GEMM Chain Skip**: Complete gradient computation chain skipped for inactive tiles -- **Early Branch Decision**: Mask OR-reduction happens before expensive K/V loads -- **Warp-Level Optimization**: Operations optimized for GPU warp execution model - -### Scalability -- **Block-level Granularity**: Tile-level sparsity more efficient than element-level for long sequences -- **Multi-Head Support**: Efficient handling of multiple attention heads with per-head sparsity patterns -- **Barrier Optimization**: Minimal synchronization overhead through smart aliasing strategies - -### Performance Model - -Expected speedup for various sparsity levels: -- **50% sparsity**: ~1.8x speedup -- **75% sparsity**: ~3.2x speedup -- **90% sparsity**: ~6.5x speedup - -Performance factors: -- Skip overhead typically <5% of dense computation time -- Memory bandwidth reduction scales linearly with sparsity -- Shared memory aliasing enables 20-30% larger tile sizes - -## API Changes - -### New Required Parameters - -The Dynamic Mask Attention integration introduces new required parameters to the forward pass: - -- **`attn_mask`** (`torch.Tensor`): Attention mask tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Binary mask (1.0 = compute, 0.0 = skip) indicating which positions should be processed - - Determines the sparsity pattern for computational efficiency - -- **`attn_bias`** (`torch.Tensor`): Attention bias tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Contains dynamic attention bias values applied to attention scores before softmax - - Must have the same dtype and device as Q/K/V tensors - -### Updated Function Signature - -```python -def fwd( - q: torch.Tensor, # Query tensor - k: torch.Tensor, # Key tensor - v: torch.Tensor, # Value tensor - attn_mask: torch.Tensor, # Attention mask (REQUIRED) - attn_bias: torch.Tensor, # Attention bias (REQUIRED) - out: Optional[torch.Tensor] = None, # Pre-allocated output - softmax_scale: float = None, # Attention scaling - is_causal: bool = False, # Causal masking - softcap: float = 0.0, # Soft capping - return_softmax: bool = False, # Return attention weights -) -> List[torch.Tensor] -``` - -### Backward Compatibility - -**Breaking Change Notice**: The integration requires attention bias and attention mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. - -**Migration Path**: Users need to: -1. Add attention mask and bias generation logic to attention modules -2. Implement appropriate mask and bias computation within the attention forward pass -3. Ensure proper tensor shapes and dtypes for mask and bias tensors - -### Complete Usage Example - -```python -import torch -import torch.nn as nn -from flash_dmattn.integration.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward - -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.scaling = 1.0 / math.sqrt(self.head_dim) - - # Standard attention projections - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) - - def forward(self, hidden_states, attention_mask=None, attention_bias=None): - batch_size, seq_len, _ = hidden_states.shape - - # Project to Q, K, V - query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - - # Prepare mask and bias tensors with proper shapes - if attention_mask is None: - attention_mask = torch.ones((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - if attention_bias is None: - attention_bias = torch.zeros((batch_size, self.num_kv_heads, seq_len, seq_len), - dtype=query_states.dtype, device=query_states.device) - - # Call attention implementation - attn_output, attn_weights = flash_dynamic_mask_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - attention_bias=attention_bias, - scaling=self.scaling, - ) - - return attn_output, attn_weights -``` - -The attention bias generation process: - -1. **Value-based Dynamic States**: - ```python - dt_states = self.dt_proj(value_states_flattened) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - ``` - -2. **Bias Expansion**: - ```python - attn_bias = dt_states[:, :, None, :].expand(-1, -1, query_len, -1) - ``` - -3. **Mask Processing**: Done internally in `_flash_dynamic_mask_attention_forward` - - -### CUDA Backend: Sparse Attention Computation - -The CUDA backend implements the sparse attention computation through `_flash_dynamic_mask_attention_forward`: - -```python -def _flash_dynamic_mask_attention_forward( - query_states, key_states, value_states, - attention_mask, attention_bias, - query_length, key_length, - is_causal, softmax_scale=None, softcap=None, - target_dtype=None, implementation=None, **kwargs -): - dtype = query_states.dtype - min_dtype = torch.finfo(dtype).min - batch_size, _, num_kv_heads, _ = key_states.shape - - # Initialize attention bias if not provided - if attention_bias is None: - attention_bias = torch.zeros( - (batch_size, num_kv_heads, query_length, key_length), - dtype=dtype, device=query_states.device - ) - - # Apply attention mask to bias - if attention_mask is not None: - attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) - attention_mask = attention_mask.to(dtype) - - # Call Flash Attention with dynamic masking - out = flash_dmattn_func( - query_states, key_states, value_states, - attn_mask=attention_mask, attn_bias=attention_bias, - softmax_scale=softmax_scale, is_causal=is_causal - ) - - return out[0] if isinstance(out, tuple) else out -``` - -The backend processing stages: - -1. **Bias Initialization**: Create zero bias tensor if not provided -2. **Mask Application**: Apply boolean attention mask to bias tensor -3. **Flash Attention Call**: Execute optimized CUDA kernels with sparse patterns - -#### Forward Algorithm - -The implementation introduces unified block-level skip logic that optimizes computation by skipping entire tiles when they are fully masked: - -```cpp -// Forward pass with unified skip logic -for m_block in M_tiles: - load Q_tile - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Block-level skip decision - if !any_active: - advance_pointers() // Skip computation, advance to next tile - continue - - // Only execute for active tiles - load K_tile, V_tile // Load data only when needed - S = Q_tile @ K_tile^T + bias_block // Sparse Q*K^T GEMM - S_masked = apply_mask(S, mask_block) // Apply dynamic masking - P = softmax(S_masked, LSE_cache) // Softmax with LSE caching - O_partial += P @ V_tile // Sparse Score*V GEMM -write O -``` - -Key improvements: -- **Block-level Skip Logic**: OR-reduction over entire (BlockM × BlockN) tile determines if computation is needed -- **Early Skip Decision**: Mask evaluation happens before expensive K/V loading and computation -- **Pointer Management**: Safe pointer advancement ensures correct memory layout for subsequent tiles - -#### Backward Algorithm - -The backward pass also benefits from the unified skip logic, maintaining numerical correctness while significantly reducing computation for sparse patterns: - -```cpp -// Backward pass with unified skip logic -for m_block in reversed(M_tiles): - load Q_tile, dO_tile - init accum_dQ - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) // Same skip decision as forward - if !any_active: - advance_pointers_zero_side_outputs() // Skip computation, zero side outputs - continue - - // Only execute for active tiles - load K_tile, V_tile - - // Recompute (identical to forward for active tiles) - S = Q_tile @ K_tile^T + bias_block - P = softmax(S, LSE_cache) // Use cached LSE for stability - - // Gradient computation chain (5 GEMMs) - dV += P^T @ dO_tile // Accumulate dV - dP = dO_tile @ V_tile^T // Compute dP - dS = g(P, dP) // dS = (dP - (P ⊙ dP).sum(axis)) * P - dQ += dS @ K_tile // Accumulate dQ - dK += dS^T @ Q_tile // Accumulate dK - write dQ, accumulate dK, dV -``` - -Key features: -- **Recomputation Strategy**: Forward computation is recomputed only for active tiles to maintain numerical precision -- **LSE Caching**: Uses cached log-sum-exp values from forward pass for stable softmax recomputation -- **Gradient Chain**: All five gradient GEMMs are skipped for fully masked tiles, maintaining mathematical correctness -- **Zero Handling**: Properly handles zero contributions from skipped tiles in accumulation - -#### Skip Logic Correctness - -The mathematical correctness of the skip logic relies on the following principles: - -1. **Forward Skip**: If a tile is entirely masked (active_mask = 0), its contribution to the output is exactly zero: - ``` - O_contribution = P @ V = 0 @ V = 0 - ``` - -2. **Backward Skip**: For fully masked tiles, all intermediate gradients are zero: - ``` - P = 0 ⟹ dS = 0 ⟹ dQ = dK = dV = 0 (from this tile) - ``` - -3. **LSE Preservation**: Skipped tiles don't contribute to the log-sum-exp, maintaining numerical stability. - -### Sparse Computation Strategy - -### Block-level Skip Logic - -The implementation introduces unified block-level skip logic that operates at the tile granularity rather than individual elements: - -1. **Tile-level Active Detection**: - ```cpp - any_active = OR_reduce(mask_block) // Single bit indicating if any position in tile is active - ``` - -2. **Skip Decision**: Binary branch based on tile activity: - ```cpp - if (!any_active) { - advance_pointers(); // Forward: skip all computation - advance_pointers_zero_outputs(); // Backward: skip computation, zero side outputs - continue; - } - ``` - -3. **Computational Benefits**: - - Skip entire K/V loads for inactive tiles - - Eliminate all 5 GEMMs in backward pass for inactive tiles - - Reduce memory bandwidth and arithmetic operations proportional to sparsity - -### Sparsity Pattern Recognition - -The Dynamic Mask Attention implements structured sparsity based on learned importance scores: - -1. **Attention Bias Computation**: Attention bias values are computed based on dynamic states derived from value tensors - - Learned projection matrices map value features to importance scores - - Coefficient parameters control the dynamic range of importance values - - Activation functions ensure appropriate bias magnitude - -2. **Binary Attention Mask**: - - 1.0 for positions that should be computed - - 0.0 for positions that should be skipped - -### Performance Model - -For block-level sparsity with active tile fraction $p$, skip overhead ratio $\varepsilon$, and early-exit efficiency $\eta$: - -$$ -\text{Speedup} \approx \frac{1}{p + (1-p)(\varepsilon + \eta \cdot \text{LoadOverhead})} -$$ - -Where: -- $p$: fraction of active tiles -- $\varepsilon$: skip branching overhead -- $\eta$: efficiency of early memory load exit -- $\text{LoadOverhead}$: relative cost of K/V loading vs computation - -Upper bound as $\varepsilon, \eta \to 0$: $1/p$ - -### Shared Memory Aliasing - -The implementation introduces smart shared memory aliasing to reduce footprint and enable larger tile sizes: - -1. **sMask ↔ sP Aliasing**: Mask shared memory region is reused for storing softmax probabilities P after mask consumption -2. **sBias ↔ sdS Aliasing**: Bias shared memory region is reused for gradient computations dS -3. **Barrier Synchronization**: Explicit `__syncthreads()` calls ensure safe transitions between aliased usage - -```cpp -// Example aliasing pattern -load mask -> sMask -any_active = or_reduce(sMask) -if any_active: - compute S - __syncthreads() // ensure mask fully consumed - softmax -> write P into aliased region (sP) // reuse sMask region as sP - ... -__syncthreads() // ensure dS consumed -// reuse sBias region as sdS in next iteration -``` - -### Memory Efficiency Optimizations - -1. **Shared Memory Aliasing**: Smart reuse of memory regions (sMask ↔ sP, sBias ↔ sdS) with explicit barrier synchronization -2. **Block-level Skip**: Early exit from computation and memory loading for inactive tiles -3. **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -4. **Register-Optimized Operations**: Critical masking and gradient operations performed in register memory -5. **Coalesced Memory Access**: Optimized access patterns for GPU memory hierarchy -6. **Template Specialization**: Compile-time optimization eliminates runtime branching overhead - -## Memory Layout - -### Tensor Memory Organization - -The Dynamic Mask Attention extends Flash Attention's memory layout to include attention masks and attention bias: - -``` -Global Memory Layout: -┌─────────────────────────────────────────────────────────────────┐ -│ Q: [batch, seqlen_q, num_heads, head_dim] │ -│ K: [batch, seqlen_k, num_heads_k, head_dim] │ -│ V: [batch, seqlen_k, num_heads_k, head_dim] │ -│ AttnMask: [batch, num_kv_heads, seqlen_q, seqlen_k] │ -│ Bias: [batch, num_kv_heads, seqlen_q, seqlen_k] │ -│ Output: [batch, seqlen_q, num_heads, head_dim] │ -└─────────────────────────────────────────────────────────────────┘ - -Shared Memory Layout (per thread block): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Tile: [kBlockM, head_dim] │ K Tile: [kBlockN, head_dim] │ -│ V Tile: [kBlockN, head_dim] │ S Tile: [kBlockM, kBlockN] │ -│ AM Tile: [kBlockM, kBlockN] │ Bias Tile: [kBlockM, kBlockN] │ -└─────────────────────────────────────────────────────────────────────┘ - -Register Memory (per thread): -┌─────────────────────────────────────────────────────────────────────┐ -│ Q Frag: [MMA_M, head_dim/N] │ K Frag: [MMA_N, head_dim/N] │ -│ V Frag: [MMA_N, head_dim/N] │ S Frag: [MMA_M, MMA_N] │ -│ AM Frag: [MMA_M, MMA_N] │ Bias Frag: [MMA_M, MMA_N] │ -│ Acc Frag: [MMA_M, head_dim/N] │ │ -└─────────────────────────────────────────────────────────────────────┘ -``` - -### Memory Access Patterns - -#### Attention Mask and Attention Bias Loading -```cpp -// Global to Shared Memory (coalesced access) -Tensor tSgBias = local_partition(mBias, smem_tiled_copy_Bias, thread_idx); -Tensor tSsBias = local_partition(sBias, smem_tiled_copy_Bias, thread_idx); - -// Each thread loads a contiguous chunk to maximize memory bandwidth -copy(smem_tiled_copy_Bias, tSgBias, tSsBias); - -// Shared to Register Memory (bank-conflict-free) -Tensor tSrBias = local_partition(sBias, smem_thr_copy_Bias, thread_idx); -copy(smem_thr_copy_Bias, tSsBias, tSrBias); -``` - -#### Memory Layout Transformations -```cpp -// Convert MMA accumulator layout to row-column layout for masking -// From: (MMA=4, MMA_M, MMA_N) -> (nrow=(2, MMA_M), ncol=(2, MMA_N)) -auto convert_layout_acc_rowcol = [](auto layout) { - return make_layout( - make_layout(make_shape(Int<2>{}, get<1>(layout.shape())), - make_stride(Int(layout.stride())* 2>{}, get<1>(layout.stride()))), - make_layout(make_shape(Int<2>{}, get<2>(layout.shape())), - make_stride(Int<1>{}, Int<2>{})) - ); -}; -``` - -### Shared Memory Optimization - -#### Bank Conflict Avoidance -- Attention bias and attention masks use the same copy patterns as Q/K/V to avoid bank conflicts -- Padding added when necessary to ensure 128-bit aligned access -- Thread block size chosen to maximize occupancy while maintaining memory efficiency - -#### Memory Coalescing -```cpp -// Example: Loading 128-bit aligned chunks for optimal bandwidth -using SmemCopyAtomBias = Copy_Atom; // 128-bit loads -using SmemCopyAtomAttnMask = Copy_Atom; -``` - -## Performance Considerations - -### Memory Efficiency -- **Shared Memory Aliasing**: Smart memory reuse (sMask ↔ sP, sBias ↔ sdS) reduces footprint by ~30% -- **Block-level Skip**: Early exit eliminates unnecessary memory loads for inactive tiles -- **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability -- **Coalesced Access**: Optimized tensor layouts for GPU memory hierarchy - -### Computational Efficiency -- **Unified Skip Logic**: Both forward and backward passes benefit from block-level computation skipping -- **5-GEMM Chain Skip**: Complete gradient computation chain skipped for inactive tiles -- **Early Branch Decision**: Mask OR-reduction happens before expensive K/V loads -- **Warp-Level Optimization**: Operations optimized for GPU warp execution model - -### Scalability -- **Block-level Granularity**: Tile-level sparsity more efficient than element-level for long sequences -- **Multi-Head Support**: Efficient handling of multiple attention heads with per-head sparsity patterns -- **Barrier Optimization**: Minimal synchronization overhead through smart aliasing strategies - -### Performance Model - -Expected speedup for various sparsity levels: -- **50% sparsity**: ~1.8x speedup -- **75% sparsity**: ~3.2x speedup -- **90% sparsity**: ~6.5x speedup - -Performance factors: -- Skip overhead typically <5% of dense computation time -- Memory bandwidth reduction scales linearly with sparsity -- Shared memory aliasing enables 20-30% larger tile sizes - -## API Changes - -### New Required Parameters - -The Dynamic Mask Attention integration introduces new required parameters to the forward pass: - -- **`attn_mask`** (`torch.Tensor`): Attention mask tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Binary mask (1.0 = compute, 0.0 = skip) indicating which positions should be processed - - Determines the sparsity pattern for computational efficiency - -- **`attn_bias`** (`torch.Tensor`): Attention bias tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - - Contains dynamic attention bias values applied to attention scores before softmax - - Must have the same dtype and device as Q/K/V tensors - -### Updated Function Signature - -```python -def fwd( - q: torch.Tensor, # Query tensor - k: torch.Tensor, # Key tensor - v: torch.Tensor, # Value tensor - attn_mask: torch.Tensor, # Attention mask (REQUIRED) - attn_bias: torch.Tensor, # Attention bias (REQUIRED) - out: Optional[torch.Tensor] = None, # Pre-allocated output - softmax_scale: float = None, # Attention scaling - is_causal: bool = False, # Causal masking - softcap: float = 0.0, # Soft capping - return_softmax: bool = False, # Return attention weights -) -> List[torch.Tensor] -``` - -### Backward Compatibility - -**Breaking Change Notice**: The integration requires attention bias and attention mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. - -**Migration Path**: Users need to: -1. Add attention mask and bias generation logic to attention modules -2. Implement appropriate mask and bias computation within the attention forward pass -3. Ensure proper tensor shapes and dtypes for mask and bias tensors - -### Complete Usage Example - -```python -import torch -import torch.nn as nn -import flash_dmattn_cuda as flash_dmattn - -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.scaling = 1.0 / math.sqrt(self.head_dim) - - # Standard attention projections - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) - - def forward(self, hidden_states, attention_mask=None, attention_bias=None): - batch_size, seq_len, _ = hidden_states.shape - - # Project to Q, K, V - query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - - # Generate attention bias from value states - dt_states = self.dt_proj( - value_states.transpose(1, 2).reshape(batch_size, seq_len, -1) - ) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - attention_bias = dt_states[:, :, None, :].expand(-1, -1, seq_len, -1).to(hidden_states.dtype) - - # Prepare attention mask for multi-head - if attention_mask is not None: - attention_mask = attention_mask.expand(-1, self.num_kv_heads, -1, -1) - - # Flash Dynamic Mask Attention - attn_output, _ = flash_dynamic_mask_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - attention_bias=attention_bias, - scaling=self.scaling, - ) - - # Output projection - attn_output = attn_output.reshape(batch_size, seq_len, -1) - return self.o_proj(attn_output) - -# Usage example -config = type('Config', (), { - 'hidden_size': 768, - 'num_attention_heads': 12, - 'num_key_value_heads': 12, -})() - -attention = DynamicMaskAttention(config) -hidden_states = torch.randn(2, 4096, 768, device='cuda', dtype=torch.bfloat16) -output = attention(hidden_states) -print(f"Output shape: {output.shape}") # [2, 4096, 768] -``` - -### Integration with Existing Codebases - -For users migrating from Flash Attention, the typical changes required are: - -```python -# Before (Flash Attention) -class StandardAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim) - self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) - self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) - self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size) - - def forward(self, hidden_states): - q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) - output = flash_attn_func(q, k, v, dropout_p=0.1, softmax_scale=self.scaling, causal=True) - return self.o_proj(output) - -# After (Dynamic Mask Attention) -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - # Same standard projections - self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim) - self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) - self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) - self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size) - - # Add dynamic mask parameters - self.A = nn.Parameter(torch.zeros(config.num_key_value_heads)) - self.dt_proj = nn.Linear(config.num_key_value_heads * self.head_dim, config.num_key_value_heads) - self.keep_window_size = config.keep_window_size - - def forward(self, hidden_states): - # Standard Q, K, V projections - query_states = self.q_proj(hidden_states).view(...).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(...).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(...).transpose(1, 2) - - # Generate attention bias from value states - dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(...)) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - attention_bias = dt_states[:, :, None, :].expand(-1, -1, seq_len, -1) - - # Use Flash Dynamic Mask Attention - attn_output, _ = flash_dynamic_mask_attention_forward( - self, query_states, key_states, value_states, - attention_mask=attention_mask, attention_bias=attention_bias, - scaling=self.scaling - ) - - return self.o_proj(attn_output.reshape(...)) -``` \ No newline at end of file diff --git a/docs/integration_zh.md b/docs/integration_zh.md deleted file mode 100644 index 56ed0c1..0000000 --- a/docs/integration_zh.md +++ /dev/null @@ -1,522 +0,0 @@ -# Flash 动态掩码注意力集成指南 - -## 概述 - -本文档阐述了如何在 Flash Attention 框架中集成 Dynamic Mask Attention(动态掩码注意力)。通过将 Flash Attention 的高效显存利用方式与动态稀疏掩码结合,这一集成能够在极长序列场景下实现稀疏注意力的高效计算。 - -该集成方案采用统一的稀疏计算路径:Python 端负责预计算注意力掩码与偏置张量,CUDA 后端在前向与反向两个阶段执行基于块的跳过逻辑与稀疏算子调度。 - -## 目录 - -1. [集成架构](#集成架构) -2. [核心改动](#核心改动) -3. [实现细节](#实现细节) -4. [稀疏计算策略](#稀疏计算策略) -5. [内存布局](#内存布局) -6. [性能考量](#性能考量) -7. [API 变化](#api-变化) - -## 集成架构 - -### 高层设计 - -动态掩码注意力的集成在前向与反向过程中统一采用块级稀疏执行路径: - -1. **动态掩码计算**:Python 端预先生成注意力掩码(mask)与注意力偏置(bias)张量。 -2. **统一稀疏执行**:CUDA 后端在块粒度上决定是否跳过计算,并执行稀疏化的注意力与梯度算子。 -3. **内存优化**:通过共享内存别名与显式同步实现更高的共享内存复用率。 - -### 关键组件 - -- **注意力掩码**:形状为 `(batch, num_kv_heads, query_len, key_len)` 的二值张量(1.0 表示保留,0.0 表示跳过)。 -- **注意力偏置**:与掩码形状一致的张量,在 Softmax 前加性注入。 -- **块级跳过逻辑**:对 `(BlockM × BlockN)` tile 做 OR 归约判断是否执行计算。 -- **LSE 缓存**:前向阶段缓存 log-sum-exp 结果,反向阶段复用以保持数值稳定。 -- **共享内存别名**:动态复用共享内存缓冲区,配合 `__syncthreads()` 控制生命周期。 -- **完备梯度链路**:在保留稀疏跳过能力的同时,确保梯度流动正确。 - -## 核心改动 - -### 1. 参数结构扩展(`flash.h`) - -**目的**:扩展参数结构体以支持动态掩码与偏置信息,同时保留对 QKV 的统一访问接口。 - -```cpp -struct QKV_params { - void *__restrict__ q_ptr; - void *__restrict__ k_ptr; - void *__restrict__ v_ptr; - index_t q_batch_stride, k_batch_stride, v_batch_stride; - index_t q_row_stride, k_row_stride, v_row_stride; - index_t q_head_stride, k_head_stride, v_head_stride; - int h, h_k; - int h_h_k_ratio; -}; - -struct Mask_params { - void *__restrict__ mask_ptr; - index_t mask_batch_stride; - index_t mask_head_stride; - index_t mask_row_stride; -}; - -struct Bias_params { - void *__restrict__ bias_ptr; - index_t bias_batch_stride; - index_t bias_head_stride; - index_t bias_row_stride; -}; - -struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params { - // ...existing code... - bool seqlenq_ngroups_swapped; -}; -``` - -**设计要点**: -- 多重继承将 QKV、掩码、偏置的参数维度拆分,保持接口清晰。 -- 为掩码与偏置提供完整的 stride 信息,以便在 CUDA 中高效寻址。 -- 与原有 Flash Attention 的内存布局保持兼容,避免性能回退。 - -### 2. 内核特性与内存布局(`kernel_traits.h`) - -**目的**:根据架构(SM75 / SM80+)选择合适的 MMA 原子与内存拷贝路径,为动态掩码操作提供最佳性能。 - -```cpp -template -struct Flash_kernel_traits { - using Element = elem_type; - using ElementAccum = float; - using index_t = int64_t; - static constexpr int kHeadDim = kHeadDim_; - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kNWarps = kNWarps_; - // ...existing code... - using SmemCopyAtomMask = SmemCopyAtom; - using SmemCopyAtomBias = SmemCopyAtom; -}; -``` - -**设计要点**: -- 根据编译目标自动选择 `cp.async` 与 LDSM 指令路径。 -- 统一掩码与偏置的共享内存加载策略,避免额外的 bank conflict。 -- 模板化的类型安全保证不同精度(FP16/BF16)路径一致。 - -### 3. 块级信息扩展(`block_info.h`) - -**目的**:在可变长度场景下计算掩码与偏置的块级偏移量,保证全局内存访问有序。 - -```cpp -template -struct BlockInfo { - template - __device__ BlockInfo(const Params ¶ms, const int bidb) { - // ...existing code... - } - - template - __forceinline__ __device__ index_t mask_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; - sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k); - return offset; - } - - // ...existing code... -}; -``` - -**设计要点**: -- 提供统一的偏移量计算方法,简化内核中的地址计算。 -- 同时支持固定长度与可变长度两种输入形式。 -- 将左侧填充(left pad)纳入偏移量,保证稀疏掩码与 KV 缓存对齐。 - -### 4. 内存拷贝与算子工具(`utils.h`) - -**目的**:提供布局转换、类型转换、warp 归约与通用 GEMM 包装,适配 Flash Attention 的内存层次结构。 - -```cpp -namespace FLASH_NAMESPACE { - -template -__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { - // ...existing code... - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); -}; - -// ...existing code... - -template -__forceinline__ __device__ void gemm(/* ... */) { - // ...existing code... -} - -} // namespace FLASH_NAMESPACE -``` - -**设计要点**: -- 通过布局转换统一 MMA 累加器的访问方式,方便掩码逻辑在寄存器中操作。 -- 提供针对 BF16 的专用类型转换,避免额外的精度损耗。 -- Warp 归约与 GEMM 包装均支持将数据留在寄存器中,降低共享内存压力。 - -### 5. 动态掩码核心逻辑(`mask.h`) - -**目的**:在寄存器层面将掩码与偏置应用到注意力得分上,同时处理因果掩码与边界情况。 - -```cpp -template -__forceinline__ __device__ void apply_mask( - TensorType &tensor, - MaskType &mask, - BiasType &bias, - const float scale_softmax, - const int col_idx_offset_, - const int max_seqlen_k, - const int row_idx_offset, - const int max_seqlen_q, - const int warp_row_stride) { - // ...existing code... -} -``` - -**设计要点**: -- 在 `tensor` 保持 MMA 布局的情况下,逐元素应用掩码、偏置与缩放因子。 -- 因果掩码通过列索引上限裁剪实现,与动态掩码兼容。 -- 被掩盖的位置直接写入 `-INFINITY`,防止 Softmax 后出现数值污染。 - -### 6. 反向链路扩展(`flash_bwd_kernel.h`) - -**目的**:在反向传播中复用动态掩码逻辑,确保梯度仅在活跃 tile 上计算。 - -```cpp -struct Flash_bwd_params : public Flash_fwd_params { - // ...existing code... -}; - -template -inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, - const int bidh, const int n_block) { - // ...existing code... -} -``` - -**设计要点**: -- 反向路径沿用前向阶段的 tile 活跃性判断,跳过完全被掩码的块。 -- 结合 LSE 缓存,重算前向 Softmax 时保持数值稳定。 -- 保证五个梯度 GEMM 在活跃 tile 上依旧串联执行,避免梯度缺失。 - -### 7. 前向内核改造(`flash_fwd_kernel.h`) - -**目的**:在主注意力内核中插入动态掩码流程,同时保持 Flash Attention 的高并发与共享内存利用率。 - -```cpp -template -inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, - const int bidh, const int m_block) { - using Element = typename Kernel_traits::Element; - // ...existing code... -} -``` - -**设计要点**: -- 按 tile 裁剪逻辑提前判断是否加载 K/V,降低无效内存访问。 -- 仅在提供掩码/偏置时启用相应的分支,保持向后兼容。 -- 通过模板参数在编译期裁剪分支,减少运行期开销。 - -### 8. 启动模板更新(`flash_fwd_launch_template.h`) - -**目的**:在 kernel launch 阶段配置共享内存需求、模板实例化与错误处理,适配动态掩码的新资源需求。 - -```cpp -#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ -template \ -__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) - -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, - bool Is_causal, bool Is_even_MN, bool Is_even_K, - bool Is_softcap, bool Return_softmax) { - // ...existing code... -} - -// ...existing code... - -template -void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr size_t smem_size = Kernel_traits::kSmemSize; - // ...existing code... -} -``` - -**设计要点**: -- 统一宏定义减少重复代码,便于扩展到新的 kernel 变体。 -- 针对不支持的架构给出明确的构建期/运行期错误提示。 -- 在 launch 前计算共享内存需求,必要时启用 `cudaFuncSetAttribute` 进行配置。 - -### 9. Python 接口扩展(`flash_api.cpp`) - -**目的**:扩展 C++/PyBind11 接口以接受掩码与偏置张量,并提供全面的数据校验。 - -```cpp -void set_params_fprop( - Flash_fwd_params ¶ms, - // ...existing code... -) { - // ...existing code... -} - -std::vector mha_fwd( - at::Tensor &q, - // ...existing code... - const bool return_softmax) { - // ...existing code... - return {out, softmax_lse}; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashDynamicMaskAttention"; - // ...existing code... -} -``` - -**设计要点**: -- 对输入张量的形状、dtype、device 进行全面校验。 -- 保持原有参数顺序,新增参数保持向后兼容的默认行为。 -- 当掩码或偏置未提供时,自动填充零值张量以保证接口易用性。 - -## 实现细节 - -### C++ API 接口 - -C++ 端对外暴露如下核心函数,用于前向、可变长度前向与反向计算: - -```cpp -namespace FLASH_NAMESPACE { - -std::vector mha_fwd( - at::Tensor &q, - at::Tensor &k, - at::Tensor &v, - // ...existing code... - const bool return_softmax); - -std::vector mha_varlen_fwd(/* ... */); - -std::vector mha_bwd(/* ... */); - -} // namespace FLASH_NAMESPACE -``` - -- `mha_fwd`:标准批量前向,支持稀疏掩码与偏置。 -- `mha_varlen_fwd`:支持变长序列并使用累计长度数组。 -- `mha_bwd`:完成梯度计算,返回 dQ / dK / dV / dBias / dMask 等张量。 - -### 参数设置与校验 - -`set_params_fprop` 会在调用前: - -- 重置 `Flash_fwd_params` 并写入基本维度信息。 -- 将掩码与偏置的设备指针、stride、批次数等全部注册。 -- 基于输入 `dtype` 设置缩放因子与 `softcap`,同时准备缓存指针。 - -### Python 绑定与接口 - -PyBind11 模块对外暴露 `mha_fwd`、`mha_bwd`、`varlen_fwd` 等接口,文档字符串说明了参数要求与返回值。用户可通过 Python 直接调用 C++/CUDA 实现。 - -### Python 前端集成示例 - -```python -import torch -import torch.nn as nn -import flash_dmattn_cuda as flash_dmattn - -class DynamicMaskAttention(nn.Module): - def __init__(self, config): - super().__init__() - # ...existing code... - - def forward(self, query_states, key_states, value_states, attn_mask, attn_bias): - out, softmax_lse = flash_dmattn.fwd( - query_states, key_states, value_states, - attn_mask=attn_mask, - attn_bias=attn_bias, - return_softmax=True, - ) - return out, softmax_lse -``` - -- 前端模块负责生成 `attn_mask`(布尔)与 `attn_bias`(与 Q/K/V dtype 相同)。 -- 内部 `_flash_dynamic_mask_attention_forward` 会根据需要补零偏置并调用后端。 -- 输入张量默认为 `(batch, seq_len, num_heads, head_dim)` 排列,内部会自动转置到后端期望格式。 - -## 稀疏计算策略 - -### 块级跳过逻辑 - -- 在加载 Q tile 后,先将掩码 tile 拷贝到共享内存并执行 OR 归约。 -- 若整块被掩盖,则跳过 K/V 加载与后续计算,只推进指针。 -- 对活跃块执行常规注意力流程,并复用共享内存保存 Softmax 结果。 - -### 前向算法 - -```pseudo -for m_block in M_tiles: - load Q_tile - load mask_tile -> shared - any_active = or_reduce(mask_tile) - if not any_active: - continue - load K_tile, V_tile - compute scaled dot product - apply mask & bias in registers - softmax -> write O_tile -``` - -- 掩码裁剪保证 Tile 内所有无效位置直接输出 `-INF`。 -- Softmax 前的缩放与偏置添加与密集版本保持一致。 -- 通过共享内存别名(sMask ↔ sP)减少显存占用。 - -### 反向算法 - -```pseudo -for m_block in reversed(M_tiles): - load Q_tile, dO_tile - load mask_tile -> shared - if tile inactive: - continue - recompute scores with cached LSE - propagate gradients for dS, dV, dK, dQ -``` - -- 仅对活跃块执行五个 GEMM 组合,减少稀疏场景下的冗余计算。 -- 使用前向缓存的 LSE 确保 Softmax 反向的数值稳定性。 -- 对被跳过的块梯度自然为零,避免写入污染。 - -### 跳过逻辑正确性 - -- 若 tile 全部被掩码,输出必为零,跳过计算不会影响结果。 -- 反向阶段活跃性与前向保持一致,保证梯度对应关系不被破坏。 -- 由于被掩盖位置在 Softmax 前已写入 `-INF`,LSE 亦不受影响。 - -## 内存布局 - -### 全局内存组织 - -``` -Q: [batch, seqlen_q, num_heads, head_dim] -K: [batch, seqlen_k, num_kv_heads, head_dim] -V: [batch, seqlen_k, num_kv_heads, head_dim] -Mask: [batch, num_kv_heads, seqlen_q, seqlen_k] -Bias: [batch, num_kv_heads, seqlen_q, seqlen_k] -Output: [batch, seqlen_q, num_heads, head_dim] -``` - -### 共享内存布局(每个线程块) - -``` -Q Tile : [kBlockM, head_dim] -K Tile : [kBlockN, head_dim] -V Tile : [kBlockN, head_dim] -S Tile : [kBlockM, kBlockN] -Mask Tile: [kBlockM, kBlockN] -Bias Tile: [kBlockM, kBlockN] -``` - -### 寄存器布局(每个线程) - -``` -Q Frag : [MMA_M, head_dim / N] -K Frag : [MMA_N, head_dim / N] -V Frag : [MMA_N, head_dim / N] -S Frag : [MMA_M, MMA_N] -Mask Frag: [MMA_M, MMA_N] -Bias Frag: [MMA_M, MMA_N] -Acc Frag : [MMA_M, head_dim / N] -``` - -### 内存访问模式 - -- 掩码与偏置与 K/V 共享相同的 `Copy_Atom` 配置,确保 128-bit 对齐、最大化带宽。 -- 共享内存拷贝后通过 `local_partition` 分配给线程,避免 bank conflict。 -- `convert_layout_acc_rowcol` 将 MMA 布局转换为行/列布局,方便寄存器操作。 - -### 共享内存优化 - -- **别名复用**:`sMask` 在使用后可重用为 `sP`(Softmax 输出),`sBias` 可重用为 `sdS`。 -- **同步屏障**:在重用前使用 `__syncthreads()` 确保所有线程完成对旧数据的使用。 -- **块尺寸选择**:根据稀疏度与共享内存限制调整 tile 尺寸,提高 SM 占用率。 - -## 性能考量 - -- **共享内存复用**:别名策略可将共享内存占用削减约 30%。 -- **块级跳过**:当稀疏度为 75% 时,可获得约 3× 的前向提速;稀疏度 90% 时可达到 ~6×。 -- **带宽优化**:跳过无效 tile 可以线性降低全局内存带宽需求。 -- **同步开销**:跳过路径的额外 OR 归约占总时间 <5%,可忽略不计。 -- **硬件自适应**:针对 SM75/SM80+ 的不同指令集做了专门优化,确保跨架构稳定收益。 - -## API 变化 - -### 新增必要参数 - -- `attn_mask` (`torch.Tensor`): 形状 `(batch, num_kv_heads, seqlen_q, seqlen_k)` 的布尔张量,决定稀疏模式。 -- `attn_bias` (`torch.Tensor`): 形状与掩码一致的加性偏置张量,dtype 与 Q/K/V 保持一致。 - -### 更新的函数签名 - -```python -def fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attn_mask: torch.Tensor, - attn_bias: torch.Tensor, - is_causal: bool = False, - return_softmax: bool = False, - **kwargs -) -> List[torch.Tensor]: - ... -``` - -### 向后兼容说明 - -- 这是一个破坏性更新,旧的 Flash Attention 调用需显式提供掩码与偏置。 -- 若业务场景不需要稀疏掩码,可传入全 1 掩码与全 0 偏置实现与旧版一致的行为。 -- 缺省值在 Python 前端会自动补齐,降低迁移的代码改动。 - -### 完整用法示例 - -```python -import torch -from flash_dmattn.integrations.flash_dynamic_mask_attention import ( - flash_dynamic_mask_attention_forward, -) - -batch, seq_q, seq_k, n_heads, head_dim = 2, 4096, 4096, 16, 128 -q = torch.randn(batch, seq_q, n_heads, head_dim, device="cuda", dtype=torch.float16) -k = torch.randn_like(q) -v = torch.randn_like(q) -mask = torch.ones(batch, n_heads, seq_q, seq_k, device=q.device, dtype=torch.bool) -bias = torch.zeros(batch, n_heads, seq_q, seq_k, device=q.device, dtype=q.dtype) - -out = flash_dynamic_mask_attention_forward( - query_states=q, - key_states=k, - value_states=v, - attention_mask=mask, - attention_bias=bias, - return_attn_probs=False, -) -``` - -- `flash_dynamic_mask_attention_forward` 会自动完成张量转置、补零偏置等准备工作。 -- 若指定 `return_attn_probs=True`,将返回经过 Softmax 的注意力概率,用于调试或可视化。 -- 稀疏模式的 mask 可通过 `flash_dmattn.utils.mask.MaskMod` 组合生成。 - -## 附加建议 - -- 修改 CUDA 核心代码后,至少运行 `benchmarks/forward_equivalence.py` 与 `benchmarks/grad_equivalence.py` 进行回归验证。 -- 构建扩展时可使用 `pip install -e . --no-build-isolation`,必要时设置 `FLASH_DMATTN_CUDA_ARCHS` 指定目标架构。 -- 若仅依赖 Triton/Flex 后端,可通过环境变量 `FLASH_DMATTN_SKIP_CUDA_BUILD=1` 跳过 CUDA 构建。 diff --git a/docs/v1.0.0_technical_report.md b/docs/v1.0.0_technical_report.md deleted file mode 100644 index f6ddf40..0000000 --- a/docs/v1.0.0_technical_report.md +++ /dev/null @@ -1,299 +0,0 @@ -# flash-dmattn v1.0.0 Technical Report - -## 1. Overview -flash-dmattn is a high-performance FlashAttention-style implementation optimized for large sequence lengths and structured sparsity via Dynamic Masks. It provides: -- Unified block-level dynamic mask (block-sparse) skip logic in both forward and backward passes. -- Fused softmax, normalization, and recomputation-friendly backward pipeline. -- Smart shared memory aliasing to reduce footprint and enhance occupancy. -- Support for bias, Log-Sum-Exp (LSE) caching, and optional softcap. -- PyTorch Autograd compatibility and downstream model integration (example: Doge model, HuggingFace-style interface). - -v1.0.0 Highlights: -1. Unified sparse skip logic for both forward and backward (eliminates redundant compute on fully masked tiles). -2. Improved numerical and performance consistency: coherent shared memory layout, aliasing, and barrier sequencing. -3. Documentation, API stabilization, and extensibility groundwork for finer-grained sparsity (bit-packed, fragment-level) later. - -Differences vs v0.3.0: -- v0.3.0 only considered backward skip conceptually; v1.0.0 fully unifies forward + backward skip execution. -- Added strict barrier ordering to prevent NaNs (notably in dK path) when reusing aliased shared memory regions. -- Enhanced documentation, tests, and benchmarking. - -## 2. Architecture -Layers: -1. Python Integration: `flash_dmattn_interface.py` exposing user-friendly APIs (mirroring standard attention calls). -2. Kernel Dispatch Layer: `flash_dmattn_flex.py` / `flash_dmattn_triton.py` selecting CUDA / Triton / hybrid code paths. -3. C++/CUDA Core: flash_api.cpp + `src/*.h` (core kernels: `flash_fwd_kernel.h`, `flash_bwd_kernel.h`). -4. Dynamic Mask Integration: `integrations/flash_dynamic_mask_attention.py` and helpers. -5. Benchmarks & Validation: `benchmarks/*_equivalence.py`, `*_performance.py`. - -Backward dataflow: -Q,K,V,dO (+ mask, bias, LSE) → block streaming → (block-sparse skip decision) → if active: recompute scores & softmax(P) → accumulate dV,dP,dQ,dK → write back. - -## 3. Key Features -- Block-level Dynamic Mask: - - OR-reduction over (BlockM × BlockN) tile; if all zeros → skip. -- Unified Skip (Forward + Backward): - - Forward: skip QK^T, softmax, and P·V for fully masked tiles; safely advances pointers / outputs zeros. - - Backward: skip recompute + the chain of 5 GEMMs (QK^T, dO·V^T, P^T·dO→dV, dP·K→dQ, dP^T·Q→dK). -- LSE Caching: - - Ensures numerical stability: P derived via stored log-sum-exp. -- Optional Softcap: - - Scaling / clamping scores pre-softmax. -- Shared Memory Aliasing: - - sMask ↔ sP; sBias ↔ sdS with explicit barriers. -- Mixed Precision: - - FP16/BF16 inputs, FP32 accumulation. -- Modular KernelTraits: - - Controls block sizes, pipeline depth (double buffering), layouts. -- Extensible Sparsity: - - Design leaves room for bit-packed masks and fragment gating. - -## 4. Algorithms & Kernels - -### 4.1 Forward (Pseudo-code) -``` -for m_block in M_tiles: - load Q_tile - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) - if !any_active: - advance_pointers() - continue - load K_tile, V_tile - S = Q_tile @ K_tile^T + bias_block - S_masked = apply_mask(S, mask_block) - P = softmax(S_masked, LSE_cache) - O_partial += P @ V_tile -write O -``` - -### 4.2 Backward (Pseudo-code) -``` -for m_block in reversed(M_tiles): - load Q_tile, dO_tile - init accum_dQ - for n_block in N_tiles_stream: - load mask_block - any_active = OR(mask_block) - if !any_active: - advance_pointers_zero_side_outputs() - continue - load K_tile, V_tile - # Recompute - S = Q_tile @ K_tile^T + bias_block - P = softmax(S, LSE_cache) - # Grad chain - dV += P^T @ dO_tile - dP = dO_tile @ V_tile^T - dS = g(P, dP) # (dP - (P ⊙ dP).sum(axis)) * P - dQ += dS @ K_tile - dK += dS^T @ Q_tile - write dQ, accumulate dK, dV -``` - -### 4.3 Softmax & Gradient -Given $S_{ij}$ and $LSE_i = \log \sum_k e^{S_{ik}}$, - -$$ -P_{ij} = \frac{e^{S_{ij}-LSE_i}}{\sum_k e^{S_{ik}-LSE_i}} -$$ - -Backward: - -$$ -\frac{\partial \mathcal{L}}{\partial S_{ij}} = \left( \frac{\partial \mathcal{L}}{\partial P_{ij}} - \sum_{k} \frac{\partial \mathcal{L}}{\partial P_{ik}} P_{ik} \right) P_{ij} -$$ - -Fully masked tile: $P=0 \Rightarrow dS=0$, all dependent GEMMs yield zero → safe to skip. - -### 4.4 Correctness of Skip -If a tile is entirely masked: -- Forward contributions vanish (outputs zero block). -- Backward intermediate tensors (S,P,dS,dP) logically zero; linear GEMMs on zero give zero. -Therefore removing those computations preserves gradients. - -## 5. Sparsity Logic & Performance - -### 5.1 Active Tile Detection -- Load mask tile into shared memory. -- Parallel OR reduction across threads / warps. -- any_active=false triggers skip branch. - -### 5.2 Performance Model -Let active fraction $p$, skip overhead ratio $\varepsilon$: - -$$ -\text{Speedup} \approx \frac{1}{p + (1-p)\varepsilon} -$$ - -Upper bound as $\varepsilon \to 0$: $1/p$. - -### 5.3 Influencing Factors -- Reduction latency vs early placement. -- Pipeline bubbles due to frequent divergent skip branches. -- Memory bandwidth—mask format (bit-packed future) reduces load footprint. - -### 5.4 Future Enhancements -- Earlier gating (before K/V loads). -- Adaptive density threshold. -- Bit-packed + warp ballot fast OR. -- Persistent CTA / work queue for load balancing. - -## 6. API Summary -Primary function: -`flash_dynamic_mask_attention(q, k, v, attn_mask=None, bias=None, softcap=None, causal=False, return_lse=False, ...)` - -Inputs: -- q/k/v: [B, H, L, D] (k/v possibly different length) -- attn_mask: block-aligned or internally sliced dynamic mask -- bias: optional additive bias -- softcap: optional scaling/clamp -Outputs: -- O (and optionally LSE when requested). - -Config: -- Block sizes (e.g., 64×64) via traits -- dtype: fp16 / bf16 (fp32 accum) -- enable_skip (default on) -- softcap scalar - -## 7. Memory & Synchronization -- Double buffering for streaming Q/K/V with `cp.async` fences. -- Aliasing: - - sMask reused as sP after consumption. - - sBias reused as sdS after gradient consumption. -- Critical barriers: - 1. Ensure mask fully read before overwriting region with P. - 2. Ensure dS fully consumed (dK finished) before alias region becomes bias. -Goal: minimize shared memory to enable larger tiles and higher occupancy. - -## 8. Numerical Stability -- LSE caching prevents overflow. -- FP16/BF16 inputs + FP32 accumulation. -- Skip path doesn't touch LSE entries of masked tiles. -- Validation scripts: forward/backward/grad equivalence across lengths, densities. - -## 9. Backward Compatibility & Upgrade -- Same Python API; upgrading from v0.3.0 requires no code changes for standard use. -- Internal layout symbols not part of public contract—custom kernels should revalidate alias expectations. -- Future runtime stats API planned (non-breaking). - -## 10. Known Limitations -- Only block-aligned sparsity (no arbitrary coordinate compression yet). -- Skip decision not yet moved ahead of K/V/dO loads. -- No fragment-level (Tensor Core tile) sparsity gating yet. -- No built-in distributed (multi-GPU) attention aggregation logic. -- Triton path feature parity still evolving. - -## 11. Testing & Validation -- Numerical: compare to dense `scaled_dot_product_attention`. -- Sparsity: random masks of varying density; compare skip vs forced-dense output. -- Regression: multi-block scenarios to guard prior dK NaN issue. -- Benchmarks: measure kernel time vs density p. - -## 12. Roadmap -1. Early mask gating pre-load. -2. Bit-packed mask + warp ballot OR. -3. Adaptive skip threshold (disable when p high). -4. Fragment-level MMA gating. -5. Persistent CTA + work queue. -6. Runtime counters: active/skipped tile counts, effective density. -7. Distributed integration examples. - -## 13. Safety & Robustness -- Input validation: shapes / dtypes / device alignment. -- Mask alignment and slicing. -- LSE + FP32 mitigate overflow. -- Barriers enforce safe alias lifecycle. -- Future fallback path for anomaly detection (planned). - -## 14. Acknowledgements -- Inspired by FlashAttention research and community. -- Contributors: core maintainers & commit authors (see git history). -- Ecosystem: PyTorch / CUTLASS / Triton. - -## 15. Version Delta Summary -Changes vs v0.3.0: -- Added forward skip bringing full forward/backward symmetry. -- Fixed block size condition + enhanced documentation. -- Shared memory alias + barrier ordering refinements (resolved dK NaNs). -- Skip branch pointer advancement semantics aligned with dense path. -- Comprehensive technical documentation and math derivations. - -## 16. Formula Quick Reference -1. Softmax: - -$$ -P_{ij} = \frac{e^{S_{ij}-LSE_i}}{\sum_k e^{S_{ik}-LSE_i}}, \quad LSE_i = \log \sum_k e^{S_{ik}} -$$ - -2. dS: - -$$ -dS_{ij} = \left(dP_{ij} - \sum_k dP_{ik} P_{ik}\right) P_{ij} -$$ - -3. Grad propagation: - -$$ -dQ = dS K,\quad dK = dS^T Q,\quad dV = P^T dO -$$ - -4. Skip predicate: - -$$ -any\_active = \bigvee_{(i,j)\in tile} mask_{ij} -$$ - -## 17. Alias & Barrier Snippet -``` -load mask -> sMask -any_active = or_reduce(sMask) -if any_active: - # reuse sMask region as sP after consumption - compute S - softmax -> write P into aliased region (sP) - ... -__syncthreads() # ensure dS consumed -# reuse sBias region as sdS in next iteration -``` - -## 18. Glossary -- Block / Tile: matrix sub-region processed per step. -- Skip: branch eliminating compute for fully masked tile. -- LSE: log-sum-exp cache for stability. -- Aliasing: reusing shared memory region across disjoint lifetimes. -- Fragment-level: granularity of Tensor Core MMA fragments. - -## 19. Integration -- HuggingFace-style example: modeling_doge.py -- Drop-in custom attention module inside transformer blocks. -- Planned: wrapper matching `scaled_dot_product_attention` signature for rapid substitution. - -## 20. Debug & Diagnostics -Planned env toggles to print: -- Active vs skipped tile counts -- Skip hit rate -- Average tile density -Common issues: -- NaNs: verify barrier / alias ordering not altered. -- Poor speedup: density p too high; disable skip to compare. - -## 21. Release Guidance -- Users gain block-sparse skip automatically after upgrade. -- For custom builds: confirm target GPU arch (sm80+) for Tensor Core efficiency. - -## 22. References -- Dao et al., FlashAttention series -- CUTLASS docs -- PyTorch Autograd internals - ---- - -## What's Changed -* Optimize sparse logic by @LoserCheems in https://github.com/SmallDoges/flash-dmattn/pull/131 -* Fix block size condition and enhance documentation by @LoserCheems in https://github.com/SmallDoges/flash-dmattn/pull/134 - - -**Full Changelog**: https://github.com/SmallDoges/flash-dmattn/compare/v0.3.0...v1.0.0 From 554e7e0c7cad1c2483cd9c26661bae8998595c6e Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:34:29 +0800 Subject: [PATCH 21/29] Align docs with sparse attention rename Updates API reference to reflect the flash_sparse_attn branding so installation instructions, imports, and backend descriptions stay consistent with the renamed package. --- docs/api_reference.md | 80 +++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/docs/api_reference.md b/docs/api_reference.md index 65bbebf..f6b4316 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -1,9 +1,9 @@ -# Flash Dynamic Mask Attention API Reference +# Flash Sparse Attention API Reference ## Overview -Flash Dynamic Mask Attention is a high-performance attention implementation that combines the memory efficiency of Flash Attention with the sparse compute benefits of Dynamic Mask Attention. It supports CUDA, Triton, and Flex Attention backends and dynamic masking for very long sequences. +Flash Sparse Attention is a high-performance attention implementation that combines the memory efficiency of Flash Attention with the sparse compute benefits of Dynamic Mask Attention. It supports CUDA, Triton, and Flex Attention backends and dynamic masking for very long sequences. ## Table of Contents @@ -12,9 +12,9 @@ Flash Dynamic Mask Attention is a high-performance attention implementation that 2. [Quick Start](#quick-start) 3. [Backend Selection and Comparison](#backend-selection-and-comparison) 4. [API Reference](#api-reference) - - [CUDA Backend: flash_dmattn_func](#flash_dmattn_func-cuda-backend) - - [Triton Backend: triton_dmattn_func](#triton_dmattn_func-triton-backend) - - [Flex Backend: flex_dmattn_func](#flex_dmattn_func-flex-backend) + - [CUDA Backend: flash_sparse_attn_func](#flash_sparse_attn_func-cuda-backend) + - [Triton Backend: triton_sparse_attn_func](#triton_sparse_attn_func-triton-backend) + - [Flex Backend: flex_sparse_attn_func](#flex_sparse_attn_func-flex-backend) 5. [Integrations](#integrations) - [Transformers Integration](#transformers-integration) 6. [Common Issues and Solutions](#common-issues-and-solutions) @@ -22,27 +22,27 @@ Flash Dynamic Mask Attention is a high-performance attention implementation that ## Installation -Please refer to the [README](https://github.com/SmallDoges/flash-dmattn/blob/main/README.md#install) for detailed installation instructions. +Please refer to the [README](https://github.com/SmallDoges/flash-sparse-attention/blob/main/README.md#install) for detailed installation instructions. ```bash # With CUDA backend -pip install flash-dmattn +pip install flash-sparse-attn # Or install from source pip install -e . # Triton/Flex only -FLASH_DMATTN_SKIP_CUDA_BUILD=1 pip install -e . +FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD=1 pip install -e . ``` ## Quick Start -Use `flash_dmattn_func_auto` to automatically select the best available backend without manual checking. +Use `flash_sparse_attn_func_auto` to automatically select the best available backend without manual checking. ```python import torch -from flash_dmattn import flash_dmattn_func_auto +from flash_sparse_attn import flash_sparse_attn_func_auto # Prepare input tensors batch, seqlen, num_heads, head_dim = 2, 1024, 8, 64 @@ -51,19 +51,19 @@ k = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device v = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device='cuda') # Get attention function (auto-select backend, priority: cuda > triton > flex) -attn_func = flash_dmattn_func_auto() +attn_func = flash_sparse_attn_func_auto() # Compute attention output = attn_func(q, k, v, is_causal=True) print(f"Output shape: {output.shape}") # (2, 1024, 8, 64) # Or force a specific backend -attn_func = flash_dmattn_func_auto(backend="cuda") # or "triton", "flex" +attn_func = flash_sparse_attn_func_auto(backend="cuda") # or "triton", "flex" output = attn_func(q, k, v, is_causal=True) ``` > [!NOTE] -> `flash_dmattn_func_auto` returns a callable attention function, not the attention output. +> `flash_sparse_attn_func_auto` returns a callable attention function, not the attention output. ## Backend Selection and Comparison @@ -71,7 +71,7 @@ output = attn_func(q, k, v, is_causal=True) ### Check Available Backends ```python -from flash_dmattn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE +from flash_sparse_attn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE # List all available backends print(get_available_backends()) # e.g., ["cuda", "triton", "flex"] @@ -101,19 +101,19 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL ### When to Use Each Backend -**CUDA Backend** ([details](#flash_dmattn_func-cuda-backend)) +**CUDA Backend** ([details](#flash_sparse_attn_func-cuda-backend)) - ✅ Training workloads requiring full gradient support - ✅ Production inference requiring maximum performance - ✅ Applications needing deterministic behavior - ❌ Avoid: when custom CUDA extensions cannot be built -**Triton Backend** ([details](#triton_dmattn_func-triton-backend)) +**Triton Backend** ([details](#triton_sparse_attn_func-triton-backend)) - ✅ Training when CUDA extension unavailable - ✅ Development and prototyping - ✅ Cross-platform compatibility needs - ✅ Good balance of performance and ease of installation -**Flex Backend** ([details](#flex_dmattn_func-flex-backend)) +**Flex Backend** ([details](#flex_sparse_attn_func-flex-backend)) - ✅ Inference-only applications - ✅ Research with latest PyTorch features - ✅ Quick experimentation without custom builds @@ -123,15 +123,15 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL ### Import Available Functions ```python -from flash_dmattn import ( +from flash_sparse_attn import ( # Automatic backend selection get_available_backends, - flash_dmattn_func_auto, + flash_sparse_attn_func_auto, # Backend-specific functions - flash_dmattn_func, # CUDA backend - triton_dmattn_func, # Triton backend - flex_dmattn_func, # Flex backend + flash_sparse_attn_func, # CUDA backend + triton_sparse_attn_func, # Triton backend + flex_sparse_attn_func, # Flex backend # Backend availability flags CUDA_AVAILABLE, @@ -140,20 +140,20 @@ from flash_dmattn import ( ) # Transformers integration -from flash_dmattn.integrations.flash_dynamic_mask_attention import ( - flash_dynamic_mask_attention_forward +from flash_sparse_attn.integrations.flash_sparse_attention import ( + flash_sparse_attention_forward ) ``` ## API Reference -### flash_dmattn_func (CUDA backend) +### flash_sparse_attn_func (CUDA backend) Main attention function. Supports multi-head and grouped-query attention (when the number of KV heads is smaller than the number of Q heads). Requires the CUDA extension to be built and available. ```python -def flash_dmattn_func( +def flash_sparse_attn_func( query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) key: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) value: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) @@ -182,12 +182,12 @@ def flash_dmattn_func( - output: (B, Q, H, D) -### triton_dmattn_func (Triton backend) +### triton_sparse_attn_func (Triton backend) Triton-based implementation that provides good performance without requiring custom CUDA kernels. ```python -def triton_dmattn_func( +def triton_sparse_attn_func( query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) @@ -198,12 +198,12 @@ def triton_dmattn_func( ) -> torch.Tensor ``` -### flex_dmattn_func (Flex Attention backend) +### flex_sparse_attn_func (Flex Attention backend) Flex Attention-based implementation using PyTorch's native flex attention with dynamic masking support. ```python -def flex_dmattn_func( +def flex_sparse_attn_func( query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) @@ -221,13 +221,13 @@ def flex_dmattn_func( Integration function for HuggingFace Transformers models that provides seamless flash dynamic mask attention support. -#### flash_dynamic_mask_attention_forward +#### flash_sparse_attention_forward ```python -from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward +from flash_sparse_attn.integrations.flash_sparse_attention import flash_sparse_attention_forward -def flash_dynamic_mask_attention_forward( +def flash_sparse_attention_forward( module: torch.nn.Module, # The attention module query: torch.Tensor, # (batch_size, num_heads, query_len, head_dim) key: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim) @@ -254,7 +254,7 @@ def flash_dynamic_mask_attention_forward( - is_causal: Whether to apply causal mask - window_size: Size of window to keep - layer_idx: Layer index for logging - - implementation: Implementation to use ("flash_dmattn" or None) + - implementation: Implementation to use ("flash_sparse_attn" or None) #### Returns @@ -268,7 +268,7 @@ import torch.nn as nn import torch.nn.functional as F from typing import Optional, Callable, tuple from transformers.cache_utils import Cache -from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward +from flash_sparse_attn.integrations.flash_sparse_attention import flash_sparse_attention_forward class DynamicMaskAttention(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None): @@ -332,7 +332,7 @@ class DynamicMaskAttention(nn.Module): attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype) # Choose attention implementation - attention_interface: Callable = flash_dynamic_mask_attention_forward + attention_interface: Callable = flash_sparse_attention_forward attn_output, attn_weights = attention_interface( self, @@ -362,7 +362,7 @@ This example shows: ```python try: - from flash_dmattn import flash_dmattn_func_auto, get_available_backends + from flash_sparse_attn import flash_sparse_attn_func_auto, get_available_backends print("✅ Imported successfully", get_available_backends()) except ImportError as e: print(f"❌ Import failed: {e}") @@ -385,10 +385,10 @@ except ImportError as e: ```python import torch -from flash_dmattn import flash_dmattn_func_auto +from flash_sparse_attn import flash_sparse_attn_func_auto torch.autograd.set_detect_anomaly(True) -attn = flash_dmattn_func_auto() +attn = flash_sparse_attn_func_auto() output = attn(q, k, v, attn_mask=attn_mask, attn_bias=attn_bias, is_causal=True) if torch.isnan(output).any(): print("⚠️ NaN detected in attention output") @@ -404,7 +404,7 @@ def print_memory_stats(): print(f"max alloc: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB") print_memory_stats() -attn = flash_dmattn_func_auto() +attn = flash_sparse_attn_func_auto() output = attn(q, k, v) print_memory_stats() ``` From ac95f256f81a613e7e7d634f3e481a6906e7c643 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:34:53 +0800 Subject: [PATCH 22/29] Aligns Chinese doc with sparse attention Updates terminology to reflect the flash sparse attention rebranding so readers follow accurate package names, imports, and integration guidance. --- docs/api_reference_zh.md | 83 ++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 42 deletions(-) diff --git a/docs/api_reference_zh.md b/docs/api_reference_zh.md index 83d1617..3676d16 100644 --- a/docs/api_reference_zh.md +++ b/docs/api_reference_zh.md @@ -1,9 +1,9 @@ -# Flash Dynamic Mask Attention API 参考文档 +# Flash Sparse Attention API 参考文档 ## 概述 -Flash Dynamic Mask Attention 是一个高性能注意力实现,结合了 Flash Attention 的内存效率和 Dynamic Mask Attention 的稀疏计算优势。它支持 CUDA、Triton 和 Flex Attention 后端,并支持超长序列的动态掩码。 +Flash Sparse Attention 是一个高性能注意力实现,结合了 Flash Attention 的内存效率和 Dynamic Mask Attention 的稀疏计算优势。它支持 CUDA、Triton 和 Flex Attention 后端,并支持超长序列的动态掩码。 ## 目录 @@ -12,9 +12,9 @@ Flash Dynamic Mask Attention 是一个高性能注意力实现,结合了 Flash 2. [快速开始](#快速开始) 3. [后端选择与比较](#后端选择与比较) 4. [接口函数详解](#接口函数详解) - - [CUDA 后端:flash_dmattn_func](#flash_dmattn_func-cuda-后端) - - [Triton 后端:triton_dmattn_func](#triton_dmattn_func-triton-后端) - - [Flex 后端:flex_dmattn_func](#flex_dmattn_func-flex-后端) + - [CUDA 后端:flash_sparse_attn_func](#flash_sparse_attn_func-cuda-后端) + - [Triton 后端:triton_sparse_attn_func](#triton_sparse_attn_func-triton-后端) + - [Flex 后端:flex_sparse_attn_func](#flex_sparse_attn_func-flex-后端) 5. [集成](#集成) - [Transformers 集成](#transformers-集成) 6. [常见问题与解决方案](#常见问题与解决方案) @@ -22,27 +22,26 @@ Flash Dynamic Mask Attention 是一个高性能注意力实现,结合了 Flash ## 安装 -请参考 [README](https://github.com/SmallDoges/flash-dmattn/blob/main/README_zh.md#%E5%AE%89%E8%A3%85-1) 以获取详细的安装说明和依赖项。 +请参考 [README](https://github.com/SmallDoges/flash-sparse-attention/blob/main/README_zh.md#%E5%AE%89%E8%A3%85-1) 以获取详细的安装说明和依赖项。 ```bash # 使用 CUDA 后端 -pip install flash-dmattn - +pip install flash-sparse-attn # 或从源码安装 pip install -e . # 仅使用 Triton/Flex 后端 -FLASH_DMATTN_SKIP_CUDA_BUILD=1 pip install -e . +FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD=1 pip install -e . ``` ## 快速开始 -使用 `flash_dmattn_func_auto` 可以自动选择最佳可用后端,无需手动判断。 +使用 `flash_sparse_attn_func_auto` 可以自动选择最佳可用后端,无需手动判断。 ```python import torch -from flash_dmattn import flash_dmattn_func_auto +from flash_sparse_attn import flash_sparse_attn_func_auto # 准备输入张量 batch, seqlen, num_heads, head_dim = 2, 1024, 8, 64 @@ -51,19 +50,19 @@ k = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device v = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device='cuda') # 获取注意力函数(自动选择后端,优先级: cuda > triton > flex) -attn_func = flash_dmattn_func_auto() +attn_func = flash_sparse_attn_func_auto() # 调用注意力计算 output = attn_func(q, k, v, is_causal=True) print(f"输出形状: {output.shape}") # (2, 1024, 8, 64) # 也可以强制使用特定后端 -attn_func = flash_dmattn_func_auto(backend="cuda") # 或 "triton", "flex" +attn_func = flash_sparse_attn_func_auto(backend="cuda") # 或 "triton", "flex" output = attn_func(q, k, v, is_causal=True) ``` > [!NOTE] -> `flash_dmattn_func_auto` 返回一个可调用的注意力函数,而不是注意力输出。 +> `flash_sparse_attn_func_auto` 返回一个可调用的注意力函数,而不是注意力输出。 ## 后端选择与比较 @@ -71,7 +70,7 @@ output = attn_func(q, k, v, is_causal=True) ### 可用后端检查 ```python -from flash_dmattn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE +from flash_sparse_attn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE # 查看所有可用后端 print(get_available_backends()) # 例如:["cuda", "triton", "flex"] @@ -101,19 +100,19 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL ### 何时使用各个后端 -**CUDA 后端** ([详细说明](#flash_dmattn_func-cuda-后端)) +**CUDA 后端** ([详细说明](#flash_sparse_attn_func-cuda-后端)) - ✅ 完整梯度支持的训练工作负载 - ✅ 最大性能生产推理 - ✅ 需要确定性行为的应用 - ❌ 避免:无法构建自定义 CUDA 扩展时 -**Triton 后端** ([详细说明](#triton_dmattn_func-triton-后端)) +**Triton 后端** ([详细说明](#triton_sparse_attn_func-triton-后端)) - ✅ CUDA 扩展不可用时的训练工作负载 - ✅ 开发和原型设计 - ✅ 跨平台兼容性需求 - ✅ 性能和易安装性的良好平衡 -**Flex 后端** ([详细说明](#flex_dmattn_func-flex-后端)) +**Flex 后端** ([详细说明](#flex_sparse_attn_func-flex-后端)) - ✅ 仅推理应用 - ✅ 使用最新 PyTorch 特性的研究 - ✅ 无需自定义构建的快速实验 @@ -123,15 +122,15 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL ### 导入可用函数 ```python -from flash_dmattn import ( +from flash_sparse_attn import ( # 自动后端选择 get_available_backends, - flash_dmattn_func_auto, + flash_sparse_attn_func_auto, # 后端特定函数 - flash_dmattn_func, # CUDA 后端 - triton_dmattn_func, # Triton 后端 - flex_dmattn_func, # Flex 后端 + flash_sparse_attn_func, # CUDA 后端 + triton_sparse_attn_func, # Triton 后端 + flex_sparse_attn_func, # Flex 后端 # 后端可用性标志 CUDA_AVAILABLE, @@ -140,20 +139,20 @@ from flash_dmattn import ( ) # Transformers 集成 -from flash_dmattn.integrations.flash_dynamic_mask_attention import ( - flash_dynamic_mask_attention_forward +from flash_sparse_attn.integrations.flash_sparse_attention import ( + flash_sparse_attention_forward ) ``` ## 接口函数详解 -### flash_dmattn_func (CUDA 后端) +### flash_sparse_attn_func (CUDA 后端) 主要的注意力函数。支持多头注意力和分组查询注意力(当 KV 头数少于 Q 头数时)。需要 CUDA 扩展已构建并可用。 ```python -def flash_dmattn_func( +def flash_sparse_attn_func( query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) key: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) value: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) @@ -182,12 +181,12 @@ def flash_dmattn_func( - output: (B, Q, H, D) -### triton_dmattn_func (Triton 后端) +### triton_sparse_attn_func (Triton 后端) 基于 Triton 的实现,无需自定义 CUDA 内核即可提供良好性能。 ```python -def triton_dmattn_func( +def triton_sparse_attn_func( query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) @@ -198,12 +197,12 @@ def triton_dmattn_func( ) -> torch.Tensor ``` -### flex_dmattn_func (Flex Attention 后端) +### flex_sparse_attn_func (Flex Attention 后端) 基于 Flex Attention 的实现,使用 PyTorch 原生 flex attention 并支持动态掩码。 ```python -def flex_dmattn_func( +def flex_sparse_attn_func( query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) @@ -219,14 +218,14 @@ def flex_dmattn_func( ### Transformers 集成 -为 HuggingFace Transformers 模型提供的集成函数,提供无缝的 flash dynamic mask attention 支持。 +为 HuggingFace Transformers 模型提供的集成函数,提供无缝的 flash sparse attention 支持。 -#### flash_dynamic_mask_attention_forward +#### flash_sparse_attention_forward ```python -from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward +from flash_sparse_attn.integrations.flash_sparse_attention import flash_sparse_attention_forward -def flash_dynamic_mask_attention_forward( +def flash_sparse_attention_forward( module: torch.nn.Module, # 注意力模块 query: torch.Tensor, # (batch_size, num_heads, query_len, head_dim) key: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim) @@ -253,7 +252,7 @@ def flash_dynamic_mask_attention_forward( - is_causal: 是否应用因果掩码 - window_size: 保持的窗口大小 - layer_idx: 用于日志的层索引 - - implementation: 使用的实现("flash_dmattn" 或 None) + - implementation: 使用的实现("flash_sparse_attn" 或 None) #### 返回值 @@ -267,7 +266,7 @@ import torch.nn as nn import torch.nn.functional as F from typing import Optional, Callable, tuple from transformers.cache_utils import Cache -from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward +from flash_sparse_attn.integrations.flash_sparse_attention import flash_sparse_attention_forward class DynamicMaskAttention(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None): @@ -331,7 +330,7 @@ class DynamicMaskAttention(nn.Module): attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype) # 选择注意力实现 - attention_interface: Callable = flash_dynamic_mask_attention_forward + attention_interface: Callable = flash_sparse_attention_forward attn_output, attn_weights = attention_interface( self, @@ -361,7 +360,7 @@ class DynamicMaskAttention(nn.Module): ```python try: - from flash_dmattn import flash_dmattn_func_auto, get_available_backends + from flash_sparse_attn import flash_sparse_attn_func_auto, get_available_backends print("✅ 导入成功", get_available_backends()) except ImportError as e: print(f"❌ 导入失败: {e}") @@ -384,10 +383,10 @@ except ImportError as e: ```python import torch -from flash_dmattn import flash_dmattn_func_auto +from flash_sparse_attn import flash_sparse_attn_func_auto torch.autograd.set_detect_anomaly(True) -attn = flash_dmattn_func_auto() +attn = flash_sparse_attn_func_auto() output = attn(q, k, v, attn_mask=attn_mask, attn_bias=attn_bias, is_causal=True) if torch.isnan(output).any(): print("⚠️ 注意力输出中检测到 NaN") @@ -403,7 +402,7 @@ def print_memory_stats(): print(f"最大分配: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB") print_memory_stats() -attn = flash_dmattn_func_auto() +attn = flash_sparse_attn_func_auto() output = attn(q, k, v) print_memory_stats() ``` From a0ed87d916ec43046c4a527e2d8f65f1ed70bbac Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:36:20 +0800 Subject: [PATCH 23/29] Aligns benchmarks with sparse attn imports Updates benchmark integrations to load the flash_sparse_attn implementations so the renamed package continues to back the CUDA, Triton, and Flex runs. Renames the availability guards and status messages to keep diagnostic output aligned with the new module namespace. --- benchmarks/backward_equivalence.py | 42 +++++++++++++-------------- benchmarks/backward_performance.py | 36 +++++++++++------------ benchmarks/forward_equivalence.py | 46 +++++++++++++++--------------- benchmarks/forward_performance.py | 36 +++++++++++------------ 4 files changed, 80 insertions(+), 80 deletions(-) diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index 7c34ba2..a10da1e 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -21,33 +21,33 @@ # Import the compiled CUDA extension try: - from flash_dmattn.flash_dmattn_interface import flash_dmattn_func - print("✅ Successfully imported flash_dmattn interface") + from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn interface") except ImportError as e: - print(f"❌ Failed to import flash_dmattn interface: {e}") + print(f"❌ Failed to import flash_sparse_attn interface: {e}") print("Please make sure the package is properly installed with: pip install .") # Don't exit here, just warn - flash_dmattn_func = None + flash_sparse_attn_func = None # Import the Triton implementation try: - from flash_dmattn.flash_dmattn_triton import triton_dmattn_func - print("✅ Successfully imported flash_dmattn_triton") + from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_triton") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_triton: {e}") + print(f"❌ Failed to import flash_sparse_attn_triton: {e}") print("Please make sure the Triton implementation is available.") # Don't exit here, just warn - triton_dmattn_func = None + triton_sparse_attn_func = None # Import the Flex Attention implementation try: - from flash_dmattn.flash_dmattn_flex import flex_dmattn_func - print("✅ Successfully imported flash_dmattn_flex") + from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_flex") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_flex: {e}") + print(f"❌ Failed to import flash_sparse_attn_flex: {e}") print("Please make sure the Flex Attention implementation is available.") # Don't exit here, just warn - flex_dmattn_func = None + flex_sparse_attn_func = None def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -189,7 +189,7 @@ def dynamic_mask_attention_cuda( Returns: tuple: (attn_outputs, dq, dk, dv, dbias) """ - if flash_dmattn_func is None: + if flash_sparse_attn_func is None: raise ImportError("CUDA implementation not available") query_states_leaf = query_states @@ -210,8 +210,8 @@ def dynamic_mask_attention_cuda( key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - # Call the flash_dmattn_func interface - attn_outputs = flash_dmattn_func( + # Call the flash_sparse_attn_func interface + attn_outputs = flash_sparse_attn_func( query=query_states, key=key_states, value=value_states, @@ -256,7 +256,7 @@ def dynamic_mask_attention_triton( Returns: tuple: (attn_outputs, dq, dk, dv, dbias) """ - if triton_dmattn_func is None: + if triton_sparse_attn_func is None: raise RuntimeError("Triton implementation not available") _, num_heads, _, _ = query_states.shape @@ -288,7 +288,7 @@ def dynamic_mask_attention_triton( value_states = value_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim] # Call the Triton implementation - attn_outputs = triton_dmattn_func( + attn_outputs = triton_sparse_attn_func( query=query_states, key=key_states, value=value_states, @@ -330,7 +330,7 @@ def dynamic_mask_attention_flex( Returns: tuple: (attn_outputs, dq, dk, dv, dbias) """ - if flex_dmattn_func is None: + if flex_sparse_attn_func is None: raise RuntimeError("Flex Attention implementation not available") _, num_heads, _, _ = query_states.shape @@ -359,7 +359,7 @@ def dynamic_mask_attention_flex( attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] # Call the Flex Attention implementation - attn_outputs = flex_dmattn_func( + attn_outputs = flex_sparse_attn_func( query_states, key_states, value_states, @@ -474,7 +474,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): print("🚀" + "=" * 76 + "🚀") # Check if CUDA implementation is available - if flash_dmattn_func is None: + if flash_sparse_attn_func is None: print("❌ CUDA implementation not available, skipping test.") return False @@ -734,7 +734,7 @@ def test_triton_backward_equivalence(accuracy_threshold=0.95): print("🚀" + "=" * 76 + "🚀") # Check if Triton implementation is available - if triton_dmattn_func is None: + if triton_sparse_attn_func is None: print("❌ Triton implementation not available, skipping test.") return False diff --git a/benchmarks/backward_performance.py b/benchmarks/backward_performance.py index 82deb8c..59daf16 100644 --- a/benchmarks/backward_performance.py +++ b/benchmarks/backward_performance.py @@ -28,33 +28,33 @@ # Import the compiled CUDA extension try: - from flash_dmattn.flash_dmattn_interface import flash_dmattn_func - print("✅ Successfully imported flash_dmattn interface") + from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn interface") except ImportError as e: - print(f"❌ Failed to import flash_dmattn interface: {e}") + print(f"❌ Failed to import flash_sparse_attn interface: {e}") print("Please make sure the package is properly installed with: pip install .") # Don't exit here, just warn - flash_dmattn_func = None + flash_sparse_attn_func = None # Import the Triton implementation try: - from flash_dmattn.flash_dmattn_triton import triton_dmattn_func - print("✅ Successfully imported flash_dmattn_triton") + from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_triton") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_triton: {e}") + print(f"❌ Failed to import flash_sparse_attn_triton: {e}") print("Please make sure the Triton implementation is available.") # Don't exit here, just warn - triton_dmattn_func = None + triton_sparse_attn_func = None # Import the Flex Attention implementation try: - from flash_dmattn.flash_dmattn_flex import flex_dmattn_func - print("✅ Successfully imported flash_dmattn_flex") + from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_flex") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_flex: {e}") + print(f"❌ Failed to import flash_sparse_attn_flex: {e}") print("Please make sure the Flex Attention implementation is available.") # Don't exit here, just warn - flex_dmattn_func = None + flex_sparse_attn_func = None def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -207,7 +207,7 @@ def dynamic_mask_attention_backward_cuda( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - if flash_dmattn_func is None: + if flash_sparse_attn_func is None: return "Not Available", 0 attn_bias, attn_mask = prepare_mask( @@ -223,7 +223,7 @@ def dynamic_mask_attention_backward_cuda( value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] try: - attn_outputs = flash_dmattn_func( + attn_outputs = flash_sparse_attn_func( query=query_states, key=key_states, value=value_states, @@ -277,7 +277,7 @@ def dynamic_mask_attention_backward_triton( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - if triton_dmattn_func is None: + if triton_sparse_attn_func is None: return "Not Available", 0 _, num_heads, _, _ = query_states.shape @@ -305,7 +305,7 @@ def dynamic_mask_attention_backward_triton( attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] try: - attn_outputs = triton_dmattn_func( + attn_outputs = triton_sparse_attn_func( query=query_states, key=key_states, value=value_states, @@ -356,7 +356,7 @@ def dynamic_mask_attention_backward_flex( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - if flex_dmattn_func is None: + if flex_sparse_attn_func is None: return "Not Available", 0 _, num_heads, _, _ = query_states.shape @@ -384,7 +384,7 @@ def dynamic_mask_attention_backward_flex( attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] try: - attn_outputs = flex_dmattn_func( + attn_outputs = flex_sparse_attn_func( query_states, key_states, value_states, diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index 8baff70..9b05ba3 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -21,33 +21,33 @@ # Import the compiled CUDA extension try: - from flash_dmattn.flash_dmattn_interface import flash_dmattn_func - print("✅ Successfully imported flash_dmattn interface") + from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn interface") except ImportError as e: - print(f"❌ Failed to import flash_dmattn interface: {e}") + print(f"❌ Failed to import flash_sparse_attn interface: {e}") print("Please make sure the package is properly installed with: pip install .") # Don't exit here, just warn - flash_dmattn_func = None + flash_sparse_attn_func = None # Import the Triton implementation try: - from flash_dmattn.flash_dmattn_triton import triton_dmattn_func - print("✅ Successfully imported flash_dmattn_triton") + from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_triton") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_triton: {e}") + print(f"❌ Failed to import flash_sparse_attn_triton: {e}") print("Please make sure the Triton implementation is available.") # Don't exit here, just warn - triton_dmattn_func = None + triton_sparse_attn_func = None # Import the Flex Attention implementation try: - from flash_dmattn.flash_dmattn_flex import flex_dmattn_func - print("✅ Successfully imported flash_dmattn_flex") + from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_flex") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_flex: {e}") + print(f"❌ Failed to import flash_sparse_attn_flex: {e}") print("Please make sure the Flex Attention implementation is available.") # Don't exit here, just warn - flex_dmattn_func = None + flex_sparse_attn_func = None def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -181,8 +181,8 @@ def dynamic_mask_attention_cuda( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - if flash_dmattn_func is None: - raise RuntimeError("flash_dmattn_func not available") + if flash_sparse_attn_func is None: + raise RuntimeError("flash_sparse_attn_func not available") attn_bias, attn_mask = prepare_mask( query_states, @@ -196,8 +196,8 @@ def dynamic_mask_attention_cuda( key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] - # Call the flash_dmattn_func interface - attn_outputs = flash_dmattn_func( + # Call the flash_sparse_attn_func interface + attn_outputs = flash_sparse_attn_func( query_states, key_states, value_states, @@ -239,7 +239,7 @@ def dynamic_mask_attention_triton( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - if triton_dmattn_func is None: + if triton_sparse_attn_func is None: raise RuntimeError("Triton implementation not available") _, num_heads, _, _ = query_states.shape @@ -267,7 +267,7 @@ def dynamic_mask_attention_triton( attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] # Call the Triton implementation - attn_outputs = triton_dmattn_func( + attn_outputs = triton_sparse_attn_func( query_states, key_states, value_states, @@ -306,7 +306,7 @@ def dynamic_mask_attention_flex( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - if flex_dmattn_func is None: + if flex_sparse_attn_func is None: raise RuntimeError("Flex Attention implementation not available") _, num_heads, _, _ = query_states.shape @@ -334,7 +334,7 @@ def dynamic_mask_attention_flex( attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] # Call the Flex Attention implementation - attn_outputs = flex_dmattn_func( + attn_outputs = flex_sparse_attn_func( query_states, key_states, value_states, @@ -446,7 +446,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): print("🚀" + "=" * 76 + "🚀") # Check if CUDA implementation is available - if flash_dmattn_func is None: + if flash_sparse_attn_func is None: print("❌ CUDA implementation not available, skipping test.") return False @@ -653,7 +653,7 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): print("🔬 Testing Forward Pass Equivalence: Python vs Triton 🔬") print("🔥" + "=" * 76 + "🔥") - if triton_dmattn_func is None: + if triton_sparse_attn_func is None: print("❌ Triton implementation not available, skipping Triton tests") return False @@ -859,7 +859,7 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): print("🔬 Testing Forward Pass Equivalence: Python vs Flex Attention 🔬") print("🌟" + "=" * 76 + "🌟") - if flex_dmattn_func is None: + if flex_sparse_attn_func is None: print("❌ Flex Attention implementation not available, skipping Flex Attention tests") return False diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index 5730e0e..05e75c4 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -28,33 +28,33 @@ # Import the compiled CUDA extension try: - from flash_dmattn.flash_dmattn_interface import flash_dmattn_func - print("✅ Successfully imported flash_dmattn interface") + from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn interface") except ImportError as e: - print(f"❌ Failed to import flash_dmattn interface: {e}") + print(f"❌ Failed to import flash_sparse_attn interface: {e}") print("Please make sure the package is properly installed with: pip install .") # Don't exit here, just warn - flash_dmattn_func = None + flash_sparse_attn_func = None # Import the Triton implementation try: - from flash_dmattn.flash_dmattn_triton import triton_dmattn_func - print("✅ Successfully imported flash_dmattn_triton") + from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_triton") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_triton: {e}") + print(f"❌ Failed to import flash_sparse_attn_triton: {e}") print("Please make sure the Triton implementation is available.") # Don't exit here, just warn - triton_dmattn_func = None + triton_sparse_attn_func = None # Import the Flex Attention implementation try: - from flash_dmattn.flash_dmattn_flex import flex_dmattn_func - print("✅ Successfully imported flash_dmattn_flex") + from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func + print("✅ Successfully imported flash_sparse_attn_flex") except ImportError as e: - print(f"❌ Failed to import flash_dmattn_flex: {e}") + print(f"❌ Failed to import flash_sparse_attn_flex: {e}") print("Please make sure the Flex Attention implementation is available.") # Don't exit here, just warn - flex_dmattn_func = None + flex_sparse_attn_func = None def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -203,7 +203,7 @@ def dynamic_mask_attention_cuda( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - if flash_dmattn_func is None: + if flash_sparse_attn_func is None: return "Not Available", 0 attn_bias, attn_mask = prepare_mask( @@ -222,7 +222,7 @@ def dynamic_mask_attention_cuda( torch.cuda.synchronize() start_time = time.time() - attn_outputs = flash_dmattn_func( + attn_outputs = flash_sparse_attn_func( query_states, key_states, value_states, @@ -269,7 +269,7 @@ def dynamic_mask_attention_triton( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - if triton_dmattn_func is None: + if triton_sparse_attn_func is None: return "Not Available", 0 _, num_heads, _, _ = query_states.shape @@ -300,7 +300,7 @@ def dynamic_mask_attention_triton( torch.cuda.synchronize() start_time = time.time() - attn_outputs = triton_dmattn_func( + attn_outputs = triton_sparse_attn_func( query_states, key_states, value_states, @@ -344,7 +344,7 @@ def dynamic_mask_attention_flex( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - if flex_dmattn_func is None: + if flex_sparse_attn_func is None: return "Not Available", 0 _, num_heads, _, _ = query_states.shape @@ -376,7 +376,7 @@ def dynamic_mask_attention_flex( start_time = time.time() # Call the Flex Attention implementation - attn_outputs = flex_dmattn_func( + attn_outputs = flex_sparse_attn_func( query_states, key_states, value_states, From b3ac56f400b94524a466f8cb46d2023bf8e2d30e Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:41:01 +0800 Subject: [PATCH 24/29] Renames flash attention variant Updates the sparse attention backend to drop the old dynamic mask name so future errors and docs consistently refer to FlashSparseAttention. --- .../flash_api.cpp | 34 +++++++++---------- .../src/block_info.h | 0 .../src/flash.h | 0 .../src/flash_bwd_kernel.h | 0 .../src/flash_bwd_launch_template.h | 2 +- .../src/flash_bwd_preprocess_kernel.h | 0 .../src/flash_fwd_kernel.h | 0 .../src/flash_fwd_launch_template.h | 2 +- .../src/generate_kernels.py | 0 .../src/hardware_info.h | 0 ...h_bwd_hdim128_bf16_causal_has_bias_sm80.cu | 0 ...m128_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...h_bwd_hdim128_bf16_causal_has_mask_sm80.cu | 0 .../flash_bwd_hdim128_bf16_causal_sm80.cu | 0 .../flash_bwd_hdim128_bf16_has_bias_sm80.cu | 0 ...bwd_hdim128_bf16_has_mask_has_bias_sm80.cu | 0 .../flash_bwd_hdim128_bf16_has_mask_sm80.cu | 0 .../flash_bwd_hdim128_bf16_sm80.cu | 0 ...h_bwd_hdim128_fp16_causal_has_bias_sm80.cu | 0 ...m128_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...h_bwd_hdim128_fp16_causal_has_mask_sm80.cu | 0 .../flash_bwd_hdim128_fp16_causal_sm80.cu | 0 .../flash_bwd_hdim128_fp16_has_bias_sm80.cu | 0 ...bwd_hdim128_fp16_has_mask_has_bias_sm80.cu | 0 .../flash_bwd_hdim128_fp16_has_mask_sm80.cu | 0 .../flash_bwd_hdim128_fp16_sm80.cu | 0 ...h_bwd_hdim192_bf16_causal_has_bias_sm80.cu | 0 ...m192_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...h_bwd_hdim192_bf16_causal_has_mask_sm80.cu | 0 .../flash_bwd_hdim192_bf16_causal_sm80.cu | 0 .../flash_bwd_hdim192_bf16_has_bias_sm80.cu | 0 ...bwd_hdim192_bf16_has_mask_has_bias_sm80.cu | 0 .../flash_bwd_hdim192_bf16_has_mask_sm80.cu | 0 .../flash_bwd_hdim192_bf16_sm80.cu | 0 ...h_bwd_hdim192_fp16_causal_has_bias_sm80.cu | 0 ...m192_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...h_bwd_hdim192_fp16_causal_has_mask_sm80.cu | 0 .../flash_bwd_hdim192_fp16_causal_sm80.cu | 0 .../flash_bwd_hdim192_fp16_has_bias_sm80.cu | 0 ...bwd_hdim192_fp16_has_mask_has_bias_sm80.cu | 0 .../flash_bwd_hdim192_fp16_has_mask_sm80.cu | 0 .../flash_bwd_hdim192_fp16_sm80.cu | 0 ...h_bwd_hdim256_bf16_causal_has_bias_sm80.cu | 0 ...m256_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...h_bwd_hdim256_bf16_causal_has_mask_sm80.cu | 0 .../flash_bwd_hdim256_bf16_causal_sm80.cu | 0 .../flash_bwd_hdim256_bf16_has_bias_sm80.cu | 0 ...bwd_hdim256_bf16_has_mask_has_bias_sm80.cu | 0 .../flash_bwd_hdim256_bf16_has_mask_sm80.cu | 0 .../flash_bwd_hdim256_bf16_sm80.cu | 0 ...h_bwd_hdim256_fp16_causal_has_bias_sm80.cu | 0 ...m256_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...h_bwd_hdim256_fp16_causal_has_mask_sm80.cu | 0 .../flash_bwd_hdim256_fp16_causal_sm80.cu | 0 .../flash_bwd_hdim256_fp16_has_bias_sm80.cu | 0 ...bwd_hdim256_fp16_has_mask_has_bias_sm80.cu | 0 .../flash_bwd_hdim256_fp16_has_mask_sm80.cu | 0 .../flash_bwd_hdim256_fp16_sm80.cu | 0 ...sh_bwd_hdim32_bf16_causal_has_bias_sm80.cu | 0 ...im32_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...sh_bwd_hdim32_bf16_causal_has_mask_sm80.cu | 0 .../flash_bwd_hdim32_bf16_causal_sm80.cu | 0 .../flash_bwd_hdim32_bf16_has_bias_sm80.cu | 0 ..._bwd_hdim32_bf16_has_mask_has_bias_sm80.cu | 0 .../flash_bwd_hdim32_bf16_has_mask_sm80.cu | 0 .../flash_bwd_hdim32_bf16_sm80.cu | 0 ...sh_bwd_hdim32_fp16_causal_has_bias_sm80.cu | 0 ...im32_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...sh_bwd_hdim32_fp16_causal_has_mask_sm80.cu | 0 .../flash_bwd_hdim32_fp16_causal_sm80.cu | 0 .../flash_bwd_hdim32_fp16_has_bias_sm80.cu | 0 ..._bwd_hdim32_fp16_has_mask_has_bias_sm80.cu | 0 .../flash_bwd_hdim32_fp16_has_mask_sm80.cu | 0 .../flash_bwd_hdim32_fp16_sm80.cu | 0 ...sh_bwd_hdim64_bf16_causal_has_bias_sm80.cu | 0 ...im64_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...sh_bwd_hdim64_bf16_causal_has_mask_sm80.cu | 0 .../flash_bwd_hdim64_bf16_causal_sm80.cu | 0 .../flash_bwd_hdim64_bf16_has_bias_sm80.cu | 0 ..._bwd_hdim64_bf16_has_mask_has_bias_sm80.cu | 0 .../flash_bwd_hdim64_bf16_has_mask_sm80.cu | 0 .../flash_bwd_hdim64_bf16_sm80.cu | 0 ...sh_bwd_hdim64_fp16_causal_has_bias_sm80.cu | 0 ...im64_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...sh_bwd_hdim64_fp16_causal_has_mask_sm80.cu | 0 .../flash_bwd_hdim64_fp16_causal_sm80.cu | 0 .../flash_bwd_hdim64_fp16_has_bias_sm80.cu | 0 ..._bwd_hdim64_fp16_has_mask_has_bias_sm80.cu | 0 .../flash_bwd_hdim64_fp16_has_mask_sm80.cu | 0 .../flash_bwd_hdim64_fp16_sm80.cu | 0 ...sh_bwd_hdim96_bf16_causal_has_bias_sm80.cu | 0 ...im96_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...sh_bwd_hdim96_bf16_causal_has_mask_sm80.cu | 0 .../flash_bwd_hdim96_bf16_causal_sm80.cu | 0 .../flash_bwd_hdim96_bf16_has_bias_sm80.cu | 0 ..._bwd_hdim96_bf16_has_mask_has_bias_sm80.cu | 0 .../flash_bwd_hdim96_bf16_has_mask_sm80.cu | 0 .../flash_bwd_hdim96_bf16_sm80.cu | 0 ...sh_bwd_hdim96_fp16_causal_has_bias_sm80.cu | 0 ...im96_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...sh_bwd_hdim96_fp16_causal_has_mask_sm80.cu | 0 .../flash_bwd_hdim96_fp16_causal_sm80.cu | 0 .../flash_bwd_hdim96_fp16_has_bias_sm80.cu | 0 ..._bwd_hdim96_fp16_has_mask_has_bias_sm80.cu | 0 .../flash_bwd_hdim96_fp16_has_mask_sm80.cu | 0 .../flash_bwd_hdim96_fp16_sm80.cu | 0 ...h_fwd_hdim128_bf16_causal_has_bias_sm80.cu | 0 ...m128_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...h_fwd_hdim128_bf16_causal_has_mask_sm80.cu | 0 .../flash_fwd_hdim128_bf16_causal_sm80.cu | 0 .../flash_fwd_hdim128_bf16_has_bias_sm80.cu | 0 ...fwd_hdim128_bf16_has_mask_has_bias_sm80.cu | 0 .../flash_fwd_hdim128_bf16_has_mask_sm80.cu | 0 .../flash_fwd_hdim128_bf16_sm80.cu | 0 ...h_fwd_hdim128_fp16_causal_has_bias_sm80.cu | 0 ...m128_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...h_fwd_hdim128_fp16_causal_has_mask_sm80.cu | 0 .../flash_fwd_hdim128_fp16_causal_sm80.cu | 0 .../flash_fwd_hdim128_fp16_has_bias_sm80.cu | 0 ...fwd_hdim128_fp16_has_mask_has_bias_sm80.cu | 0 .../flash_fwd_hdim128_fp16_has_mask_sm80.cu | 0 .../flash_fwd_hdim128_fp16_sm80.cu | 0 ...h_fwd_hdim192_bf16_causal_has_bias_sm80.cu | 0 ...m192_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...h_fwd_hdim192_bf16_causal_has_mask_sm80.cu | 0 .../flash_fwd_hdim192_bf16_causal_sm80.cu | 0 .../flash_fwd_hdim192_bf16_has_bias_sm80.cu | 0 ...fwd_hdim192_bf16_has_mask_has_bias_sm80.cu | 0 .../flash_fwd_hdim192_bf16_has_mask_sm80.cu | 0 .../flash_fwd_hdim192_bf16_sm80.cu | 0 ...h_fwd_hdim192_fp16_causal_has_bias_sm80.cu | 0 ...m192_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...h_fwd_hdim192_fp16_causal_has_mask_sm80.cu | 0 .../flash_fwd_hdim192_fp16_causal_sm80.cu | 0 .../flash_fwd_hdim192_fp16_has_bias_sm80.cu | 0 ...fwd_hdim192_fp16_has_mask_has_bias_sm80.cu | 0 .../flash_fwd_hdim192_fp16_has_mask_sm80.cu | 0 .../flash_fwd_hdim192_fp16_sm80.cu | 0 ...h_fwd_hdim256_bf16_causal_has_bias_sm80.cu | 0 ...m256_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...h_fwd_hdim256_bf16_causal_has_mask_sm80.cu | 0 .../flash_fwd_hdim256_bf16_causal_sm80.cu | 0 .../flash_fwd_hdim256_bf16_has_bias_sm80.cu | 0 ...fwd_hdim256_bf16_has_mask_has_bias_sm80.cu | 0 .../flash_fwd_hdim256_bf16_has_mask_sm80.cu | 0 .../flash_fwd_hdim256_bf16_sm80.cu | 0 ...h_fwd_hdim256_fp16_causal_has_bias_sm80.cu | 0 ...m256_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...h_fwd_hdim256_fp16_causal_has_mask_sm80.cu | 0 .../flash_fwd_hdim256_fp16_causal_sm80.cu | 0 .../flash_fwd_hdim256_fp16_has_bias_sm80.cu | 0 ...fwd_hdim256_fp16_has_mask_has_bias_sm80.cu | 0 .../flash_fwd_hdim256_fp16_has_mask_sm80.cu | 0 .../flash_fwd_hdim256_fp16_sm80.cu | 0 ...sh_fwd_hdim32_bf16_causal_has_bias_sm80.cu | 0 ...im32_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...sh_fwd_hdim32_bf16_causal_has_mask_sm80.cu | 0 .../flash_fwd_hdim32_bf16_causal_sm80.cu | 0 .../flash_fwd_hdim32_bf16_has_bias_sm80.cu | 0 ..._fwd_hdim32_bf16_has_mask_has_bias_sm80.cu | 0 .../flash_fwd_hdim32_bf16_has_mask_sm80.cu | 0 .../flash_fwd_hdim32_bf16_sm80.cu | 0 ...sh_fwd_hdim32_fp16_causal_has_bias_sm80.cu | 0 ...im32_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...sh_fwd_hdim32_fp16_causal_has_mask_sm80.cu | 0 .../flash_fwd_hdim32_fp16_causal_sm80.cu | 0 .../flash_fwd_hdim32_fp16_has_bias_sm80.cu | 0 ..._fwd_hdim32_fp16_has_mask_has_bias_sm80.cu | 0 .../flash_fwd_hdim32_fp16_has_mask_sm80.cu | 0 .../flash_fwd_hdim32_fp16_sm80.cu | 0 ...sh_fwd_hdim64_bf16_causal_has_bias_sm80.cu | 0 ...im64_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...sh_fwd_hdim64_bf16_causal_has_mask_sm80.cu | 0 .../flash_fwd_hdim64_bf16_causal_sm80.cu | 0 .../flash_fwd_hdim64_bf16_has_bias_sm80.cu | 0 ..._fwd_hdim64_bf16_has_mask_has_bias_sm80.cu | 0 .../flash_fwd_hdim64_bf16_has_mask_sm80.cu | 0 .../flash_fwd_hdim64_bf16_sm80.cu | 0 ...sh_fwd_hdim64_fp16_causal_has_bias_sm80.cu | 0 ...im64_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...sh_fwd_hdim64_fp16_causal_has_mask_sm80.cu | 0 .../flash_fwd_hdim64_fp16_causal_sm80.cu | 0 .../flash_fwd_hdim64_fp16_has_bias_sm80.cu | 0 ..._fwd_hdim64_fp16_has_mask_has_bias_sm80.cu | 0 .../flash_fwd_hdim64_fp16_has_mask_sm80.cu | 0 .../flash_fwd_hdim64_fp16_sm80.cu | 0 ...sh_fwd_hdim96_bf16_causal_has_bias_sm80.cu | 0 ...im96_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...sh_fwd_hdim96_bf16_causal_has_mask_sm80.cu | 0 .../flash_fwd_hdim96_bf16_causal_sm80.cu | 0 .../flash_fwd_hdim96_bf16_has_bias_sm80.cu | 0 ..._fwd_hdim96_bf16_has_mask_has_bias_sm80.cu | 0 .../flash_fwd_hdim96_bf16_has_mask_sm80.cu | 0 .../flash_fwd_hdim96_bf16_sm80.cu | 0 ...sh_fwd_hdim96_fp16_causal_has_bias_sm80.cu | 0 ...im96_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...sh_fwd_hdim96_fp16_causal_has_mask_sm80.cu | 0 .../flash_fwd_hdim96_fp16_causal_sm80.cu | 0 .../flash_fwd_hdim96_fp16_has_bias_sm80.cu | 0 ..._fwd_hdim96_fp16_has_mask_has_bias_sm80.cu | 0 .../flash_fwd_hdim96_fp16_has_mask_sm80.cu | 0 .../flash_fwd_hdim96_fp16_sm80.cu | 0 ...split_hdim128_bf16_causal_has_bias_sm80.cu | 0 ...m128_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...split_hdim128_bf16_causal_has_mask_sm80.cu | 0 ...lash_fwd_split_hdim128_bf16_causal_sm80.cu | 0 ...sh_fwd_split_hdim128_bf16_has_bias_sm80.cu | 0 ...lit_hdim128_bf16_has_mask_has_bias_sm80.cu | 0 ...sh_fwd_split_hdim128_bf16_has_mask_sm80.cu | 0 .../flash_fwd_split_hdim128_bf16_sm80.cu | 0 ...split_hdim128_fp16_causal_has_bias_sm80.cu | 0 ...m128_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...split_hdim128_fp16_causal_has_mask_sm80.cu | 0 ...lash_fwd_split_hdim128_fp16_causal_sm80.cu | 0 ...sh_fwd_split_hdim128_fp16_has_bias_sm80.cu | 0 ...lit_hdim128_fp16_has_mask_has_bias_sm80.cu | 0 ...sh_fwd_split_hdim128_fp16_has_mask_sm80.cu | 0 .../flash_fwd_split_hdim128_fp16_sm80.cu | 0 ...split_hdim192_bf16_causal_has_bias_sm80.cu | 0 ...m192_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...split_hdim192_bf16_causal_has_mask_sm80.cu | 0 ...lash_fwd_split_hdim192_bf16_causal_sm80.cu | 0 ...sh_fwd_split_hdim192_bf16_has_bias_sm80.cu | 0 ...lit_hdim192_bf16_has_mask_has_bias_sm80.cu | 0 ...sh_fwd_split_hdim192_bf16_has_mask_sm80.cu | 0 .../flash_fwd_split_hdim192_bf16_sm80.cu | 0 ...split_hdim192_fp16_causal_has_bias_sm80.cu | 0 ...m192_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...split_hdim192_fp16_causal_has_mask_sm80.cu | 0 ...lash_fwd_split_hdim192_fp16_causal_sm80.cu | 0 ...sh_fwd_split_hdim192_fp16_has_bias_sm80.cu | 0 ...lit_hdim192_fp16_has_mask_has_bias_sm80.cu | 0 ...sh_fwd_split_hdim192_fp16_has_mask_sm80.cu | 0 .../flash_fwd_split_hdim192_fp16_sm80.cu | 0 ...split_hdim256_bf16_causal_has_bias_sm80.cu | 0 ...m256_bf16_causal_has_mask_has_bias_sm80.cu | 0 ...split_hdim256_bf16_causal_has_mask_sm80.cu | 0 ...lash_fwd_split_hdim256_bf16_causal_sm80.cu | 0 ...sh_fwd_split_hdim256_bf16_has_bias_sm80.cu | 0 ...lit_hdim256_bf16_has_mask_has_bias_sm80.cu | 0 ...sh_fwd_split_hdim256_bf16_has_mask_sm80.cu | 0 .../flash_fwd_split_hdim256_bf16_sm80.cu | 0 ...split_hdim256_fp16_causal_has_bias_sm80.cu | 0 ...m256_fp16_causal_has_mask_has_bias_sm80.cu | 0 ...split_hdim256_fp16_causal_has_mask_sm80.cu | 0 ...lash_fwd_split_hdim256_fp16_causal_sm80.cu | 0 ...sh_fwd_split_hdim256_fp16_has_bias_sm80.cu | 0 ...lit_hdim256_fp16_has_mask_has_bias_sm80.cu | 0 ...sh_fwd_split_hdim256_fp16_has_mask_sm80.cu | 0 .../flash_fwd_split_hdim256_fp16_sm80.cu | 0 ..._split_hdim32_bf16_causal_has_bias_sm80.cu | 0 ...im32_bf16_causal_has_mask_has_bias_sm80.cu | 0 ..._split_hdim32_bf16_causal_has_mask_sm80.cu | 0 ...flash_fwd_split_hdim32_bf16_causal_sm80.cu | 0 ...ash_fwd_split_hdim32_bf16_has_bias_sm80.cu | 0 ...plit_hdim32_bf16_has_mask_has_bias_sm80.cu | 0 ...ash_fwd_split_hdim32_bf16_has_mask_sm80.cu | 0 .../flash_fwd_split_hdim32_bf16_sm80.cu | 0 ..._split_hdim32_fp16_causal_has_bias_sm80.cu | 0 ...im32_fp16_causal_has_mask_has_bias_sm80.cu | 0 ..._split_hdim32_fp16_causal_has_mask_sm80.cu | 0 ...flash_fwd_split_hdim32_fp16_causal_sm80.cu | 0 ...ash_fwd_split_hdim32_fp16_has_bias_sm80.cu | 0 ...plit_hdim32_fp16_has_mask_has_bias_sm80.cu | 0 ...ash_fwd_split_hdim32_fp16_has_mask_sm80.cu | 0 .../flash_fwd_split_hdim32_fp16_sm80.cu | 0 ..._split_hdim64_bf16_causal_has_bias_sm80.cu | 0 ...im64_bf16_causal_has_mask_has_bias_sm80.cu | 0 ..._split_hdim64_bf16_causal_has_mask_sm80.cu | 0 ...flash_fwd_split_hdim64_bf16_causal_sm80.cu | 0 ...ash_fwd_split_hdim64_bf16_has_bias_sm80.cu | 0 ...plit_hdim64_bf16_has_mask_has_bias_sm80.cu | 0 ...ash_fwd_split_hdim64_bf16_has_mask_sm80.cu | 0 .../flash_fwd_split_hdim64_bf16_sm80.cu | 0 ..._split_hdim64_fp16_causal_has_bias_sm80.cu | 0 ...im64_fp16_causal_has_mask_has_bias_sm80.cu | 0 ..._split_hdim64_fp16_causal_has_mask_sm80.cu | 0 ...flash_fwd_split_hdim64_fp16_causal_sm80.cu | 0 ...ash_fwd_split_hdim64_fp16_has_bias_sm80.cu | 0 ...plit_hdim64_fp16_has_mask_has_bias_sm80.cu | 0 ...ash_fwd_split_hdim64_fp16_has_mask_sm80.cu | 0 .../flash_fwd_split_hdim64_fp16_sm80.cu | 0 ..._split_hdim96_bf16_causal_has_bias_sm80.cu | 0 ...im96_bf16_causal_has_mask_has_bias_sm80.cu | 0 ..._split_hdim96_bf16_causal_has_mask_sm80.cu | 0 ...flash_fwd_split_hdim96_bf16_causal_sm80.cu | 0 ...ash_fwd_split_hdim96_bf16_has_bias_sm80.cu | 0 ...plit_hdim96_bf16_has_mask_has_bias_sm80.cu | 0 ...ash_fwd_split_hdim96_bf16_has_mask_sm80.cu | 0 .../flash_fwd_split_hdim96_bf16_sm80.cu | 0 ..._split_hdim96_fp16_causal_has_bias_sm80.cu | 0 ...im96_fp16_causal_has_mask_has_bias_sm80.cu | 0 ..._split_hdim96_fp16_causal_has_mask_sm80.cu | 0 ...flash_fwd_split_hdim96_fp16_causal_sm80.cu | 0 ...ash_fwd_split_hdim96_fp16_has_bias_sm80.cu | 0 ...plit_hdim96_fp16_has_mask_has_bias_sm80.cu | 0 ...ash_fwd_split_hdim96_fp16_has_mask_sm80.cu | 0 .../flash_fwd_split_hdim96_fp16_sm80.cu | 0 .../src/kernel_traits.h | 0 .../src/mask.h | 0 .../src/namespace_config.h | 0 .../src/softmax.h | 0 .../src/static_switch.h | 0 .../src/utils.h | 0 304 files changed, 19 insertions(+), 19 deletions(-) rename csrc/{flash_dmattn => flash_sparse_attn}/flash_api.cpp (97%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/block_info.h (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/flash.h (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/flash_bwd_kernel.h (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/flash_bwd_launch_template.h (98%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/flash_bwd_preprocess_kernel.h (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/flash_fwd_kernel.h (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/flash_fwd_launch_template.h (99%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/generate_kernels.py (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/hardware_info.h (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/kernel_traits.h (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/mask.h (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/namespace_config.h (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/softmax.h (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/static_switch.h (100%) rename csrc/{flash_dmattn => flash_sparse_attn}/src/utils.h (100%) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_sparse_attn/flash_api.cpp similarity index 97% rename from csrc/flash_dmattn/flash_api.cpp rename to csrc/flash_sparse_attn/flash_api.cpp index 4a67ec1..2bfa20f 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_sparse_attn/flash_api.cpp @@ -126,7 +126,7 @@ void set_params_fprop( // Set the different scale values. #ifdef FLASHATTENTION_DISABLE_SOFTCAP - TORCH_CHECK(softcap <= 0.0, "This flash dynamic mask attention build does not support softcap."); + TORCH_CHECK(softcap <= 0.0, "This flash sparse attention build does not support softcap."); #endif if (softcap > 0.0) { params.softcap = softmax_scale / softcap; @@ -145,7 +145,7 @@ void set_params_fprop( params.is_seqlens_k_cumulative = true; #ifdef FLASHATTENTION_DISABLE_UNEVEN_K - TORCH_CHECK(d == d_rounded, "This flash dynamic mask attention build does not support headdim not being a multiple of 32."); + TORCH_CHECK(d == d_rounded, "This flash sparse attention build does not support headdim not being a multiple of 32."); #endif params.unpadded_lse = unpadded_lse; @@ -366,10 +366,10 @@ mha_fwd( at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + TORCH_CHECK(is_sm8x_min, "FlashSparseAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashSparseAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -420,7 +420,7 @@ mha_fwd( const int seqlen_k_rounded = round_multiple(seqlen_k, 128); TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size <= 256, "FlashSparseAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -577,10 +577,10 @@ mha_varlen_fwd( at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + TORCH_CHECK(is_sm8x_min, "FlashSparseAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashSparseAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); @@ -644,7 +644,7 @@ mha_varlen_fwd( const int total_q = q.sizes()[0]; TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size <= 256, "FlashSparseAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -810,19 +810,19 @@ mha_bwd( ) { #ifdef FLASHATTENTION_DISABLE_BACKWARD - TORCH_CHECK(false, "This flash dynamic mask attention build does not support backward."); + TORCH_CHECK(false, "This flash sparse attention build does not support backward."); #endif // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + TORCH_CHECK(is_sm8x_min, "FlashSparseAttention only supports Ampere GPUs or newer."); auto stream = at::cuda::getCurrentCUDAStream().stream(); auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashSparseAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); @@ -881,7 +881,7 @@ mha_bwd( TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); + TORCH_CHECK(head_size <= 256, "FlashSparseAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (has_mask) { @@ -1072,19 +1072,19 @@ mha_varlen_bwd( ) { #ifdef FLASHATTENTION_DISABLE_BACKWARD - TORCH_CHECK(false, "This flash dynamic mask attention build does not support backward."); + TORCH_CHECK(false, "This flash sparse attention build does not support backward."); #endif // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + TORCH_CHECK(is_sm8x_min, "FlashSparseAttention only supports Ampere GPUs or newer."); auto stream = at::cuda::getCurrentCUDAStream().stream(); auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashSparseAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); @@ -1124,7 +1124,7 @@ mha_varlen_bwd( const int num_heads_bias = has_bias ? bias.size(1) : 1; TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); + TORCH_CHECK(head_size <= 256, "FlashSparseAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; @@ -1268,7 +1268,7 @@ mha_varlen_bwd( } // namespace FLASH_NAMESPACE PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashDynamicMaskAttention"; + m.doc() = "FlashSparseAttention"; m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length"); m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); diff --git a/csrc/flash_dmattn/src/block_info.h b/csrc/flash_sparse_attn/src/block_info.h similarity index 100% rename from csrc/flash_dmattn/src/block_info.h rename to csrc/flash_sparse_attn/src/block_info.h diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_sparse_attn/src/flash.h similarity index 100% rename from csrc/flash_dmattn/src/flash.h rename to csrc/flash_sparse_attn/src/flash.h diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_sparse_attn/src/flash_bwd_kernel.h similarity index 100% rename from csrc/flash_dmattn/src/flash_bwd_kernel.h rename to csrc/flash_sparse_attn/src/flash_bwd_kernel.h diff --git a/csrc/flash_dmattn/src/flash_bwd_launch_template.h b/csrc/flash_sparse_attn/src/flash_bwd_launch_template.h similarity index 98% rename from csrc/flash_dmattn/src/flash_bwd_launch_template.h rename to csrc/flash_sparse_attn/src/flash_bwd_launch_template.h index 00712b8..a6a3717 100644 --- a/csrc/flash_dmattn/src/flash_bwd_launch_template.h +++ b/csrc/flash_sparse_attn/src/flash_bwd_launch_template.h @@ -24,7 +24,7 @@ namespace FLASH_NAMESPACE { #endif // Define a macro for unsupported architecture handling to centralize the error message -#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashDynamicMaskAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashSparseAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); // Use a macro to clean up kernel definitions #define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \ diff --git a/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_sparse_attn/src/flash_bwd_preprocess_kernel.h similarity index 100% rename from csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h rename to csrc/flash_sparse_attn/src/flash_bwd_preprocess_kernel.h diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_sparse_attn/src/flash_fwd_kernel.h similarity index 100% rename from csrc/flash_dmattn/src/flash_fwd_kernel.h rename to csrc/flash_sparse_attn/src/flash_fwd_kernel.h diff --git a/csrc/flash_dmattn/src/flash_fwd_launch_template.h b/csrc/flash_sparse_attn/src/flash_fwd_launch_template.h similarity index 99% rename from csrc/flash_dmattn/src/flash_fwd_launch_template.h rename to csrc/flash_sparse_attn/src/flash_fwd_launch_template.h index 9c3d94b..412db39 100644 --- a/csrc/flash_dmattn/src/flash_fwd_launch_template.h +++ b/csrc/flash_sparse_attn/src/flash_fwd_launch_template.h @@ -23,7 +23,7 @@ namespace FLASH_NAMESPACE { #endif // Define a macro for unsupported architecture handling to centralize the error message -#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashDynamicMaskAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashSparseAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); // Use a macro to clean up kernel definitions #define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ diff --git a/csrc/flash_dmattn/src/generate_kernels.py b/csrc/flash_sparse_attn/src/generate_kernels.py similarity index 100% rename from csrc/flash_dmattn/src/generate_kernels.py rename to csrc/flash_sparse_attn/src/generate_kernels.py diff --git a/csrc/flash_dmattn/src/hardware_info.h b/csrc/flash_sparse_attn/src/hardware_info.h similarity index 100% rename from csrc/flash_dmattn/src/hardware_info.h rename to csrc/flash_sparse_attn/src/hardware_info.h diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu b/csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu similarity index 100% rename from csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu rename to csrc/flash_sparse_attn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu diff --git a/csrc/flash_dmattn/src/kernel_traits.h b/csrc/flash_sparse_attn/src/kernel_traits.h similarity index 100% rename from csrc/flash_dmattn/src/kernel_traits.h rename to csrc/flash_sparse_attn/src/kernel_traits.h diff --git a/csrc/flash_dmattn/src/mask.h b/csrc/flash_sparse_attn/src/mask.h similarity index 100% rename from csrc/flash_dmattn/src/mask.h rename to csrc/flash_sparse_attn/src/mask.h diff --git a/csrc/flash_dmattn/src/namespace_config.h b/csrc/flash_sparse_attn/src/namespace_config.h similarity index 100% rename from csrc/flash_dmattn/src/namespace_config.h rename to csrc/flash_sparse_attn/src/namespace_config.h diff --git a/csrc/flash_dmattn/src/softmax.h b/csrc/flash_sparse_attn/src/softmax.h similarity index 100% rename from csrc/flash_dmattn/src/softmax.h rename to csrc/flash_sparse_attn/src/softmax.h diff --git a/csrc/flash_dmattn/src/static_switch.h b/csrc/flash_sparse_attn/src/static_switch.h similarity index 100% rename from csrc/flash_dmattn/src/static_switch.h rename to csrc/flash_sparse_attn/src/static_switch.h diff --git a/csrc/flash_dmattn/src/utils.h b/csrc/flash_sparse_attn/src/utils.h similarity index 100% rename from csrc/flash_dmattn/src/utils.h rename to csrc/flash_sparse_attn/src/utils.h From 92c0fad06b91624a37fe0a6c18edc5da6f79259d Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:43:01 +0800 Subject: [PATCH 25/29] Aligns issue templates with FSA Maintains naming consistency after the FSA rebrand. --- .github/ISSUE_TEMPLATE/bug_report.md | 4 ++-- .github/ISSUE_TEMPLATE/bug_report.yml | 2 +- .github/ISSUE_TEMPLATE/feature_request.md | 4 ++-- .github/ISSUE_TEMPLATE/feature_request.yml | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 46b7d9f..bcc17cf 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -1,6 +1,6 @@ --- name: Bug report -about: Create a report to help us improve Flash-DMA +about: Create a report to help us improve FSA title: '[BUG REPORT] ' labels: ["bug"] assignees: @@ -39,7 +39,7 @@ python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: { **Additional context** - OS: [e.g. Ubuntu 20.04, Windows 10, macOS 12] - Python version: [e.g. 3.9.7] -- Flash-DMA version: [e.g. 0.1.0] +- FSA version: [e.g. 0.1.0] - CUDA Compute Capability: [e.g. 8.6] **Error traceback** diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index dfb007e..65c74de 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,5 +1,5 @@ name: Bug report -description: Create a report to help us improve Flash-DMA +description: Create a report to help us improve FSA title: "[BUG REPORT] " labels: - bug diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 2816d8f..1db7b8e 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -1,6 +1,6 @@ --- name: Feature request -about: Suggest an idea for Flash-DMA +about: Suggest an idea for FSA title: '[FEATURE REQUEST] ' labels: ["feature"] assignees: @@ -44,4 +44,4 @@ Add any other context or screenshots about the feature request here. If this feature is inspired by a paper or existing implementation, please provide: - Link to paper/implementation - Brief explanation of the technique -- Why it would be valuable for Flash-DMA users +- Why it would be valuable for FSA users diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 46a7d39..8ac591e 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -1,5 +1,5 @@ name: Feature request -description: Suggest an idea for FDMA +description: Suggest an idea for FSA title: "[FEATURE REQUEST] " labels: - feature @@ -16,7 +16,7 @@ body: - type: markdown attributes: value: | - Help us understand the feature you are proposing and why it matters for Flash-DMA workflows. + Help us understand the feature you are proposing and why it matters for FSA workflows. - type: textarea id: problem attributes: From 4aa0153fb7c25a1e1f0fe890c7d92c78462cc248 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:48:22 +0800 Subject: [PATCH 26/29] Renames environment variables for sparse attention build configuration --- .github/workflows/_build.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml index 8ad6f3f..80e09a9 100644 --- a/.github/workflows/_build.yml +++ b/.github/workflows/_build.yml @@ -172,12 +172,12 @@ jobs: export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2) export NVCC_THREADS=2 - export FLASH_DMATTN_FORCE_BUILD="TRUE" - export FLASH_DMATTN_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }} + export FLASH_SPARSE_ATTENTION_FORCE_BUILD="TRUE" + export FLASH_SPARSE_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }} # If specified, limit to a single compute capability to speed up build if [ -n "${MATRIX_ARCH}" ]; then - export FLASH_DMATTN_CUDA_ARCHS="${MATRIX_ARCH}" + export FLASH_SPARSE_ATTENTION_CUDA_ARCHS="${MATRIX_ARCH}" fi # GH allows max 6h From fc64149492703925743156b602c6f5f019c39794 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:48:37 +0800 Subject: [PATCH 27/29] Renames the kernel generation script description to reflect sparse attention context --- csrc/flash_sparse_attn/src/generate_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_sparse_attn/src/generate_kernels.py b/csrc/flash_sparse_attn/src/generate_kernels.py index 54d5a72..3626243 100644 --- a/csrc/flash_sparse_attn/src/generate_kernels.py +++ b/csrc/flash_sparse_attn/src/generate_kernels.py @@ -113,7 +113,7 @@ def main(output_dir: Optional[str]) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate_kernels", - description="Generate the flash_dmattn kernels template instantiations", + description="Generate the flash_sparse_attn kernels template instantiations", ) parser.add_argument( "-o", From 1211c5b3e41a5ccabc49f03bfd8ed58555e6dd9b Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Sun, 9 Nov 2025 23:51:25 +0800 Subject: [PATCH 28/29] Renames environment variable for sparse attention build configuration --- .github/workflows/manual_publish.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/manual_publish.yml b/.github/workflows/manual_publish.yml index c1dae8e..258bfae 100644 --- a/.github/workflows/manual_publish.yml +++ b/.github/workflows/manual_publish.yml @@ -38,7 +38,7 @@ jobs: - name: Build core package env: - FLASH_DMATTN_SKIP_CUDA_BUILD: "TRUE" + FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD: "TRUE" run: | python setup.py sdist --dist-dir=dist ls -l dist From 8695288b40eb9f7ecf424eace4a46695ad2fb55d Mon Sep 17 00:00:00 2001 From: Jingze Date: Sun, 9 Nov 2025 23:53:28 +0800 Subject: [PATCH 29/29] Update CONTRIBUTING.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7271d04..ba79358 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -202,4 +202,4 @@ If you discover a security vulnerability, please send an e-mail to the maintaine If you have questions about contributing, feel free to ask in the [GitHub Discussions](https://github.com/SmallDoges/flash-sparse-attention/discussions) or open an issue. -Thank you for contributing to Flash Dynamic Mask Attention! 🚀 +Thank you for contributing to Flash Sparse Attention! 🚀