diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index 9dee833..44352fd 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -112,6 +112,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const index_t row_offset_bias = binfo.attn_bias_offset(params.attn_bias_batch_stride, params.attn_bias_row_stride, params.attn_bias_col_stride, bidb) + (bidh / params.h_h_k_ratio) * params.attn_bias_head_stride + (m_block_max - 1) * kBlockM * params.attn_bias_row_stride + n_block * kBlockN * params.attn_bias_col_stride; + const index_t row_offset_dbias = binfo.attn_bias_offset(params.dbias_batch_stride, params.dbias_row_stride, params.dbias_col_stride, bidb) + + (bidh / params.h_h_k_ratio) * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + + n_block * kBlockN * params.dbias_col_stride; const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride; const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) @@ -126,63 +129,150 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM; - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); - Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - Shape, Int>{}, - make_stride(params.v_row_stride, _1{})); - Tensor gMask = make_tensor(make_gmem_ptr(reinterpret_cast(params.attn_mask_ptr) + row_offset_mask), - Shape, Int>{}, - make_stride(params.attn_mask_row_stride, params.attn_mask_col_stride)); - Tensor gBias = make_tensor(make_gmem_ptr(reinterpret_cast(params.attn_bias_ptr) + row_offset_bias), - Shape, Int>{}, - make_stride(params.attn_bias_row_stride, params.attn_bias_col_stride)); - Tensor gdBias = make_tensor(make_gmem_ptr(reinterpret_cast(params.dbias_ptr) + row_offset_bias), - Shape, Int>{}, - make_stride(params.dbias_row_stride, params.dbias_col_stride)); - Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), - Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), - Shape, Int>{}, - make_stride(params.dq_row_stride, _1{})); - Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, - make_stride(params.h * params.d_rounded, _1{})); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); - Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), - Shape>{}, Stride<_1>{}); - - Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQdO{}); - Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); - Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); + // Global memory tensor configuration + Tensor gQ = make_tensor( + make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{}) + ); + Tensor gK = make_tensor( + make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{}) + ); + Tensor gV = make_tensor( + make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{}) + ); + Tensor gMask = make_tensor( + make_gmem_ptr(reinterpret_cast(params.attn_mask_ptr) + row_offset_mask), + Shape, Int>{}, + make_stride(params.attn_mask_row_stride, params.attn_mask_col_stride) + ); + Tensor gBias = make_tensor( + make_gmem_ptr(reinterpret_cast(params.attn_bias_ptr) + row_offset_bias), + Shape, Int>{}, + make_stride(params.attn_bias_row_stride, params.attn_bias_col_stride) + ); + Tensor gdO = make_tensor( + make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{}) + ); + Tensor gO = make_tensor( + make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{}) + ); + Tensor gdQ = make_tensor( + make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), + Shape, Int>{}, + make_stride(params.dq_row_stride, _1{}) + ); + Tensor gdQaccum = make_tensor( + make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{}) + ); + Tensor gdBias = make_tensor( + make_gmem_ptr(reinterpret_cast(params.dbias_ptr) + row_offset_dbias), + Shape, Int>{}, + make_stride(params.dbias_row_stride, params.dbias_col_stride) + ); + Tensor gLSE = make_tensor( + make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{} + ); + Tensor gdPsum = make_tensor( + make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), + Shape>{}, Stride<_1>{} + ); + + // Shared memory layout configuration + Tensor sQ = make_tensor( + make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQdO{} + ); + Tensor sQt = make_tensor( + sQ.data(), + typename Kernel_traits::SmemLayoutQdOtransposed{} + ); + Tensor sQtNoSwizzle = make_tensor( + sQ.data(), + typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{} + ); // Double buffer for sQ - Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{}); - Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); - Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(), - typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); - Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{}); - Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); - Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{}); - Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{}); - Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK), - typename Kernel_traits::SmemLayoutPdS{}); - Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); - Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); - Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{}); - Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); - Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); + Tensor sdO = make_tensor( + sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), + typename Kernel_traits::SmemLayoutQdO{} + ); + Tensor sdOt = make_tensor( + sdO.data(), + typename Kernel_traits::SmemLayoutQdOtransposed{} + ); + Tensor sdOtransposedNoSwizzle = make_tensor( + sdO.data(), + typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{} + ); + Tensor sK = make_tensor( + sdO.data() + size(sdO), + typename Kernel_traits::SmemLayoutKV{} + ); + Tensor sKt = make_tensor( + sK.data(), + typename Kernel_traits::SmemLayoutKtransposed{} + ); + Tensor sKtNoSwizzle = make_tensor( + sK.data(), + typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{} + ); + Tensor sV = make_tensor( + sK.data() + size(sK), + typename Kernel_traits::SmemLayoutKV{} + ); + Tensor sMask = make_tensor( + !Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK), + typename Kernel_traits::SmemLayoutMask{} + ); + Tensor sBias = make_tensor( + sMask.data() + size(sMask), + typename Kernel_traits::SmemLayoutBias{} + ); + Tensor sdBias = make_tensor( + sBias.data(), + typename Kernel_traits::SmemLayoutBias{} + ); + Tensor sdS = make_tensor( + sBias.data() + size(sBias), + typename Kernel_traits::SmemLayoutPdS{} + ); + Tensor sdSt = make_tensor( + sdS.data(), + typename Kernel_traits::SmemLayoutPdStransposed{} + ); + Tensor sdStNoSwizzle = make_tensor( + sdS.data(), + typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{} + ); + Tensor sP = make_tensor( + sdS.data() + size(sdS), + typename Kernel_traits::SmemLayoutPdS{} + ); + Tensor sPt = make_tensor( + sP.data(), + typename Kernel_traits::SmemLayoutPdStransposed{} + ); + Tensor sPtNoSwizzle = make_tensor( + sP.data(), + typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{} + ); // sP and sdQ share the same memory so be careful - Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{}); + Tensor sdQ = make_tensor( + sP.data(), + typename Kernel_traits::SmemLayoutdQ{} + ); + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index aea9604..e0176bc 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -143,7 +143,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - // Golobal memory tensor configuration + // Global memory tensor configuration Tensor mQ = make_tensor( make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d),