From 65f80867174d2e14e1acc98bc44792e85e355e02 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 22 Jul 2025 02:06:51 +0000 Subject: [PATCH 1/2] Initial plan From efd2f288d584ea338e694b7cd9ecfb9e8ba6ed78 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 22 Jul 2025 02:25:15 +0000 Subject: [PATCH 2/2] Implement dBias computation in backward kernel Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com> --- csrc/src/flash_bwd_kernel.h | 44 +++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index 44352fd..ef25be5 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -659,6 +659,48 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N)) Tensor dS = make_tensor(acc_dp.data(), scores.layout()); + + // Compute dBias before modifying dS + // Following Triton logic: dbias = p * (dp - Di[:, None]) + Tensor dBias = make_tensor_like(scores); + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { + dBias(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); + } + } + + // Copy dBias to shared memory and then to global memory + Tensor dBias_reshaped = make_tensor(dBias.data(), acc_dp.layout()); + // Convert dBias from fp32 to fp16/bf16 + Tensor rdBias = FLASH_NAMESPACE::convert_type(dBias_reshaped); + Tensor tadBiasdBias = smem_thr_copy_PdS.retile_S(rdBias); + cute::copy(smem_tiled_copy_PdS, tadBiasdBias, smem_thr_copy_PdS.partition_D(sdBias)); + + __syncthreads(); // Synchronize before copying to global memory + + // Copy dBias from shared memory to global memory + typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_bias; + auto gmem_thr_copy_bias = gmem_tiled_copy_bias.get_thread_slice(tidx); + Tensor tBiasgdBias = gmem_thr_copy_bias.partition_D(gdBias); + Tensor tBiassdBias = gmem_thr_copy_bias.partition_S(sdBias); + Tensor tBiasrdBias = make_tensor(shape(tBiasgdBias)); + cute::copy(gmem_tiled_copy_bias, tBiassdBias, tBiasrdBias); + + // Write to global memory with bounds checking + Tensor cBias = make_identity_tensor(Shape, Int>{}); + Tensor tBiascBias = gmem_thr_copy_bias.partition_D(cBias); + Tensor tBiaspBias = make_tensor(make_shape(size<1>(tBiasgdBias))); + #pragma unroll + for (int n = 0; n < size(tBiaspBias); ++n) { + tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN; + } + // Copy with bounds checking similar to dK/dV pattern + FLASH_NAMESPACE::copy( + gmem_tiled_copy_bias, tBiasrdBias, tBiasgdBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM + ); + auto pointwise_mult = [](float p, float dp, float d) { return p * (!Is_dropout || p >= 0 ? dp - d : d); }; @@ -739,6 +781,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); } gdPsum.data() = gdPsum.data() + (-int(kBlockM)); + // Advance gdBias pointer for next iteration + gdBias.data() = gdBias.data() + (-int(kBlockM * params.dbias_row_stride)); } if (!Is_last) {