From 60a7fde7667f614b15b8d17045f607573a069f66 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 12 Sep 2025 22:11:20 +0800 Subject: [PATCH 01/12] Adds boolean-to-element conversion support in copy function Introduces template parameters to enable converting boolean values to numeric elements during copy operations. Adds conditional logic that converts true values to 1.0f and false values to 0.0f when the Bool_to_Element flag is enabled, allowing for more flexible data type transformations in memory copy routines. --- csrc/flash_dmattn/src/utils.h | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/csrc/flash_dmattn/src/utils.h b/csrc/flash_dmattn/src/utils.h index 81c716a..a93e2ac 100644 --- a/csrc/flash_dmattn/src/utils.h +++ b/csrc/flash_dmattn/src/utils.h @@ -521,7 +521,7 @@ __forceinline__ __device__ void copy( //////////////////////////////////////////////////////////////////////////////////////////////////// template < - bool Is_even_MN=true, bool Clear_OOB_MN=true, + bool Is_even_MN=true, bool Clear_OOB_MN=true, bool Bool_to_Element=false, typename To_type=void, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2 > @@ -543,7 +543,16 @@ __forceinline__ __device__ void copy_MN( #pragma unroll for (int n = 0; n < size<2>(S); ++n) { if (Is_even_MN || get<1>(identity_MN(0, m, n)) < max_N) { - cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + if constexpr (Bool_to_Element) { + #pragma unroll + for (int i = 0; i < size<0>(S); ++i) { + D(i, m, n) = static_cast(S(i, m, n)) + ? static_cast(1.0f) + : static_cast(0.0f); + } + } else { + cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + } } else if (Clear_OOB_MN) { cute::clear(D(_, m, n)); } From ce71d0dd4a8ef41fb3bfaa34672eacbc8cf9a507 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 12 Sep 2025 22:12:16 +0800 Subject: [PATCH 02/12] Fixes mask tensor type safety and copy parameters Changes mask pointer casting from generic Element to const bool for type safety. Updates copy_MN template calls to include Clear_OOB_MN and Bool_to_Element parameters for proper mask handling. Comments out async fence and wait operations, likely for debugging or performance optimization. --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 41 +++++++++++++----------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index c98e3b1..79dcbfb 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -169,7 +169,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_coord(_, 0) ); // (kBlockN, kHeadDim, nblocksN) Tensor mMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), make_stride(params.mask_head_stride, params.mask_row_stride, _1{}) ); @@ -344,15 +344,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } // Reverse iteration over N blocks int n_block = n_block_max - 1; - - FLASH_NAMESPACE::copy_MN( + + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask(_, _, _, n_block), tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - cute::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Do OR-reduce on the mask to see if any active threads @@ -470,14 +470,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } if (n_block > n_block_min) { - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask(_, _, _, n_block - 1), tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - cute::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Do OR-reduce on the mask to see if any active threads for next iteration @@ -593,14 +593,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } if (n_block > n_block_min) { - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask(_, _, _, n_block - 1), tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - cute::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Do OR-reduce on the mask to see if any active threads for next iteration @@ -873,7 +873,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons make_stride(params.v_row_stride, _1{}) ); Tensor gMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.mask_ptr) + col_offset_mask), + make_gmem_ptr(reinterpret_cast(params.mask_ptr) + col_offset_mask), Shape, Int>{}, make_stride(params.mask_row_stride, _1{}) ); @@ -999,14 +999,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons int n_block = n_block_max - 1; - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - cute::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Do OR-reduce on the mask to see if any active threads for next iteration @@ -1146,14 +1146,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); } - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - cute::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Do OR-reduce on the mask to see if any active threads for next iteration @@ -1287,12 +1287,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); } - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); // Do OR-reduce on the mask to see if any active threads for next iteration any_active_local_next = false; From de54bfdb8be2aadaf22679710ac7b89edb3fbf9b Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 12 Sep 2025 22:12:29 +0800 Subject: [PATCH 03/12] Fixes mask data type handling and synchronization Changes mask pointer casting from generic Element to const bool type for proper type safety. Updates copy operations to include Bool_to_Element template parameter for correct boolean-to-element conversion. Replaces asynchronous copy fences with synchronous thread synchronization to ensure proper data consistency before mask operations. --- csrc/flash_dmattn/src/flash_bwd_kernel.h | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 643cfcd..723265b 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -144,7 +144,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in make_stride(params.v_row_stride, _1{}) ); Tensor gMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.mask_ptr) + row_offset_mask), + make_gmem_ptr(reinterpret_cast(params.mask_ptr) + row_offset_mask), Shape, Int>{}, make_stride(params.mask_row_stride, _1{}) ); @@ -552,14 +552,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK); // // if (cute::thread(1, 0)) { print(tKrK); } - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - cute::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); // Do OR-reduce on the mask to see if any active threads Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask); @@ -807,14 +808,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (m_block > m_block_min) { // Advance gMask tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // FLASH_NAMESPACE::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); // Do OR-reduce on the mask to see if any active threads for next iteration any_active_local_next = false; From 9751d0f1cd0ea25e133b8efc732f0d24697b90e7 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 12 Sep 2025 22:12:46 +0800 Subject: [PATCH 04/12] Enforces bool dtype for mask parameter and disables varlen functions Changes mask validation to require boolean dtype instead of matching query dtype across all attention functions. Comments out variable length forward and backward pass functions in the Python binding module. --- csrc/flash_dmattn/flash_api.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 93cde0e..b4a657e 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -361,7 +361,7 @@ mha_fwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs"); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); @@ -512,7 +512,7 @@ mha_varlen_fwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs"); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); @@ -749,7 +749,7 @@ mha_bwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == q_dtype, "query and mask must have the same dtype"); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); @@ -951,7 +951,7 @@ mha_varlen_bwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == q_dtype, "query and mask must have the same dtype"); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); @@ -1136,7 +1136,7 @@ mha_varlen_bwd( PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashDynamicMaskAttention"; m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); - m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length"); + // m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length"); m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); - m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length"); + // m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length"); } From ab4513ac93a268e8e811110862204fbafabde75a Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 12 Sep 2025 22:13:18 +0800 Subject: [PATCH 05/12] Refactors attention mask to use boolean dtype Improves type consistency and performance by using torch.bool instead of generic dtype for attention masks. Eliminates unnecessary type conversions and simplifies mask comparison logic by using False instead of 0.0 comparisons. --- benchmarks/backward_equivalence.py | 8 ++++---- benchmarks/backward_performance.py | 8 ++++---- benchmarks/forward_equivalence.py | 8 ++++---- benchmarks/forward_performance.py | 8 ++++---- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index da66343..2aaa0bb 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -87,11 +87,11 @@ def prepare_dynamic_mask( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False ) valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) - attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) + attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) + attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask diff --git a/benchmarks/backward_performance.py b/benchmarks/backward_performance.py index c08600b..39bfca4 100644 --- a/benchmarks/backward_performance.py +++ b/benchmarks/backward_performance.py @@ -109,11 +109,11 @@ def prepare_dynamic_mask( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False ) valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) - attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) + attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) + attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index 97e80f9..cac014d 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -87,11 +87,11 @@ def prepare_dynamic_mask( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False ) valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) - attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) + attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) + attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index 967e009..ea5bcf0 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -109,11 +109,11 @@ def prepare_dynamic_mask( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False ) valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) - attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) + attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) + attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask From a0d6ee5bd64d6a39c3f0ce0bb7206cf450243b19 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 12 Sep 2025 22:13:40 +0800 Subject: [PATCH 06/12] Clarifies attention bias parameter type in docstring Specifies that attention_bias parameter expects a float tensor to improve API documentation clarity and help developers understand the expected data type. --- flash_dmattn/integrations/flash_dynamic_mask_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_dmattn/integrations/flash_dynamic_mask_attention.py b/flash_dmattn/integrations/flash_dynamic_mask_attention.py index bf3ff7d..7d718e2 100644 --- a/flash_dmattn/integrations/flash_dynamic_mask_attention.py +++ b/flash_dmattn/integrations/flash_dynamic_mask_attention.py @@ -30,7 +30,7 @@ def flash_dynamic_mask_attention_forward( key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim). value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim). attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_kv_heads, query_len, key_len). - attention_bias (Optional[torch.Tensor]): The attention bias tensor of shape (batch_size, num_kv_heads, query_len, key_len). + attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_kv_heads, query_len, key_len). scaling (Optional[float]): The scaling factor for the attention scores. softcap (Optional[float]): The softcap value for the attention scores. **kwargs: Additional keyword arguments. From 9b401770bdfd9822711ec303061f40348b5f8394 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 12 Sep 2025 22:13:57 +0800 Subject: [PATCH 07/12] Optimizes attention mask handling with boolean dtype Replaces float-based attention mask operations with boolean dtype for improved memory efficiency and cleaner logic. Removes unnecessary dtype conversion and simplifies mask creation by using boolean tensors directly instead of converting comparison results to float values. --- .../modeling_flash_dynamic_mask_attention_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py index 8f86542..3c36e61 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py @@ -93,17 +93,15 @@ def _flash_dynamic_mask_attention_forward( ~attention_mask, min_dtype ) - attention_mask = attention_mask.to(dtype) if keep_window_size is not None: if key_length > keep_window_size: topk_values, topk_indices = torch.topk( attention_bias, keep_window_size, dim=-1, largest=True, sorted=False ) - valid_topk = (topk_values != min_dtype).to(dtype) - attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device) - attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk) - attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype) + attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device) + attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype) + attention_bias = attention_bias.masked_fill(attention_mask == False, min_dtype) out = flash_fn( query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal From 2bedcbe44be3025cc8f60d7d3d80665228f1c6e0 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 12 Sep 2025 22:14:53 +0800 Subject: [PATCH 08/12] Removes varlen flash attention functions Eliminates variable-length sequence support to simplify the codebase and focus on standard batch-based attention operations. Removes forward and backward implementations for variable-length sequences along with their fake wrappers, reducing code complexity and maintenance overhead. Fixes mask and bias handling in the remaining implementation to properly handle None values during padding operations. --- flash_dmattn/flash_dmattn_interface.py | 457 +------------------------ 1 file changed, 14 insertions(+), 443 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 1acdfdd..ea2ffcf 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -126,89 +126,6 @@ def _flash_dmattn_forward_fake( _wrapped_flash_dmattn_forward = _flash_dmattn_forward -@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_forward", mutates_args=(), device_types="cuda") -def _flash_dmattn_varlen_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: torch.Tensor, - bias: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - is_causal: bool, - softcap: float = 0.0, - return_softmax: bool = False, - block_table: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)] - out, softmax_lse, S_dmask = flash_dmattn_gpu.varlen_fwd( - q, - k, - v, - mask, - bias, - None, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - leftpad_k, - block_table, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - zero_tensors, - is_causal, - softcap, - return_softmax, - ) - _sanitize_tensors(out) - return out, softmax_lse, S_dmask - - -@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_varlen_forward") -def _flash_dmattn_varlen_forward_fake( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: torch.Tensor, - bias: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - is_causal: bool, - softcap: float = 0.0, - return_softmax: bool = False, - block_table: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)] - paged_kv = block_table is not None - batch_size = cu_seqlens_q.numel() - 1 - total_q, num_heads, _ = q.shape - - out = torch.empty_like(q) - softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) - p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) - seqlen_q_rounded = round_multiple(max_seqlen_q, 128) - seqlen_k_rounded = round_multiple(max_seqlen_k, 128) - if return_softmax: - p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) - return out, softmax_lse, p - - -_wrapped_flash_dmattn_varlen_forward = _flash_dmattn_varlen_forward - - @_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_backward", mutates_args=("dq", "dk", "dv", "dbias"), device_types="cuda") def _flash_dmattn_backward( dout: torch.Tensor, @@ -294,108 +211,6 @@ def _flash_dmattn_backward_fake( _wrapped_flash_dmattn_backward = _flash_dmattn_backward -@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_backward", mutates_args=("dq", "dk", "dv", "dbias"), device_types="cuda") -def _flash_dmattn_varlen_backward( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: torch.Tensor, - bias: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dbias: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - is_causal: bool, - softcap: float, - deterministic: bool, - zero_tensors: bool = False, -) -> torch.Tensor: - dout, dbias, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, dbias, q, k, v, mask, bias, out)] - ( - dq, - dk, - dv, - dbias, - softmax_d, - ) = flash_dmattn_gpu.varlen_bwd( - dout, - q, - k, - v, - mask, - bias, - out, - softmax_lse, - dq, - dk, - dv, - dbias, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - zero_tensors, - is_causal, - softcap, - deterministic, - ) - _sanitize_tensors(dq, dk, dv, dbias) - return softmax_d - - -@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_varlen_backward") -def _flash_dmattn_varlen_backward_fake( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: torch.Tensor, - bias: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dbias: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - is_causal: bool, - softcap: float, - deterministic: bool, - zero_tensors: bool = False, -) -> torch.Tensor: - dout, dbias, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, dbias, q, k, v, mask, bias, out)] - batch_size = cu_seqlens_q.numel() - 1 - total_q, num_heads, _ = q.shape - - if dq is None: - dq = torch.empty_like(q) - if dk is None: - dk = torch.empty_like(k) - if dv is None: - dv = torch.empty_like(v) - if dbias is None: - dbias = torch.empty_like(bias) - softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) - - return softmax_d - - -_wrapped_flash_dmattn_varlen_backward = _flash_dmattn_varlen_backward - - class FlashDMAttnFunc(torch.autograd.Function): @staticmethod def forward( @@ -413,17 +228,10 @@ def forward( is_grad_enabled: bool = True, ): # q, k, v are expected to be of shape (batch_size, seqlen, num_heads, head_size) - batch_size, seqlen_k, num_heads_k, _ = k.shape - seqlen_q = q.shape[1] + seqlen_k = k.shape[1] is_grad = is_grad_enabled and any( x.requires_grad for x in [q, k, v] ) - if mask is None: - mask = torch.ones((batch_size, num_heads_k, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) - return_dbias = True - if bias is None: - return_dbias = False - bias = torch.zeros((batch_size, num_heads_k, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if is_causal is None: @@ -444,9 +252,10 @@ def forward( if seqlen_k % 128 != 0: k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 128 - seqlen_k % 128]) v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 128 - seqlen_k % 128]) - mask = torch.nn.functional.pad(mask, [0, 128 - seqlen_k % 128], value=0.0) - bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=torch.finfo(bias.dtype).min) - + if mask is None: + mask = torch.nn.functional.pad(mask, [0, 128 - seqlen_k % 128], value=False) + if bias is None: + bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=torch.finfo(bias.dtype).min) out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, @@ -467,9 +276,9 @@ def forward( ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic - ctx.return_dbias = return_dbias out = out_padded[..., :head_size_og] + return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod @@ -513,162 +322,8 @@ def backward( dk = dk[:, : ctx.seqlen_k, :, :] dv = dv[:, : ctx.seqlen_k, :, :] dbias = dbias[..., : ctx.seqlen_k] - if ctx.return_dbias: - return dq, dk, dv, None, dbias, None, None, None, None, None, None - return dq, dk, dv, None, None, None, None, None, None, None, None - -class FlashDMAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: Optional[torch.Tensor], - bias: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: Optional[float], - is_causal: Optional[bool], - softcap: Optional[float], - deterministic: Optional[bool], - return_softmax: Optional[bool], - block_table: Optional[torch.Tensor] = None, - is_grad_enabled: bool = True, - ): - dtype = q.dtype - min_dtype = torch.finfo(dtype).min - # q, k, v are expected to be of shape (total, num_heads, head_size) - total_q = q.shape[0] - num_heads_k = k.shape[1] - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if mask is None: - mask = torch.ones((total_q, num_heads_k, max_seqlen_k), dtype=q.dtype, device=q.device) - return_dbias = True - if bias is None: - bias = torch.zeros((total_q, num_heads_k, max_seqlen_k), dtype=q.dtype, device=q.device) - return_dbias = False - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - if is_causal is None: - is_causal = False - if softcap is None: - softcap = 0.0 - if deterministic is None: - deterministic = True - if return_softmax is None: - return_softmax = False - - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - # FDMA requires the max sequence length to be a multiple of 128 - max_seqlen_q_og = max_seqlen_q - max_seqlen_k_og = max_seqlen_k - aligned_max_seqlen_q = ((max_seqlen_q + 128 - 1) // 128) * 128 - aligned_max_seqlen_k = ((max_seqlen_k + 128 - 1) // 128) * 128 - need_pad_q = aligned_max_seqlen_q != max_seqlen_q - need_pad_k = aligned_max_seqlen_k != max_seqlen_k - if need_pad_k: - pad_cols = aligned_max_seqlen_k - max_seqlen_k - mask = torch.nn.functional.pad(mask, [0, pad_cols], value=0.0) - bias = torch.nn.functional.pad(bias, [0, pad_cols], value=min_dtype) - max_seqlen_k = aligned_max_seqlen_k - if need_pad_q: - max_seqlen_q = aligned_max_seqlen_q - - out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_varlen_forward( - q, - k, - v, - mask, - bias, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - is_causal=is_causal, - softcap=softcap, - return_softmax=return_softmax, - block_table=block_table, - ) - - if is_grad: - ctx.save_for_backward( - q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k - ) - ctx.seqlen_q_og = max_seqlen_q_og - ctx.seqlen_k_og = max_seqlen_k_og - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.is_causal = is_causal - ctx.softcap = softcap - ctx.deterministic = deterministic - ctx.return_dbias = return_dbias - - out = out_padded[..., :head_size_og] - if return_softmax: - if max_seqlen_k != max_seqlen_k_og: - S_dmask = S_dmask[..., :max_seqlen_k_og] - return out, softmax_lse, S_dmask - return out - - @staticmethod - def backward( - ctx: torch.autograd.function.FunctionCtx, - dout: torch.Tensor, - *args: Any, - ): - q, k, v, mask, bias, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors - dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) - - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - _wrapped_flash_dmattn_varlen_backward( - dout_padded, - q, - k, - v, - mask, - bias, - out, - softmax_lse, - dq, - dk, - dv, - dbias, - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.softmax_scale, - ctx.is_causal, - ctx.softcap, - ctx.deterministic, - ) - - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - - if ctx.seqlen_k_og != ctx.max_seqlen_k: - dbias = dbias[:, :, :ctx.seqlen_k_og] - - if ctx.return_dbias: - return dq, dk, dv, None, dbias, None, None, None, None, None, None, None, None, None, None, None - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, dbias, None, None, None, None, None, None def flash_dmattn_func( @@ -702,12 +357,14 @@ def flash_dmattn_func( If the row of the mask is all zero, the output will be zero. Arguments: - query: (batch_size, seqlen, nheads, headdim) - key: (batch_size, seqlen, nheads_k, headdim) - value: (batch_size, seqlen, nheads_k, headdim) - attn_mask: (batch_size, nheads_k, seqlen_q, seqlen_k). Attention mask to apply to the attention scores. + query: torch.Tensor. The query tensor of shape (batch_size, seqlen, nheads, headdim) + key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim) + value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim) + attn_mask: torch.Tensor, optional. The attention mask boolean tensor of + shape (batch_size, nheads_k, seqlen_q, seqlen_k) to apply to the attention scores. If None, no mask is applied. - attn_bias: (batch_size, nheads_k, seqlen_q, seqlen_k). Attention Bias to add to the attention scores. + attn_bias: torch.Tensor, optional. The attention bias float tensor of + shape (batch_size, nheads_k, seqlen_q, seqlen_k) to add to the attention scores. If None, no bias is applied. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. @@ -738,89 +395,3 @@ def flash_dmattn_func( return_attn_probs, torch.is_grad_enabled(), ) - - -def flash_dmattn_varlen_func( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - attn_bias: Optional[torch.Tensor] = None, - cu_seqlens_q: torch.Tensor = None, - cu_seqlens_k: torch.Tensor = None, - max_seqlen_q: int = None, - max_seqlen_k: int = None, - scale: Optional[float] = None, - is_causal: Optional[bool] = None, - softcap: Optional[float] = None, - deterministic: Optional[bool] = None, - return_attn_probs: Optional[bool] = None, - block_table: Optional[torch.Tensor] = None, -): - """ - Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - Arguments: - query: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - key: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - value: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - attn_mask: (total_q, nheads_k, max_seqlen_k). Attention mask to apply to the attention scores. - If None, no mask is applied. - attn_bias: (total_q, nheads_k, max_seqlen_k). Attention Bias to add to the attention scores. - If None, no bias is applied. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - softcap: float. Anything > 0 activates softcapping attention. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). - """ - return FlashDMAttnVarlenFunc.apply( - query, - key, - value, - attn_mask, - attn_bias, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - scale, - is_causal, - softcap, - deterministic, - return_attn_probs, - block_table, - torch.is_grad_enabled(), - ) \ No newline at end of file From 0aa8f9f36f89d2308dcdc1a5337cb29504449aab Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 12 Sep 2025 22:26:56 +0800 Subject: [PATCH 09/12] Simplifies boolean to numeric type conversion Replaces verbose static_cast operations with more concise To_type constructor calls when converting boolean values to numeric types. Improves code readability while maintaining the same functionality of converting true to 1 and false to 0. --- csrc/flash_dmattn/src/utils.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/flash_dmattn/src/utils.h b/csrc/flash_dmattn/src/utils.h index a93e2ac..be28c10 100644 --- a/csrc/flash_dmattn/src/utils.h +++ b/csrc/flash_dmattn/src/utils.h @@ -546,9 +546,7 @@ __forceinline__ __device__ void copy_MN( if constexpr (Bool_to_Element) { #pragma unroll for (int i = 0; i < size<0>(S); ++i) { - D(i, m, n) = static_cast(S(i, m, n)) - ? static_cast(1.0f) - : static_cast(0.0f); + D(i, m, n) = static_cast(S(i, m, n)) ? To_type(1) : To_type(0); } } else { cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); From b238dd0a53b27bdd85984e1ef8be42483c138538 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 12 Sep 2025 22:27:06 +0800 Subject: [PATCH 10/12] Fix padding logic to handle non-null mask and bias in FlashDMAttnFunc --- flash_dmattn/flash_dmattn_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index ea2ffcf..c4ac38d 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -252,9 +252,9 @@ def forward( if seqlen_k % 128 != 0: k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 128 - seqlen_k % 128]) v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 128 - seqlen_k % 128]) - if mask is None: + if mask is not None: mask = torch.nn.functional.pad(mask, [0, 128 - seqlen_k % 128], value=False) - if bias is None: + if bias is not None: bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=torch.finfo(bias.dtype).min) out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( From 2aba69741862b3cc58124ec12af388476fe41038 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 12 Sep 2025 22:31:02 +0800 Subject: [PATCH 11/12] Simplifies boolean mask negation logic Replaces `== False` comparison with the more idiomatic `~` operator for boolean negation, improving code readability and following Python best practices. --- .../integrations/modeling_flash_dynamic_mask_attention_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py index 3c36e61..e3a8a3b 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py @@ -101,7 +101,7 @@ def _flash_dynamic_mask_attention_forward( ) attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device) attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype) - attention_bias = attention_bias.masked_fill(attention_mask == False, min_dtype) + attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) out = flash_fn( query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal From b98f2a97e806758fb1ce6b0c904212db90c29434 Mon Sep 17 00:00:00 2001 From: Jingze Shi Date: Fri, 12 Sep 2025 22:32:22 +0800 Subject: [PATCH 12/12] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- benchmarks/backward_equivalence.py | 2 +- benchmarks/backward_performance.py | 2 +- benchmarks/forward_equivalence.py | 2 +- benchmarks/forward_performance.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index 2aaa0bb..024674f 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -89,7 +89,7 @@ def prepare_dynamic_mask( valid_topk = topk_values != min_dtype attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) - attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype) + attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) else: attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask diff --git a/benchmarks/backward_performance.py b/benchmarks/backward_performance.py index 39bfca4..afccef9 100644 --- a/benchmarks/backward_performance.py +++ b/benchmarks/backward_performance.py @@ -111,7 +111,7 @@ def prepare_dynamic_mask( valid_topk = topk_values != min_dtype attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) - attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype) + attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) else: attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index cac014d..94fc0cf 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -89,7 +89,7 @@ def prepare_dynamic_mask( valid_topk = topk_values != min_dtype attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) - attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype) + attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) else: attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index ea5bcf0..b540bf6 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -111,7 +111,7 @@ def prepare_dynamic_mask( valid_topk = topk_values != min_dtype attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) - attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype) + attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) else: attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask