From 6f036c10577ea1f910a0f5f37c679d1825d446eb Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Thu, 7 Aug 2025 22:53:47 +0800 Subject: [PATCH 1/4] Standardizes parameter naming and ordering across attention functions Renames `softmax_scale` to `scale` and `q/k/v` to `query/key/value` for consistency across all flash attention function variants. Reorders parameters to place `is_causal` before `scale` in function signatures, improving API consistency and alignment with common attention interface patterns. Updates all function calls, documentation strings, and parameter passing to reflect the standardized naming convention. --- flash_dmattn/flash_dmattn_flex.py | 2 +- flash_dmattn/flash_dmattn_interface.py | 84 +++++++++++++------------- flash_dmattn/flash_dmattn_triton.py | 8 +-- 3 files changed, 47 insertions(+), 47 deletions(-) diff --git a/flash_dmattn/flash_dmattn_flex.py b/flash_dmattn/flash_dmattn_flex.py index dfaa74a..1c160ad 100644 --- a/flash_dmattn/flash_dmattn_flex.py +++ b/flash_dmattn/flash_dmattn_flex.py @@ -10,8 +10,8 @@ def flex_attention_forward( value: torch.Tensor, attn_mask: torch.Tensor, attn_bias: torch.Tensor, - scale: Optional[float] = None, is_causal: bool = True, + scale: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: query = query.transpose(1, 2).contiguous() # [B, H, Q_LEN, D] diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 95fce09..f91a014 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -1151,8 +1151,8 @@ def flash_dmattn_qkvpacked_func( attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, dropout_p: Optional[float] = None, - softmax_scale: Optional[float] = None, is_causal: Optional[bool] = None, + scale: Optional[float] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, return_attn_probs: Optional[bool] = None, @@ -1174,9 +1174,9 @@ def flash_dmattn_qkvpacked_func( attn_bias: (batch_size, nheads, seqlen, seqlen). Attention Bias to add to the attention scores. If None, no bias is applied. dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). softcap: float. Anything > 0 activates softcapping attention. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. @@ -1197,7 +1197,7 @@ def flash_dmattn_qkvpacked_func( attn_mask, attn_bias, dropout_p, - softmax_scale, + scale, is_causal, softcap, deterministic, @@ -1212,7 +1212,7 @@ def flash_dmattn_kvpacked_func( attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, dropout_p: Optional[float] = None, - softmax_scale: Optional[float] = None, + scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, @@ -1247,9 +1247,9 @@ def flash_dmattn_kvpacked_func( attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores. If None, no bias is applied. dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). softcap: float. Anything > 0 activates softcapping attention. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. @@ -1271,7 +1271,7 @@ def flash_dmattn_kvpacked_func( attn_mask, attn_bias, dropout_p, - softmax_scale, + scale, is_causal, softcap, deterministic, @@ -1281,13 +1281,13 @@ def flash_dmattn_kvpacked_func( def flash_dmattn_func( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, dropout_p: Optional[float] = None, - softmax_scale: Optional[float] = None, + scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, @@ -1312,17 +1312,17 @@ def flash_dmattn_func( If the row of the mask is all zero, the output will be zero. Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) + query: (batch_size, seqlen, nheads, headdim) + key: (batch_size, seqlen, nheads_k, headdim) + value: (batch_size, seqlen, nheads_k, headdim) attn_mask: (batch_size, nheads, seqlen_q, seqlen_k). Attention mask to apply to the attention scores. If None, no mask is applied. attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores. If None, no bias is applied. dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for @@ -1338,13 +1338,13 @@ def flash_dmattn_func( pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashDMAttnFunc.apply( - q, - k, - v, + query, + key, + value, attn_mask, attn_bias, dropout_p, - softmax_scale, + scale, is_causal, softcap, deterministic, @@ -1360,7 +1360,7 @@ def flash_dmattn_varlen_qkvpacked_func( cu_seqlens: torch.Tensor = None, max_seqlen: int = None, dropout_p: Optional[float] = None, - softmax_scale: Optional[float] = None, + scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, @@ -1383,9 +1383,9 @@ def flash_dmattn_varlen_qkvpacked_func( of the sequences in the batch, used to index into qkv. max_seqlen: int. Maximum sequence length in the batch. dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). softcap: float. Anything > 0 activates softcapping attention. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. @@ -1408,7 +1408,7 @@ def flash_dmattn_varlen_qkvpacked_func( cu_seqlens, max_seqlen, dropout_p, - softmax_scale, + scale, is_causal, softcap, deterministic, @@ -1427,7 +1427,7 @@ def flash_dmattn_varlen_kvpacked_func( max_seqlen_q: int = None, max_seqlen_k: int = None, dropout_p: Optional[float] = None, - softmax_scale: Optional[float] = None, + scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, @@ -1468,9 +1468,9 @@ def flash_dmattn_varlen_kvpacked_func( max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). softcap: float. Anything > 0 activates softcapping attention. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. @@ -1496,7 +1496,7 @@ def flash_dmattn_varlen_kvpacked_func( max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, + scale, is_causal, softcap, deterministic, @@ -1506,9 +1506,9 @@ def flash_dmattn_varlen_kvpacked_func( def flash_dmattn_varlen_func( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, cu_seqlens_q: torch.Tensor = None, @@ -1516,7 +1516,7 @@ def flash_dmattn_varlen_func( max_seqlen_q: int = None, max_seqlen_k: int = None, dropout_p: Optional[float] = None, - softmax_scale: Optional[float] = None, + scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, @@ -1542,9 +1542,9 @@ def flash_dmattn_varlen_func( If the row of the mask is all zero, the output will be zero. Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + query: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + key: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + value: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. attn_mask: (batch_size, nheads, seqlen_q, seqlen_k). Attention mask to apply to the attention scores. If None, no mask is applied. attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores. @@ -1556,9 +1556,9 @@ def flash_dmattn_varlen_func( max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). softcap: float. Anything > 0 activates softcapping attention. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. @@ -1575,9 +1575,9 @@ def flash_dmattn_varlen_func( pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashDMAttnVarlenFunc.apply( - q, - k, - v, + query, + key, + value, attn_mask, attn_bias, cu_seqlens_q, @@ -1585,7 +1585,7 @@ def flash_dmattn_varlen_func( max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, + scale, is_causal, softcap, deterministic, diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 7a05066..85a072b 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -1052,15 +1052,15 @@ def _flash_attn_backward( class FlashDMAttnFunc(torch.autograd.Function): @staticmethod - def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, softmax_scale=None, is_causal=False): + def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=False, softmax_scale=None): """ query: (batch_size, seqlen_q, nheads, headdim) key: (batch_size, seqlen_k, nheads, headdim) value: (batch_size, seqlen_k, nheads, headdim) attn_mask: optional, (batch, nheads, seqlen_q, seqlen_k) attn_bias: optional, (batch, nheads, seqlen_q, seqlen_k) - softmax_scale: float, scaling factor for attention scores is_causal: bool, whether to apply causal masking + softmax_scale: float, scaling factor for attention scores """ batch, seqlen_q, nheads, _ = query.shape _, seqlen_k, _, _ = key.shape @@ -1111,5 +1111,5 @@ def backward(ctx, do): return dq, dk, dv, None, dbias, None, None -def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, scale=None, is_causal=False): - return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, scale, is_causal) +def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, is_causal=False, scale=None): + return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, is_causal, scale) From e5eb029cb3a037f2b7871b7d0d3e9ac86a9eccb5 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Thu, 7 Aug 2025 22:57:20 +0800 Subject: [PATCH 2/4] Updates API to use auto backend selection function Replaces direct function import with auto backend selection approach for better flexibility. Changes parameter names from abbreviated forms to full descriptive names for improved clarity. Updates num_heads from 12 to 16 in examples to reflect more common model configurations. Renames softmax_scale parameter to scale for consistency with standard naming conventions. --- README.md | 13 ++++++++----- README_zh.md | 19 +++++++++++-------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 47ac2d1..f88a67d 100644 --- a/README.md +++ b/README.md @@ -72,11 +72,11 @@ pip install . ```python import torch -from flash_dmattn import flash_dmattn_func +from flash_dmattn import flash_dmattn_func_auto import math # Setup -batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128 +batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128 device = torch.device('cuda') dtype = torch.bfloat16 @@ -103,6 +103,9 @@ if seq_len > keep_window_size: attention_mask.zero_() attention_mask.scatter(-1, topk_indices, 1.0) +# Select backend +flash_dmattn_func = flash_dmattn_func_auto(backend="cuda") + # Run Flash Dynamic Mask Attention output = flash_dmattn_func( q=query, @@ -110,11 +113,11 @@ output = flash_dmattn_func( v=value, attn_mask=attention_mask, attn_bias=attention_bias, - softmax_scale=1.0/math.sqrt(head_dim), - is_causal=True + is_causal=True, + scale=1.0/math.sqrt(head_dim), ) -print(f"Output shape: {output.shape}") # [2, 4096, 12, 128] +print(f"Output shape: {output.shape}") # [2, 4096, 16, 128] ``` diff --git a/README_zh.md b/README_zh.md index 8fc6d40..a4551c2 100644 --- a/README_zh.md +++ b/README_zh.md @@ -72,11 +72,11 @@ pip install . ```python import torch -from flash_dmattn import flash_dmattn_func +from flash_dmattn import flash_dmattn_func_auto import math # 设置 -batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128 +batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128 device = torch.device('cuda') dtype = torch.bfloat16 @@ -103,18 +103,21 @@ if seq_len > keep_window_size: attention_mask.zero_() attention_mask.scatter(-1, topk_indices, 1.0) +# 选择后端 +flash_dmattn_func = flash_dmattn_func_auto(backend="cuda") + # 运行 Flash 动态掩码注意力 output = flash_dmattn_func( - q=query, - k=key, - v=value, + query=query, + key=key, + value=value, attn_mask=attention_mask, attn_bias=attention_bias, - softmax_scale=1.0/math.sqrt(head_dim), - is_causal=True + is_causal=True, + scale=1.0/math.sqrt(head_dim), ) -print(f"输出形状: {output.shape}") # [2, 4096, 12, 128] +print(f"输出形状: {output.shape}") # [2, 4096, 16, 128] ``` From f6e910a0a6b046d25eb0d040b0337a8feb811006 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Thu, 7 Aug 2025 23:09:40 +0800 Subject: [PATCH 3/4] Standardizes parameter naming across attention functions Updates parameter names to use consistent naming conventions across CUDA, Triton, and Flex attention implementations. Changes 'softmax_scale' to 'scale' and converts positional arguments to keyword arguments for better API consistency and clarity. Fixes tensor dimension ordering in Flex attention by adding transpose operations to match expected input format. --- benchmarks/benchmark_forward_equivalence.py | 26 ++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index 6d5177b..e3e9918 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -257,8 +257,8 @@ def dynamic_mask_attention_cuda( attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] dropout_p=0.0, - softmax_scale=scaling, is_causal=is_causal, + scale=scaling, softcap=0.0, deterministic=True, return_attn_probs=return_softmax @@ -331,10 +331,10 @@ def dynamic_mask_attention_triton( query_states, # q: [batch, seqlen_q, num_heads, head_dim] key_states, # k: [batch, seqlen_k, num_heads, head_dim] value_states, # v: [batch, seqlen_k, num_heads, head_dim] - attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - is_causal, # causal masking - scaling # scaling factor + attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] + attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] + is_causal=is_causal, # causal masking + scale=scaling # scaling factor ) return attn_outputs # [batch, query_len, num_heads, head_dim] @@ -396,14 +396,14 @@ def dynamic_mask_attention_flex( # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format # Call the Flex Attention implementation - attn_outputs, _ = flex_dmattn_func( - query_states, # q: [batch, num_heads, query_len, head_dim] - key_states, # k: [batch, num_heads, key_len, head_dim] - value_states, # v: [batch, num_heads, key_len, head_dim] - attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len] - attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len] - is_causal=is_causal, # is_causal: whether to apply causal masking - scaling=scaling # scaling factor + attn_outputs = flex_dmattn_func( + query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim] + key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim] + value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim] + attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] + attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] + is_causal=is_causal, # is_causal: whether to apply causal masking + scale=scaling # scaling factor ) return attn_outputs # [batch, query_len, num_heads, head_dim] From 39ddb403ef73c3570ed0a57acd7e7aea75d8cf34 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Thu, 7 Aug 2025 23:12:13 +0800 Subject: [PATCH 4/4] Standardizes parameter naming across attention functions Changes `softmax_scale` to `scale` parameter name for consistency across CUDA, Triton, and Flex attention implementations. Updates Flex attention to use keyword arguments and adds tensor transposes to match expected input format. Removes unused return value from Flex attention call to align with other implementations. --- benchmarks/benchmark_forward_performance.py | 26 ++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index e0f7ed3..759668a 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -266,8 +266,8 @@ def dynamic_mask_attention_cuda( attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] dropout_p=0.0, - softmax_scale=scaling, is_causal=is_causal, + scale=scaling, softcap=0.0, deterministic=True, return_attn_probs=return_softmax @@ -350,10 +350,10 @@ def dynamic_mask_attention_triton( query_states, # q: [batch, seqlen_q, num_heads, head_dim] key_states, # k: [batch, seqlen_k, num_heads, head_dim] value_states, # v: [batch, seqlen_k, num_heads, head_dim] - attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - is_causal, # causal masking - scaling # scaling factor + attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] + attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] + is_causal=is_causal, # causal masking + scale=scaling # scaling factor ) torch.cuda.synchronize() @@ -425,14 +425,14 @@ def dynamic_mask_attention_flex( start_time = time.time() # Call the Flex Attention implementation - attn_outputs, _ = flex_dmattn_func( - query_states, # q: [batch, num_heads, query_len, head_dim] - key_states, # k: [batch, num_heads, key_len, head_dim] - value_states, # v: [batch, num_heads, key_len, head_dim] - attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len] - attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len] - is_causal=is_causal, # is_causal: Whether to apply causal masking - scaling=scaling # scaling factor + attn_outputs = flex_dmattn_func( + query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim] + key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim] + value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim] + attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] + attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] + is_causal=is_causal, # is_causal: whether to apply causal masking + scale=scaling # scaling factor ) torch.cuda.synchronize()