From 62a0d148bf4fc62eff6e31fbb4d016fe6d8e6443 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 11 Jul 2025 18:45:30 +0800 Subject: [PATCH 1/2] Adds support for bias gradient computation in backward kernel Introduces dedicated offset calculation and tensor configuration for bias gradient computation. Adds row_offset_dbias calculation using dbias-specific stride parameters and creates gdBias tensor with proper memory layout. Reorganizes tensor declarations with improved formatting and adds shared memory tensors for mask, bias, and bias gradient operations to support the enhanced backward pass functionality. --- csrc/src/flash_bwd_kernel.h | 200 ++++++++++++++++++++++++++---------- 1 file changed, 145 insertions(+), 55 deletions(-) diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index 9dee833..6013ad6 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -112,6 +112,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const index_t row_offset_bias = binfo.attn_bias_offset(params.attn_bias_batch_stride, params.attn_bias_row_stride, params.attn_bias_col_stride, bidb) + (bidh / params.h_h_k_ratio) * params.attn_bias_head_stride + (m_block_max - 1) * kBlockM * params.attn_bias_row_stride + n_block * kBlockN * params.attn_bias_col_stride; + const index_t row_offset_dbias = binfo.attn_bias_offset(params.dbias_batch_stride, params.dbias_row_stride, params.dbias_col_stride, bidb) + + (bidh / params.h_h_k_ratio) * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + + n_block * kBlockN * params.dbias_col_stride; const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride; const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) @@ -126,63 +129,150 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM; - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); - Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - Shape, Int>{}, - make_stride(params.v_row_stride, _1{})); - Tensor gMask = make_tensor(make_gmem_ptr(reinterpret_cast(params.attn_mask_ptr) + row_offset_mask), - Shape, Int>{}, - make_stride(params.attn_mask_row_stride, params.attn_mask_col_stride)); - Tensor gBias = make_tensor(make_gmem_ptr(reinterpret_cast(params.attn_bias_ptr) + row_offset_bias), - Shape, Int>{}, - make_stride(params.attn_bias_row_stride, params.attn_bias_col_stride)); - Tensor gdBias = make_tensor(make_gmem_ptr(reinterpret_cast(params.dbias_ptr) + row_offset_bias), - Shape, Int>{}, - make_stride(params.dbias_row_stride, params.dbias_col_stride)); - Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), - Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), - Shape, Int>{}, - make_stride(params.dq_row_stride, _1{})); - Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, - make_stride(params.h * params.d_rounded, _1{})); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); - Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), - Shape>{}, Stride<_1>{}); - - Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQdO{}); - Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); - Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); + // Golobal memory tensor configuration + Tensor gQ = make_tensor( + make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{}) + ); + Tensor gK = make_tensor( + make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{}) + ); + Tensor gV = make_tensor( + make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{}) + ); + Tensor gMask = make_tensor( + make_gmem_ptr(reinterpret_cast(params.attn_mask_ptr) + row_offset_mask), + Shape, Int>{}, + make_stride(params.attn_mask_row_stride, params.attn_mask_col_stride) + ); + Tensor gBias = make_tensor( + make_gmem_ptr(reinterpret_cast(params.attn_bias_ptr) + row_offset_bias), + Shape, Int>{}, + make_stride(params.attn_bias_row_stride, params.attn_bias_col_stride) + ); + Tensor gdO = make_tensor( + make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{}) + ); + Tensor gO = make_tensor( + make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{}) + ); + Tensor gdQ = make_tensor( + make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), + Shape, Int>{}, + make_stride(params.dq_row_stride, _1{}) + ); + Tensor gdQaccum = make_tensor( + make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{}) + ); + Tensor gdBias = make_tensor( + make_gmem_ptr(reinterpret_cast(params.dbias_ptr) + row_offset_dbias), + Shape, Int>{}, + make_stride(params.dbias_row_stride, params.dbias_col_stride) + ); + Tensor gLSE = make_tensor( + make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{} + ); + Tensor gdPsum = make_tensor( + make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), + Shape>{}, Stride<_1>{} + ); + + // Shared memory layout configuration + Tensor sQ = make_tensor( + make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQdO{} + ); + Tensor sQt = make_tensor( + sQ.data(), + typename Kernel_traits::SmemLayoutQdOtransposed{} + ); + Tensor sQtNoSwizzle = make_tensor( + sQ.data(), + typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{} + ); // Double buffer for sQ - Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{}); - Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); - Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(), - typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); - Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{}); - Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); - Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{}); - Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{}); - Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK), - typename Kernel_traits::SmemLayoutPdS{}); - Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); - Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); - Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{}); - Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); - Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); + Tensor sdO = make_tensor( + sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), + typename Kernel_traits::SmemLayoutQdO{} + ); + Tensor sdOt = make_tensor( + sdO.data(), + typename Kernel_traits::SmemLayoutQdOtransposed{} + ); + Tensor sdOtransposedNoSwizzle = make_tensor( + sdO.data(), + typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{} + ); + Tensor sK = make_tensor( + sdO.data() + size(sdO), + typename Kernel_traits::SmemLayoutKV{} + ); + Tensor sKt = make_tensor( + sK.data(), + typename Kernel_traits::SmemLayoutKtransposed{} + ); + Tensor sKtNoSwizzle = make_tensor( + sK.data(), + typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{} + ); + Tensor sV = make_tensor( + sK.data() + size(sK), + typename Kernel_traits::SmemLayoutKV{} + ); + Tensor sMask = make_tensor( + !Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK), + typename Kernel_traits::SmemLayoutMask{} + ); + Tensor sBias = make_tensor( + sMask.data() + size(sMask), + typename Kernel_traits::SmemLayoutBias{} + ); + Tensor sdBias = make_tensor( + sBias.data(), + typename Kernel_traits::SmemLayoutBias{} + ); + Tensor sdS = make_tensor( + sBias.data() + size(sBias), + typename Kernel_traits::SmemLayoutPdS{} + ); + Tensor sdSt = make_tensor( + sdS.data(), + typename Kernel_traits::SmemLayoutPdStransposed{} + ); + Tensor sdStNoSwizzle = make_tensor( + sdS.data(), + typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{} + ); + Tensor sP = make_tensor( + sdS.data() + size(sdS), + typename Kernel_traits::SmemLayoutPdS{} + ); + Tensor sPt = make_tensor( + sP.data(), + typename Kernel_traits::SmemLayoutPdStransposed{} + ); + Tensor sPtNoSwizzle = make_tensor( + sP.data(), + typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{} + ); // sP and sdQ share the same memory so be careful - Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{}); + Tensor sdQ = make_tensor( + sP.data(), + typename Kernel_traits::SmemLayoutdQ{} + ); + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); From d2debb48b94ead8f6b7f38342481a797cf1c4945 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 11 Jul 2025 18:48:44 +0800 Subject: [PATCH 2/2] Fixes typo in comment from "Golobal" to "Global" Corrects spelling error in comment describing global memory tensor configuration across multiple kernel files. --- csrc/src/flash_bwd_kernel.h | 2 +- csrc/src/flash_fwd_kernel.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index 6013ad6..44352fd 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -129,7 +129,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM; - // Golobal memory tensor configuration + // Global memory tensor configuration Tensor gQ = make_tensor( make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index aea9604..e0176bc 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -143,7 +143,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - // Golobal memory tensor configuration + // Global memory tensor configuration Tensor mQ = make_tensor( make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d),