diff --git a/README.md b/README.md index 47ac2d1..f88a67d 100644 --- a/README.md +++ b/README.md @@ -72,11 +72,11 @@ pip install . ```python import torch -from flash_dmattn import flash_dmattn_func +from flash_dmattn import flash_dmattn_func_auto import math # Setup -batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128 +batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128 device = torch.device('cuda') dtype = torch.bfloat16 @@ -103,6 +103,9 @@ if seq_len > keep_window_size: attention_mask.zero_() attention_mask.scatter(-1, topk_indices, 1.0) +# Select backend +flash_dmattn_func = flash_dmattn_func_auto(backend="cuda") + # Run Flash Dynamic Mask Attention output = flash_dmattn_func( q=query, @@ -110,11 +113,11 @@ output = flash_dmattn_func( v=value, attn_mask=attention_mask, attn_bias=attention_bias, - softmax_scale=1.0/math.sqrt(head_dim), - is_causal=True + is_causal=True, + scale=1.0/math.sqrt(head_dim), ) -print(f"Output shape: {output.shape}") # [2, 4096, 12, 128] +print(f"Output shape: {output.shape}") # [2, 4096, 16, 128] ``` diff --git a/README_zh.md b/README_zh.md index 8fc6d40..a4551c2 100644 --- a/README_zh.md +++ b/README_zh.md @@ -72,11 +72,11 @@ pip install . ```python import torch -from flash_dmattn import flash_dmattn_func +from flash_dmattn import flash_dmattn_func_auto import math # 设置 -batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128 +batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128 device = torch.device('cuda') dtype = torch.bfloat16 @@ -103,18 +103,21 @@ if seq_len > keep_window_size: attention_mask.zero_() attention_mask.scatter(-1, topk_indices, 1.0) +# 选择后端 +flash_dmattn_func = flash_dmattn_func_auto(backend="cuda") + # 运行 Flash 动态掩码注意力 output = flash_dmattn_func( - q=query, - k=key, - v=value, + query=query, + key=key, + value=value, attn_mask=attention_mask, attn_bias=attention_bias, - softmax_scale=1.0/math.sqrt(head_dim), - is_causal=True + is_causal=True, + scale=1.0/math.sqrt(head_dim), ) -print(f"输出形状: {output.shape}") # [2, 4096, 12, 128] +print(f"输出形状: {output.shape}") # [2, 4096, 16, 128] ``` diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index 6d5177b..e3e9918 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -257,8 +257,8 @@ def dynamic_mask_attention_cuda( attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] dropout_p=0.0, - softmax_scale=scaling, is_causal=is_causal, + scale=scaling, softcap=0.0, deterministic=True, return_attn_probs=return_softmax @@ -331,10 +331,10 @@ def dynamic_mask_attention_triton( query_states, # q: [batch, seqlen_q, num_heads, head_dim] key_states, # k: [batch, seqlen_k, num_heads, head_dim] value_states, # v: [batch, seqlen_k, num_heads, head_dim] - attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - is_causal, # causal masking - scaling # scaling factor + attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] + attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] + is_causal=is_causal, # causal masking + scale=scaling # scaling factor ) return attn_outputs # [batch, query_len, num_heads, head_dim] @@ -396,14 +396,14 @@ def dynamic_mask_attention_flex( # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format # Call the Flex Attention implementation - attn_outputs, _ = flex_dmattn_func( - query_states, # q: [batch, num_heads, query_len, head_dim] - key_states, # k: [batch, num_heads, key_len, head_dim] - value_states, # v: [batch, num_heads, key_len, head_dim] - attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len] - attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len] - is_causal=is_causal, # is_causal: whether to apply causal masking - scaling=scaling # scaling factor + attn_outputs = flex_dmattn_func( + query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim] + key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim] + value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim] + attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] + attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] + is_causal=is_causal, # is_causal: whether to apply causal masking + scale=scaling # scaling factor ) return attn_outputs # [batch, query_len, num_heads, head_dim] diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index e0f7ed3..759668a 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -266,8 +266,8 @@ def dynamic_mask_attention_cuda( attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] dropout_p=0.0, - softmax_scale=scaling, is_causal=is_causal, + scale=scaling, softcap=0.0, deterministic=True, return_attn_probs=return_softmax @@ -350,10 +350,10 @@ def dynamic_mask_attention_triton( query_states, # q: [batch, seqlen_q, num_heads, head_dim] key_states, # k: [batch, seqlen_k, num_heads, head_dim] value_states, # v: [batch, seqlen_k, num_heads, head_dim] - attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - is_causal, # causal masking - scaling # scaling factor + attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] + attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] + is_causal=is_causal, # causal masking + scale=scaling # scaling factor ) torch.cuda.synchronize() @@ -425,14 +425,14 @@ def dynamic_mask_attention_flex( start_time = time.time() # Call the Flex Attention implementation - attn_outputs, _ = flex_dmattn_func( - query_states, # q: [batch, num_heads, query_len, head_dim] - key_states, # k: [batch, num_heads, key_len, head_dim] - value_states, # v: [batch, num_heads, key_len, head_dim] - attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len] - attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len] - is_causal=is_causal, # is_causal: Whether to apply causal masking - scaling=scaling # scaling factor + attn_outputs = flex_dmattn_func( + query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim] + key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim] + value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim] + attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] + attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] + is_causal=is_causal, # is_causal: whether to apply causal masking + scale=scaling # scaling factor ) torch.cuda.synchronize() diff --git a/flash_dmattn/flash_dmattn_flex.py b/flash_dmattn/flash_dmattn_flex.py index dfaa74a..1c160ad 100644 --- a/flash_dmattn/flash_dmattn_flex.py +++ b/flash_dmattn/flash_dmattn_flex.py @@ -10,8 +10,8 @@ def flex_attention_forward( value: torch.Tensor, attn_mask: torch.Tensor, attn_bias: torch.Tensor, - scale: Optional[float] = None, is_causal: bool = True, + scale: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: query = query.transpose(1, 2).contiguous() # [B, H, Q_LEN, D] diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 95fce09..f91a014 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -1151,8 +1151,8 @@ def flash_dmattn_qkvpacked_func( attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, dropout_p: Optional[float] = None, - softmax_scale: Optional[float] = None, is_causal: Optional[bool] = None, + scale: Optional[float] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, return_attn_probs: Optional[bool] = None, @@ -1174,9 +1174,9 @@ def flash_dmattn_qkvpacked_func( attn_bias: (batch_size, nheads, seqlen, seqlen). Attention Bias to add to the attention scores. If None, no bias is applied. dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). softcap: float. Anything > 0 activates softcapping attention. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. @@ -1197,7 +1197,7 @@ def flash_dmattn_qkvpacked_func( attn_mask, attn_bias, dropout_p, - softmax_scale, + scale, is_causal, softcap, deterministic, @@ -1212,7 +1212,7 @@ def flash_dmattn_kvpacked_func( attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, dropout_p: Optional[float] = None, - softmax_scale: Optional[float] = None, + scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, @@ -1247,9 +1247,9 @@ def flash_dmattn_kvpacked_func( attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores. If None, no bias is applied. dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). softcap: float. Anything > 0 activates softcapping attention. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. @@ -1271,7 +1271,7 @@ def flash_dmattn_kvpacked_func( attn_mask, attn_bias, dropout_p, - softmax_scale, + scale, is_causal, softcap, deterministic, @@ -1281,13 +1281,13 @@ def flash_dmattn_kvpacked_func( def flash_dmattn_func( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, dropout_p: Optional[float] = None, - softmax_scale: Optional[float] = None, + scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, @@ -1312,17 +1312,17 @@ def flash_dmattn_func( If the row of the mask is all zero, the output will be zero. Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) + query: (batch_size, seqlen, nheads, headdim) + key: (batch_size, seqlen, nheads_k, headdim) + value: (batch_size, seqlen, nheads_k, headdim) attn_mask: (batch_size, nheads, seqlen_q, seqlen_k). Attention mask to apply to the attention scores. If None, no mask is applied. attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores. If None, no bias is applied. dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for @@ -1338,13 +1338,13 @@ def flash_dmattn_func( pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashDMAttnFunc.apply( - q, - k, - v, + query, + key, + value, attn_mask, attn_bias, dropout_p, - softmax_scale, + scale, is_causal, softcap, deterministic, @@ -1360,7 +1360,7 @@ def flash_dmattn_varlen_qkvpacked_func( cu_seqlens: torch.Tensor = None, max_seqlen: int = None, dropout_p: Optional[float] = None, - softmax_scale: Optional[float] = None, + scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, @@ -1383,9 +1383,9 @@ def flash_dmattn_varlen_qkvpacked_func( of the sequences in the batch, used to index into qkv. max_seqlen: int. Maximum sequence length in the batch. dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). softcap: float. Anything > 0 activates softcapping attention. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. @@ -1408,7 +1408,7 @@ def flash_dmattn_varlen_qkvpacked_func( cu_seqlens, max_seqlen, dropout_p, - softmax_scale, + scale, is_causal, softcap, deterministic, @@ -1427,7 +1427,7 @@ def flash_dmattn_varlen_kvpacked_func( max_seqlen_q: int = None, max_seqlen_k: int = None, dropout_p: Optional[float] = None, - softmax_scale: Optional[float] = None, + scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, @@ -1468,9 +1468,9 @@ def flash_dmattn_varlen_kvpacked_func( max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). softcap: float. Anything > 0 activates softcapping attention. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. @@ -1496,7 +1496,7 @@ def flash_dmattn_varlen_kvpacked_func( max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, + scale, is_causal, softcap, deterministic, @@ -1506,9 +1506,9 @@ def flash_dmattn_varlen_kvpacked_func( def flash_dmattn_varlen_func( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, cu_seqlens_q: torch.Tensor = None, @@ -1516,7 +1516,7 @@ def flash_dmattn_varlen_func( max_seqlen_q: int = None, max_seqlen_k: int = None, dropout_p: Optional[float] = None, - softmax_scale: Optional[float] = None, + scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, @@ -1542,9 +1542,9 @@ def flash_dmattn_varlen_func( If the row of the mask is all zero, the output will be zero. Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + query: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + key: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + value: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. attn_mask: (batch_size, nheads, seqlen_q, seqlen_k). Attention mask to apply to the attention scores. If None, no mask is applied. attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores. @@ -1556,9 +1556,9 @@ def flash_dmattn_varlen_func( max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). softcap: float. Anything > 0 activates softcapping attention. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. @@ -1575,9 +1575,9 @@ def flash_dmattn_varlen_func( pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashDMAttnVarlenFunc.apply( - q, - k, - v, + query, + key, + value, attn_mask, attn_bias, cu_seqlens_q, @@ -1585,7 +1585,7 @@ def flash_dmattn_varlen_func( max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, + scale, is_causal, softcap, deterministic, diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 7a05066..85a072b 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -1052,15 +1052,15 @@ def _flash_attn_backward( class FlashDMAttnFunc(torch.autograd.Function): @staticmethod - def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, softmax_scale=None, is_causal=False): + def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=False, softmax_scale=None): """ query: (batch_size, seqlen_q, nheads, headdim) key: (batch_size, seqlen_k, nheads, headdim) value: (batch_size, seqlen_k, nheads, headdim) attn_mask: optional, (batch, nheads, seqlen_q, seqlen_k) attn_bias: optional, (batch, nheads, seqlen_q, seqlen_k) - softmax_scale: float, scaling factor for attention scores is_causal: bool, whether to apply causal masking + softmax_scale: float, scaling factor for attention scores """ batch, seqlen_q, nheads, _ = query.shape _, seqlen_k, _, _ = key.shape @@ -1111,5 +1111,5 @@ def backward(ctx, do): return dq, dk, dv, None, dbias, None, None -def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, scale=None, is_causal=False): - return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, scale, is_causal) +def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, is_causal=False, scale=None): + return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, is_causal, scale)