From 40bfdb3ab5ae8b1d4ad9a8a81f9abe7ed55ff702 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 2 Jul 2025 13:47:58 +0800 Subject: [PATCH 1/9] Renames offset calculation methods for clarity Improves code readability by renaming function methods to better reflect their purpose: - zoh_offset becomes attn_mask_offset - active_mask_offset becomes attn_bias_offset Makes the codebase more self-documenting by using descriptive names that clearly indicate the functions calculate offsets for attention masks and bias tensors respectively. --- csrc/src/block_info.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/src/block_info.h b/csrc/src/block_info.h index b0ce7c5..d0758eb 100644 --- a/csrc/src/block_info.h +++ b/csrc/src/block_info.h @@ -36,15 +36,15 @@ struct BlockInfo { } template - __forceinline__ __device__ index_t zoh_offset(const index_t batch_stride, const int row_stride, const int col_stride, const int bidb - ) const { + __forceinline__ __device__ index_t attn_mask_offset(const index_t batch_stride, int row_stride, const int col_stride, const int bidb) const { index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride; return offset; } template - __forceinline__ __device__ index_t active_mask_offset(const index_t batch_stride, int row_stride, const int col_stride, const int bidb) const { + __forceinline__ __device__ index_t attn_bias_offset(const index_t batch_stride, const int row_stride, const int col_stride, const int bidb + ) const { index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride; return offset; From 0581b7d3817679d4d086487098e8375c59cdc436 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 2 Jul 2025 13:48:44 +0800 Subject: [PATCH 2/9] Refactors ZOH params into separate mask and bias structures Splits the monolithic ZOH_params struct into two focused components: Mask_params for attention masking operations and Bias_params for attention bias handling. Simplifies parameter management by grouping related functionality and improves code organization for flash attention operations. --- csrc/src/flash.h | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/csrc/src/flash.h b/csrc/src/flash.h index 277d5a3..64302af 100644 --- a/csrc/src/flash.h +++ b/csrc/src/flash.h @@ -45,27 +45,34 @@ struct QKV_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -struct ZOH_params { - void *__restrict__ zoh_ptr; // ZOH states tensor [batch_size, num_kv_heads, query_len, key_len] - void * __restrict__ active_mask_ptr; // Active mask tensor [batch_size, num_kv_heads, query_len, key_len] - - // The stride of the zero-hold states and active mask tensors. - index_t zoh_batch_stride; // Stride between batches of ZOH states - index_t active_mask_batch_stride; // Stride between batches of active mask - index_t zoh_head_stride; // Stride between heads of ZOH states - index_t active_mask_head_stride; // Stride between heads of active mask - index_t zoh_row_stride; // Stride between rows of ZOH states - index_t active_mask_row_stride; // Stride between rows of active mask - index_t zoh_col_stride; // Stride between columns of ZOH states - index_t active_mask_col_stride; // Stride between columns of active mask +struct Mask_params { + void * __restrict__ attn_mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len] + + // The stride of the attention mask tensors. + index_t attn_mask_batch_stride; // Stride between batches of attention mask + index_t attn_mask_head_stride; // Stride between heads of attention mask + index_t attn_mask_row_stride; // Stride between rows of attention mask + index_t attn_mask_col_stride; // Stride between columns of attention mask // The keep window size. - int keep_window_size; // Number of tokens to keep in top-k (0 means don't apply top-k) + int keep_window_size; // Number of tokens to keep in top-k +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Bias_params { + void *__restrict__ attn_bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len] + + // The stride of the attention bias tensor. + index_t attn_bias_batch_stride; // Stride between batches of attention bias + index_t attn_bias_head_stride; // Stride between heads of attention bias + index_t attn_bias_row_stride; // Stride between rows of attention bias + index_t attn_bias_col_stride; // Stride between columns of attention bias }; //////////////////////////////////////////////////////////////////////////////////////////////////// -struct Flash_fwd_params : public QKV_params, public ZOH_params { +struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params { // The O matrix (output). void * __restrict__ o_ptr; From 352f4355cb2f2bd77c36171db17ae302c6891660 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 2 Jul 2025 13:48:54 +0800 Subject: [PATCH 3/9] Renames function to reflect mask-based behavior Updates function name from copy_ZOH to copy_Mask to better describe its actual functionality and improve code clarity. --- csrc/src/utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/src/utils.h b/csrc/src/utils.h index 2e2df1b..3203b23 100644 --- a/csrc/src/utils.h +++ b/csrc/src/utils.h @@ -500,7 +500,7 @@ __forceinline__ __device__ void copy( template -__forceinline__ __device__ void copy_ZOH( +__forceinline__ __device__ void copy_Mask( TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, const int max_M=0, const int max_N=0 From 4377e0ec1933da1250419597a441f972822a3128 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 2 Jul 2025 13:49:11 +0800 Subject: [PATCH 4/9] Renames DynamicMask to Mask and clarifies parameter names Simplifies the mask struct naming by removing "Dynamic" prefix for better clarity. Updates parameter names from ZOH-related terminology to more descriptive "Mask" and "Bias" names, making the code more readable and self-documenting. Changes affect function signatures, variable names, and comments to reflect the new terminology while maintaining the same functionality. --- csrc/src/mask.h | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/csrc/src/mask.h b/csrc/src/mask.h index 849d6e1..8a59444 100644 --- a/csrc/src/mask.h +++ b/csrc/src/mask.h @@ -21,11 +21,11 @@ namespace FLASH_NAMESPACE { using namespace cute; template -struct DynamicMask { +struct Mask { const int max_seqlen_k, max_seqlen_q; const int keep_window_size; - __forceinline__ __device__ DynamicMask( + __forceinline__ __device__ Mask( const int max_seqlen_k, const int max_seqlen_q, const int keep_window_size @@ -35,25 +35,25 @@ struct DynamicMask { , keep_window_size(keep_window_size) { }; - template + template __forceinline__ __device__ void apply_mask( TensorType &tensor_, // acc_s (attention scores, MMA=4, MMA_M, MMA_N) - ZOHType &tSrZOH, // ZOH states (MMA=4, MMA_M, MMA_N) - ActiveMaskType &tSrAM, // Active Mask (MMA=4, MMA_M, MMA_N) + MaskType &Mask, // Attention Mask (MMA=4, MMA_M, MMA_N) + BiasType &Bias, // Attention Bias (MMA=4, MMA_M, MMA_N) const float scale_softmax, // Scale for softmax const int col_idx_offset_, // Column index offset const int row_idx_offset, // Row index offset const int warp_row_stride // Warp row stride ) { static_assert(TensorType::rank == 3, "tensor_ must be 3D Tensor"); - static_assert(ZOHType::rank == 3, "tZOH must be 3D Tensor"); - static_assert(ActiveMaskType::rank == 3, "tActiveMask must be 3D Tensor"); + static_assert(MaskType::rank == 3, "Mask must be 3D Tensor"); + static_assert(BiasType::rank == 3, "Bias must be 3D Tensor"); static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); const bool Need_masking = Causal_mask || !Is_even_MN || (keep_window_size < max_seqlen_k); // Reshape tensors from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); - Tensor zoh = make_tensor(tSrZOH.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrZOH.layout())); - Tensor active_mask = make_tensor(tSrAM.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrAM.layout())); + Tensor mask = make_tensor(Mask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(Mask.layout())); + Tensor bias = make_tensor(Bias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(Bias.layout())); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; @@ -72,19 +72,19 @@ struct DynamicMask { for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - bool inactive = (col_idx >= col_idx_limit) || (active_mask(coord) <= 0.0f); + bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f); if (inactive) { tensor(coord) = -INFINITY; } else { - // Apply scaling and zoh - tensor(coord) = tensor(coord) * scale_softmax + zoh(coord); + // Apply scaling and bias + tensor(coord) = tensor(coord) * scale_softmax + bias(coord); } } } } } } else { - // If no masking is needed, just scale the tensor and add zoh + // If no masking is needed, just scale the tensor and add bias #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { // const int row_idx_base = row_idx_offset + mi * warp_row_stride; @@ -98,7 +98,7 @@ struct DynamicMask { for (int j = 0; j < size<1, 0>(tensor); ++j) { // const int col_idx = col_idx_base + j; auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - tensor(coord) = tensor(coord) * scale_softmax + zoh(coord); + tensor(coord) = tensor(coord) * scale_softmax + bias(coord); } } } From 01a8d6304cb18af02517b1e57b0a06b597b181b2 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 2 Jul 2025 13:49:28 +0800 Subject: [PATCH 5/9] Separates shared memory layouts for mask and bias Refactors shared memory layout definitions to distinguish between mask and bias operations, replacing the combined ZOH/ActiveMask approach with separate SmemLayoutMask and SmemLayoutBias structures. Updates memory size calculations to account for both mask and bias components independently, improving clarity and maintainability of the memory management system. --- csrc/src/kernel_traits.h | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index 19dc319..38b885d 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -82,11 +82,11 @@ struct Flash_fwd_kernel_traits : public Base { // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 Layout>, Stride, _1>>{})); - using SmemLayoutAtomZOH = decltype( + using SmemLayoutAtomMask = decltype( composition(Swizzle{}, Layout, Stride<_8, _1>>{})); - using SmemLayoutAtomActiveMask = decltype( + using SmemLayoutAtomBias = decltype( composition(Swizzle{}, Layout, Stride<_8, _1>>{})); @@ -104,11 +104,11 @@ struct Flash_fwd_kernel_traits : public Base { composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); - using SmemLayoutZOH = decltype(tile_to_shape( - SmemLayoutAtomZOH{}, + using SmemLayoutMask = decltype(tile_to_shape( + SmemLayoutAtomMask{}, Shape, Int>{})); - using SmemLayoutActiveMask = decltype(tile_to_shape( - SmemLayoutAtomActiveMask{}, + using SmemLayoutBias = decltype(tile_to_shape( + SmemLayoutAtomBias{}, Shape, Int>{})); // Shared memory layout for output @@ -125,10 +125,11 @@ struct Flash_fwd_kernel_traits : public Base { // Shared memory size calculations static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); - static constexpr int kSmemMaskSize = size(SmemLayoutZOH{}) * sizeof(Element) + size(SmemLayoutActiveMask{}) * sizeof(Element); + static constexpr int kSmemMaskSize = size(SmemLayoutMask{}) * sizeof(Element); + static constexpr int kSmemBiasSize = size(SmemLayoutBias{}) * sizeof(Element); - // Shared memory size with QKV matrices - static constexpr int kSmemSize = (Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize) + kSmemMaskSize; + // Shared memory size with QKV matrices and mask/bias matrices + static constexpr int kSmemSize = (Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize) + kSmemMaskSize + kSmemBiasSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); @@ -153,11 +154,11 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read - using GmemTiledCopyZOH = decltype( + using GmemTiledCopyMask = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 4 vals per read - using GmemTiledCopyActiveMask = decltype( + using GmemTiledCopyBias = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 4 vals per read From 60d3a2446a8cb1161d9b372db9741a7bf29d899b Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 2 Jul 2025 13:49:47 +0800 Subject: [PATCH 6/9] Renames attention tensor parameters for clarity Updates parameter names from generic `zoh` and `active_mask` to more descriptive `attn_mask` and `attn_bias` throughout the flash attention API. Improves code readability and aligns naming conventions with standard attention mechanism terminology. --- csrc/flash_api.cpp | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 1d948ea..2db3e99 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -40,8 +40,8 @@ void set_params_fprop( const at::Tensor q, const at::Tensor k, const at::Tensor v, - const at::Tensor zoh, - const at::Tensor active_mask, + const at::Tensor attn_mask, + const at::Tensor attn_bias, at::Tensor out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, @@ -65,32 +65,32 @@ void set_params_fprop( params.q_ptr = q.data_ptr(); params.k_ptr = k.data_ptr(); params.v_ptr = v.data_ptr(); - params.zoh_ptr = zoh.data_ptr(); - params.active_mask_ptr = active_mask.data_ptr(); + params.attn_mask_ptr = attn_mask.data_ptr(); + params.attn_bias_ptr = attn_bias.data_ptr(); params.o_ptr = out.data_ptr(); // All stride are in elements, not bytes. params.q_row_stride = q.stride(-3); params.k_row_stride = k.stride(-3); params.v_row_stride = v.stride(-3); - params.zoh_row_stride = zoh.stride(-2); - params.active_mask_row_stride = active_mask.stride(-2); + params.attn_mask_row_stride = attn_mask.stride(-2); + params.attn_bias_row_stride = attn_bias.stride(-2); params.o_row_stride = out.stride(-3); params.q_head_stride = q.stride(-2); params.k_head_stride = k.stride(-2); params.v_head_stride = v.stride(-2); - params.zoh_head_stride = zoh.stride(-3); - params.active_mask_head_stride = active_mask.stride(-3); + params.attn_mask_head_stride = attn_mask.stride(-3); + params.attn_bias_head_stride = attn_bias.stride(-3); params.o_head_stride = out.stride(-2); - params.zoh_col_stride = zoh.stride(-1); - params.active_mask_col_stride = active_mask.stride(-1); + params.attn_mask_col_stride = attn_mask.stride(-1); + params.attn_bias_col_stride = attn_bias.stride(-1); if (cu_seqlens_q_d == nullptr) { params.q_batch_stride = q.stride(0); params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); - params.zoh_batch_stride = zoh.stride(0); - params.active_mask_batch_stride = active_mask.stride(0); + params.attn_mask_batch_stride = attn_mask.stride(0); + params.attn_bias_batch_stride = attn_bias.stride(0); params.o_batch_stride = out.stride(0); if (seqlenq_ngroups_swapped) { params.q_batch_stride *= seqlen_q; @@ -271,8 +271,8 @@ mha_fwd( at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &zoh, // batch_size x num_heads_k x seqlen_q x seqlen_k - const at::Tensor &active_mask, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &attn_mask, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &attn_bias, // batch_size x num_heads_k x seqlen_q x seqlen_k std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const float p_dropout, const float softmax_scale, @@ -295,10 +295,10 @@ mha_fwd( "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(zoh.dtype() == q_dtype, "zoh must have the same dtype as inputs"); - TORCH_CHECK(active_mask.dtype() == q_dtype, "active_mask must have the same dtype as inputs"); + TORCH_CHECK(attn_mask.dtype() == q_dtype, "attn_mask must have the same dtype as inputs"); + TORCH_CHECK(attn_bias.dtype() == q_dtype, "attn_bias must have the same dtype as inputs"); - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(zoh); CHECK_DEVICE(active_mask); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(attn_mask); CHECK_DEVICE(attn_bias); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); @@ -335,8 +335,8 @@ mha_fwd( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(zoh, batch_size, num_heads_k, seqlen_q, seqlen_k); - CHECK_SHAPE(active_mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + CHECK_SHAPE(attn_mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + CHECK_SHAPE(attn_bias, batch_size, num_heads_k, seqlen_q, seqlen_k); at::Tensor out; if (out_.has_value()) { @@ -379,7 +379,7 @@ mha_fwd( num_heads, num_heads_k, head_size, head_size_rounded, keep_window_size, - q, k, v, zoh, active_mask, out, + q, k, v, attn_mask, attn_bias, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, /*seqused_k=*/nullptr, From 4625f64748b8fc9d756b4b38f23df07bc5a1b26b Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 2 Jul 2025 13:56:42 +0800 Subject: [PATCH 7/9] Renames mask and bias variables for clarity Improves code readability by replacing confusing ZOH/ActiveMask naming with clearer Mask/Bias terminology throughout the attention kernel. Updates variable names, tensor declarations, and function calls to use consistent naming conventions that better reflect the actual purpose of these components in the attention computation. --- csrc/src/flash_fwd_kernel.h | 198 ++++++++++++++++++------------------ 1 file changed, 99 insertions(+), 99 deletions(-) diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 42b78e0..f1823a1 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -179,23 +179,23 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Shape, Int>{}, make_stride(params.seqlen_k_rounded, _1{}) ); - Tensor mZOH = make_tensor( - make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, params.zoh_row_stride, params.zoh_col_stride, bidb)), + Tensor mMask = make_tensor( + make_gmem_ptr(reinterpret_cast(params.attn_mask_ptr) + binfo.attn_mask_offset(params.attn_mask_batch_stride, params.attn_mask_row_stride, params.attn_mask_col_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.zoh_head_stride, params.zoh_row_stride, params.zoh_col_stride) + make_stride(params.attn_mask_head_stride, params.attn_mask_row_stride, params.attn_mask_col_stride) ); - Tensor gZOH = local_tile( - mZOH(bidh / params.h_h_k_ratio, _, _), + Tensor gMask = local_tile( + mMask(bidh / params.h_h_k_ratio, _, _), Shape, Int>{}, make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) - Tensor mActiveMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, params.active_mask_row_stride, params.active_mask_col_stride, bidb)), + Tensor mBias = make_tensor( + make_gmem_ptr(reinterpret_cast(params.attn_bias_ptr) + binfo.attn_bias_offset(params.attn_bias_batch_stride, params.attn_bias_row_stride, params.attn_bias_col_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.active_mask_head_stride, params.active_mask_row_stride, params.active_mask_col_stride) + make_stride(params.attn_bias_head_stride, params.attn_bias_row_stride, params.attn_bias_col_stride) ); - Tensor gActiveMask = local_tile( - mActiveMask(bidh / params.h_h_k_ratio, _, _), + Tensor gBias = local_tile( + mBias(bidh / params.h_h_k_ratio, _, _), Shape, Int>{}, make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) @@ -222,22 +222,22 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{} ); - Tensor sZOH = make_tensor( + Tensor sMask = make_tensor( sV.data() + size(sV), - typename Kernel_traits::SmemLayoutZOH{} + typename Kernel_traits::SmemLayoutMask{} ); - Tensor sActiveMask = make_tensor( - sZOH.data() + size(sZOH), - typename Kernel_traits::SmemLayoutActiveMask{} + Tensor sBias = make_tensor( + sMask.data() + size(sMask), + typename Kernel_traits::SmemLayoutBias{} ); // Golobal to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyZOH gmem_tiled_copy_ZOH; - auto gmem_thr_copy_ZOH = gmem_tiled_copy_ZOH.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyActiveMask gmem_tiled_copy_AM; - auto gmem_thr_copy_AM = gmem_tiled_copy_AM.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; + auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; + auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -245,19 +245,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tZOHgZOH = gmem_thr_copy_ZOH.partition_S(gZOH); // (ZOHCPY, ZOHCPY_M, ZOHCPY_N, nblocksN) - Tensor tZOHsZOH = gmem_thr_copy_ZOH.partition_D(sZOH); - Tensor tAMgAM = gmem_thr_copy_AM.partition_S(gActiveMask); // (AMCPY, AMCPY_M, AMCPY_N, nblocksN) - Tensor tAMsAM = gmem_thr_copy_AM.partition_D(sActiveMask); + Tensor tMaskgMask = gmem_thr_copy_Mask.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N, nblocksN) + Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask); + Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N, nblocksN) + Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias); // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K) - Tensor tSrZOH = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) - Tensor tSrAM = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N) + Tensor tSrMask = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) + Tensor tSrBias = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) Tensor tSgS = thr_mma.partition_C(gP); Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_K) @@ -270,26 +270,26 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); - auto smem_tiled_copy_ZOH = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_ZOH = smem_tiled_copy_ZOH.get_thread_slice(tidx); - Tensor tSsZOH = smem_thr_copy_ZOH.partition_S(sZOH); - auto smem_tiled_copy_AM = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_AM = smem_tiled_copy_AM.get_thread_slice(tidx); - Tensor tSsAM = smem_thr_copy_AM.partition_S(sActiveMask); auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); + Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); + Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); // PREDICATES // // Allocate predicate tensors for m and n // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); // Construct identity layout for sQ and sK - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor cZOH = make_identity_tensor(make_shape(size<0>(sZOH), size<1>(sZOH))); // (BLK_M,BLK_N) -> (blk_m,blk_n) - Tensor cAM = make_identity_tensor(make_shape(size<0>(sActiveMask), size<1>(sActiveMask))); // (BLK_M,BLK_N) -> (blk_m,blk_n) - // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M, BLK_K) -> (blk_m, blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N, BLK_K) -> (blk_n, blk_k) + Tensor cMask = make_identity_tensor(make_shape(size<0>(sMask), size<1>(sMask))); // (BLK_M, BLK_N) -> (blk_m, blk_n) + Tensor cBias = make_identity_tensor(make_shape(size<0>(sBias), size<1>(sBias))); // (BLK_M, BLK_N) -> (blk_m, blk_n) + // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA, MMA_M, MMA_K) // if (cute::thread0()) { // print(tScQ.layout()); printf("\n"); // for (int i = 0; i < size(tScQ); ++i) { @@ -302,10 +302,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // printf("\n"); // } // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - Tensor tZOHcZOH = gmem_thr_copy_ZOH.partition_S(cZOH); // (ZOHCPY, ZOHCPY_M, ZOHCPY_N) -> (blk_m, blk_n) - Tensor tAMcAM = gmem_thr_copy_AM.partition_S(cAM); // (AMCPY, AMCPY_M, AMCPY_N) -> (blk_m, blk_n) + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) + Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); @@ -351,19 +351,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_ZOH( - gmem_tiled_copy_ZOH, - tZOHgZOH(_, _, _, n_block), - tZOHsZOH, - tZOHcZOH, + FLASH_NAMESPACE::copy_Mask( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block), + tMasksMask, + tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_ZOH( - gmem_tiled_copy_AM, - tAMgAM(_, _, _, n_block), - tAMsAM, - tAMcAM, + FLASH_NAMESPACE::copy_Mask( + gmem_tiled_copy_Bias, + tBiasgBias(_, _, _, n_block), + tBiassBias, + tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); @@ -384,7 +384,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; // Init dynamic mask processor - FLASH_NAMESPACE::DynamicMask dynamic_mask( + FLASH_NAMESPACE::Mask mask( binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.keep_window_size ); @@ -407,11 +407,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); - // Copy ZOH and ActiveMask from smem to registers - Tensor tSrZOH_copy_view = smem_thr_copy_ZOH.retile_D(tSrZOH); - cute::copy(smem_tiled_copy_ZOH, tSsZOH, tSrZOH_copy_view); - Tensor tSrAM_copy_view = smem_thr_copy_AM.retile_D(tSrAM); - cute::copy(smem_tiled_copy_AM, tSsAM, tSrAM_copy_view); + // Copy Mask and Bias from smem to registers + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); // Advance gV if (masking_step > 0) { @@ -428,7 +428,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::sparse_gemm( acc_s, tSrQ, - tSrK, tSsQ, tSsK, tSrAM, // Active key indices for sparse K matrix multiplication + tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); @@ -437,9 +437,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - // Scale attention scores and apply dynamic mask - dynamic_mask.template apply_mask( - acc_s, tSrZOH, tSrAM, params.scale_softmax, + // Scale attention scores and apply mask/bias + mask.template apply_mask( + acc_s, tSrMask, tSrBias, params.scale_softmax, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); @@ -449,19 +449,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. - FLASH_NAMESPACE::copy_ZOH( - gmem_tiled_copy_ZOH, - tZOHgZOH(_, _, _, n_block - 1), - tZOHsZOH, - tZOHcZOH, + FLASH_NAMESPACE::copy_Mask( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block - 1), + tMasksMask, + tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - FLASH_NAMESPACE::copy_ZOH( - gmem_tiled_copy_AM, - tAMgAM(_, _, _, n_block - 1), - tAMsAM, - tAMcAM, + FLASH_NAMESPACE::copy_Mask( + gmem_tiled_copy_Bias, + tBiasgBias(_, _, _, n_block - 1), + tBiassBias, + tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); @@ -500,7 +500,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Use sparse general matrix multiplication with register accumulation for V as well FLASH_NAMESPACE::sparse_gemm_rs( acc_o, - tOrP, tOrVt, tOsVt, tSrAM, // Apply the same mask for sparse V matrix multiplication + tOrP, tOrVt, tOsVt, tSrMask, // Apply the same mask for sparse V matrix multiplication tiled_mma, smem_tiled_copy_V, smem_thr_copy_V ); // if (cute::thread0()) { print(scores); } @@ -519,11 +519,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); - // Copy ZOH and ActiveMask from smem to registers - Tensor tSrZOH_copy_view = smem_thr_copy_ZOH.retile_D(tSrZOH); - cute::copy(smem_tiled_copy_ZOH, tSsZOH, tSrZOH_copy_view); - Tensor tSrAM_copy_view = smem_thr_copy_AM.retile_D(tSrAM); - cute::copy(smem_tiled_copy_AM, tSsAM, tSrAM_copy_view); + // Copy Mask and Bias from smem to registers + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); @@ -531,7 +531,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::sparse_gemm( acc_s, tSrQ, - tSrK, tSsQ, tSsK, tSrAM, // Active key indices for sparse K matrix multiplication + tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); @@ -540,8 +540,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } // Scale attention scores and apply dynamic mask - dynamic_mask.template apply_mask( - acc_s, tSrZOH, tSrAM, params.scale_softmax, + mask.template apply_mask( + acc_s, tSrMask, tSrBias, params.scale_softmax, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); @@ -549,19 +549,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); if (n_block > n_block_min) { FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); - FLASH_NAMESPACE::copy_ZOH( - gmem_tiled_copy_ZOH, - tZOHgZOH(_, _, _, n_block - 1), - tZOHsZOH, - tZOHcZOH, + FLASH_NAMESPACE::copy_Mask( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block - 1), + tMasksMask, + tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - FLASH_NAMESPACE::copy_ZOH( - gmem_tiled_copy_AM, - tAMgAM(_, _, _, n_block - 1), - tAMsAM, - tAMcAM, + FLASH_NAMESPACE::copy_Mask( + gmem_tiled_copy_Bias, + tBiasgBias(_, _, _, n_block - 1), + tBiassBias, + tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); @@ -597,7 +597,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Use sparse general matrix multiplication with register accumulation for V as well FLASH_NAMESPACE::sparse_gemm_rs( acc_o, - tOrP, tOrVt, tOsVt, tSrAM, // Apply the same mask for sparse V matrix multiplication + tOrP, tOrVt, tOsVt, tSrMask, // Apply the same mask for sparse V matrix multiplication tiled_mma, smem_tiled_copy_V, smem_thr_copy_V ); } @@ -608,12 +608,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Convert acc_o from fp32 to fp16/bf16 Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); - Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M, SMEM_N) // Partition sO to match the accumulator partitioning auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); - Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom, AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom, AtomNum), PIPE_M, PIPE_N) // sO has the same size as sQ, so we don't need to sync here. if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } @@ -634,7 +634,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom, AtomNum), ATOM_M, ATOM_N) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); __syncthreads(); @@ -642,12 +642,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tOrO = make_tensor(shape(tOgO)); cute::copy(gmem_tiled_copy_O, tOsO, tOrO); - Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M, BLK_K) -> (blk_m, blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA, MMA_M, MMA_K) static_assert(decltype(size<0>(taccOcO))::value == 4); // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M if (get<1>(taccOcO_row(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { @@ -657,9 +657,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M, BLK_K) -> (blk_m, blk_k) // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); if (!Is_even_K) { #pragma unroll From 41376a6242cd0dd5c8968e6f2b4d12e05c035d2b Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 2 Jul 2025 13:57:03 +0800 Subject: [PATCH 8/9] Fixes parameter order in CUDA attention function call Corrects the argument mapping in the flash_dma_cuda.fwd call by swapping the order of zero_hold_states and active_mask parameters to match the expected function signature. The change ensures proper parameter alignment where active_mask is passed as the attn_mask argument and attn_mask is passed as the bias argument. --- benchmarks/benchmark_forward_equivalence.py | 28 ++++++++++----------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index 5d89419..78ab9c1 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -232,20 +232,20 @@ def dynamic_mask_attention_cuda( # Call the CUDA implementation using the mha_fwd function signature out_tensor = None # Let the function allocate the output tensor - result = flash_dma_cuda.fwd( # type: ignore - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask - active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] - out_tensor, # out: None to auto-allocate - 0.0, # p_dropout - scaling, # softmax_scale - is_causal, # is_causal - keep_window_size, # keep_window_size - 0.0, # softcap - return_softmax, # return_softmax - None # gen (generator) + result = flash_dma_cuda.fwd( # type: ignore + query_states, # q: [batch, seqlen_q, num_heads, head_dim] + key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] + value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] + active_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] + attn_mask, # bias: [batch, num_kv_heads, seqlen_q, seqlen_k] + out_tensor, # out: None to auto-allocate + 0.0, # p_dropout + scaling, # softmax_scale + is_causal, # is_causal + keep_window_size, # keep_window_size + 0.0, # softcap + return_softmax, # return_softmax + None # gen (generator) ) attn_outputs = result[0] # [batch, query_len, num_heads, head_dim] From aef047f34d5366a0502a77111c4007919e2cef60 Mon Sep 17 00:00:00 2001 From: Jingze Shi Date: Wed, 2 Jul 2025 14:01:37 +0800 Subject: [PATCH 9/9] Update benchmarks/benchmark_forward_equivalence.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- benchmarks/benchmark_forward_equivalence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index 78ab9c1..3f4a86b 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -236,8 +236,8 @@ def dynamic_mask_attention_cuda( query_states, # q: [batch, seqlen_q, num_heads, head_dim] key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - active_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] - attn_mask, # bias: [batch, num_kv_heads, seqlen_q, seqlen_k] + attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] + active_mask, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k] out_tensor, # out: None to auto-allocate 0.0, # p_dropout scaling, # softmax_scale