From 2e69c3da745019e22036e068fc75b7cc14012e5f Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 16 Oct 2025 12:35:47 +0800 Subject: [PATCH 1/4] Adds dbias accumulation flag to bwd params Introduces a toggle to optionally accumulate bias gradients during backward attention. Enables skipping unnecessary dbias work when unused and provides clearer control for kernels, aiding performance and configurability. --- csrc/flash_dmattn/src/flash.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_dmattn/src/flash.h index a1c9bf1..29c342f 100644 --- a/csrc/flash_dmattn/src/flash.h +++ b/csrc/flash_dmattn/src/flash.h @@ -195,6 +195,8 @@ struct Flash_bwd_params : public Flash_fwd_params { bool deterministic; index_t dq_accum_split_stride; + + bool accum_dbias; }; //////////////////////////////////////////////////////////////////////////////////////////////////// From e5b045d95edb169bf02d096858ab1aafaf289ddf Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 16 Oct 2025 12:36:43 +0800 Subject: [PATCH 2/4] Add atomic bias-grad accumulation option Adds an optional accumulation path for bias gradients using atomic updates when accumulation is enabled, avoiding overwrites when multiple tiles contribute. Keeps the existing fast write path when accumulation is disabled, respects sequence bounds, and correctly tracks the accumulation pointer across tile steps. Improves correctness for split/streamed backward passes where bias gradients must be aggregated across blocks. --- csrc/flash_dmattn/src/flash_bwd_kernel.h | 44 ++++++++++++++++++++---- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index e8578e2..7c4056e 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -159,6 +159,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Shape, Int>{}, make_stride(params.dbias_row_stride, _1{}) ); + [[maybe_unused]] ElementAccum *gdBias_accum_ptr = nullptr; + if constexpr (Has_bias) { + gdBias_accum_ptr = reinterpret_cast(params.dbias_ptr) + row_offset_dbias; + } Tensor gdO = make_tensor( make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), Shape, Int>{}, @@ -848,12 +852,37 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in __syncthreads(); if constexpr (Has_bias) { // Write dS to dBias - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_dBias, - tBiassBias, tdBiasgdBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); + if (!params.accum_dbias) { + FLASH_NAMESPACE::copy_MN( + gmem_tiled_copy_dBias, + tBiassBias, tdBiasgdBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + } else { + #pragma unroll + for (int m = 0; m < size<1>(tBiassBias); ++m) { + if (Is_even_MN || get<0>(tBiascBias(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { + #pragma unroll + for (int n = 0; n < size<2>(tBiassBias); ++n) { + if (Is_even_MN || tBiaspBias(n)) { + #pragma unroll + for (int i = 0; i < size<0>(tBiassBias); ++i) { + const auto coord = tBiascBias(i, m, n); + const int row = get<0>(coord); + const int col = get<1>(coord); + if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) { + atomicAdd( + gdBias_accum_ptr + row * params.dbias_row_stride + col, + static_cast(tBiassBias(i, m, n)) + ); + } + } + } + } + } + } + } } // if (cute::thread0()) { print(tPrP); } @@ -994,6 +1023,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Advance gBias and gdBias tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); + if (params.accum_dbias) { + gdBias_accum_ptr -= int(kBlockM * params.dbias_row_stride); + } if (any_active_next) { FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, From 9181f7020226679b0302dd72ed5d4038c05d3013 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 16 Oct 2025 12:37:33 +0800 Subject: [PATCH 3/4] Fixes dbias accumulation for broadcasted bias Improves backward handling when bias is broadcast across sequence or batch by allocating correctly shaped scratch buffers and adjusting reduction paths. Adds a kernel parameter to accumulate along sequence for S=1 bias, and uses fp32 buffers for numerically stable accumulation. Corrects the previous over-eager scratch allocation on batch-size mismatch to only trigger for shared (B=1) or head-grouped cases, aligning with broadcasting semantics (incl. MQA/GQA). Leaves the variable-length path unchanged (no accumulation). Results in correct dbias reductions and gradients for broadcasted bias with better numerical stability. --- csrc/flash_dmattn/flash_api.cpp | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 03068ab..7d5bf4f 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -190,6 +190,7 @@ void set_params_dgrad( const float softcap, bool has_mask, bool has_bias, + bool accum_dbias, bool deterministic, const bool unpadded_lse ) { @@ -245,6 +246,8 @@ void set_params_dgrad( // Softmax sum params.dsoftmax_sum = dsoftmax_sum_d; + params.accum_dbias = accum_dbias; + params.deterministic = deterministic; } @@ -977,12 +980,13 @@ mha_bwd( ? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts) : dv; dbias_expanded = has_bias - ? ( - (num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q - ? torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) - : dbias - ) + ? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q + ? (seqlen_q_bias == 1) + ? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts.dtype(at::kFloat)) + : torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) + : dbias : torch::empty({0}, opts); + bool accum_dbias = has_bias && seqlen_q_bias != seqlen_q && seqlen_q_bias == 1; Flash_bwd_params params; @@ -1009,6 +1013,7 @@ mha_bwd( softcap, has_mask, has_bias, + accum_dbias, deterministic, /*unpadded_lse*/false ); @@ -1036,9 +1041,10 @@ mha_bwd( if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) { at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); } else { - dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); - if (seqlen_q_bias == 1) { - dbias_expanded = at::sum(dbias_expanded, {2}, true); + if (accum_dbias) { + dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, 1, seqlen_k_rounded}), {2}); + } else { + dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); } if (batch_size_bias == 1) { dbias_expanded = at::sum(dbias_expanded, {0}, true); @@ -1238,6 +1244,7 @@ mha_varlen_bwd( softcap, has_mask, has_bias, + /*accum_dbias*/false, deterministic, /*unpadded_lse*/true ); From 7f1511873caabc4aef7397899e4a1386a2ad2fee Mon Sep 17 00:00:00 2001 From: Jingze <3314685395@qq.com> Date: Thu, 16 Oct 2025 12:44:32 +0800 Subject: [PATCH 4/4] Update csrc/flash_dmattn/flash_api.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- csrc/flash_dmattn/flash_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 7d5bf4f..040eba1 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -986,7 +986,7 @@ mha_bwd( : torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) : dbias : torch::empty({0}, opts); - bool accum_dbias = has_bias && seqlen_q_bias != seqlen_q && seqlen_q_bias == 1; + bool accum_dbias = has_bias && (seqlen_q_bias == 1 && seqlen_q != 1); Flash_bwd_params params;