From 8effe3cdf33a0a77eb2998f94156e086a40d7943 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 12 Oct 2025 23:55:30 +0800 Subject: [PATCH 1/2] Refactors mha_bwd to use torch::zeros for bias initialization and removes unnecessary zeroing of dbias_expanded --- csrc/flash_dmattn/flash_api.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index e6a8b7a..03068ab 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -979,13 +979,10 @@ mha_bwd( dbias_expanded = has_bias ? ( (num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q - ? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) + ? torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) : dbias ) : torch::empty({0}, opts); - if (has_bias) { - dbias_expanded.zero_(); - } Flash_bwd_params params; @@ -1050,7 +1047,7 @@ mha_bwd( } } - return { dq, dk, dv, dbias, softmax_d }; + return {dq, dk, dv, dbias, softmax_d}; } std::vector From 502a1a43e117379c79112ac71f3873f913e6c5ce Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 12 Oct 2025 23:55:57 +0800 Subject: [PATCH 2/2] Enhances FlashDMAttnFunc to track original sequence length bias and adjusts dbias computation based on its value --- flash_dmattn/flash_dmattn_interface.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 663dc2e..4ee10eb 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -409,6 +409,7 @@ def forward( deterministic = False if return_softmax is None: return_softmax = False + seqlen_k_bias_og = bias.shape[-1] if bias is not None else 0 # Padding to multiple of 8 for 16-bit memory allocations head_size_og = q.size(3) @@ -446,6 +447,7 @@ def forward( ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic + ctx.seqlen_k_bias_og = seqlen_k_bias_og out = out_padded[..., :head_size_og] @@ -491,7 +493,7 @@ def backward( dv = dv[..., : dout.shape[-1]] if dbias is not None: - dbias = dbias[..., : k.shape[1]] + dbias = dbias[..., :k.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : k.shape[1]] return dq, dk, dv, None, dbias, None, None, None, None, None, None