From b0054e21a0001ba03a510258912a0c9785abfd60 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 28 Jul 2025 12:33:41 +0800 Subject: [PATCH 1/5] Simplifies attention mask and bias parameter naming Removes redundant "attn_" prefixes from mask and bias parameter names to improve code readability and consistency. Also removes unused keep_window_size field from Mask_params struct. --- csrc/src/flash.h | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/csrc/src/flash.h b/csrc/src/flash.h index 1ccfd2a..95d7f52 100644 --- a/csrc/src/flash.h +++ b/csrc/src/flash.h @@ -46,28 +46,25 @@ struct QKV_params { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Mask_params { - void * __restrict__ attn_mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len] + void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len] // The stride of the attention mask tensors. - index_t attn_mask_batch_stride; // Stride between batches of attention mask - index_t attn_mask_head_stride; // Stride between heads of attention mask - index_t attn_mask_row_stride; // Stride between rows of attention mask - index_t attn_mask_col_stride; // Stride between columns of attention mask - - // The keep window size. - int keep_window_size; // Number of tokens to keep in top-k + index_t mask_batch_stride; // Stride between batches of attention mask + index_t mask_head_stride; // Stride between heads of attention mask + index_t mask_row_stride; // Stride between rows of attention mask + index_t mask_col_stride; // Stride between columns of attention mask }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Bias_params { - void *__restrict__ attn_bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len] + void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len] // The stride of the attention bias tensor. - index_t attn_bias_batch_stride; // Stride between batches of attention bias - index_t attn_bias_head_stride; // Stride between heads of attention bias - index_t attn_bias_row_stride; // Stride between rows of attention bias - index_t attn_bias_col_stride; // Stride between columns of attention bias + index_t bias_batch_stride; // Stride between batches of attention bias + index_t bias_head_stride; // Stride between heads of attention bias + index_t bias_row_stride; // Stride between rows of attention bias + index_t bias_col_stride; // Stride between columns of attention bias }; //////////////////////////////////////////////////////////////////////////////////////////////////// From 770c3e6d943acdbaf494a48d119b3e80223d3800 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 28 Jul 2025 12:34:45 +0800 Subject: [PATCH 2/5] Refactors parameter names and removes unused parameter Renames attention mask and bias parameters from `attn_mask`/`attn_bias` to `mask`/`bias` for improved clarity and consistency throughout the flash attention API. Removes the unused `keep_window_size` parameter from function signatures and parameter structures to clean up the interface. --- csrc/flash_api.cpp | 66 +++++++++++++++++++++------------------------- 1 file changed, 30 insertions(+), 36 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 48a71ea..c9bc93a 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -35,13 +35,12 @@ void set_params_fprop( const size_t h_k, const size_t d, const size_t d_rounded, - const size_t keep_window_size, // device pointers const at::Tensor q, const at::Tensor k, const at::Tensor v, - const at::Tensor attn_mask, - const at::Tensor attn_bias, + const at::Tensor mask, + const at::Tensor bias, at::Tensor out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, @@ -65,32 +64,32 @@ void set_params_fprop( params.q_ptr = q.data_ptr(); params.k_ptr = k.data_ptr(); params.v_ptr = v.data_ptr(); - params.attn_mask_ptr = attn_mask.data_ptr(); - params.attn_bias_ptr = attn_bias.data_ptr(); + params.mask_ptr = mask.data_ptr(); + params.bias_ptr = bias.data_ptr(); params.o_ptr = out.data_ptr(); // All stride are in elements, not bytes. params.q_row_stride = q.stride(-3); params.k_row_stride = k.stride(-3); params.v_row_stride = v.stride(-3); - params.attn_mask_row_stride = attn_mask.stride(-2); - params.attn_bias_row_stride = attn_bias.stride(-2); + params.mask_row_stride = mask.stride(-2); + params.bias_row_stride = bias.stride(-2); params.o_row_stride = out.stride(-3); params.q_head_stride = q.stride(-2); params.k_head_stride = k.stride(-2); params.v_head_stride = v.stride(-2); - params.attn_mask_head_stride = attn_mask.stride(-3); - params.attn_bias_head_stride = attn_bias.stride(-3); + params.mask_head_stride = mask.stride(-3); + params.bias_head_stride = bias.stride(-3); params.o_head_stride = out.stride(-2); - params.attn_mask_col_stride = attn_mask.stride(-1); - params.attn_bias_col_stride = attn_bias.stride(-1); + params.mask_col_stride = mask.stride(-1); + params.bias_col_stride = bias.stride(-1); if (cu_seqlens_q_d == nullptr) { params.q_batch_stride = q.stride(0); params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); - params.attn_mask_batch_stride = attn_mask.stride(0); - params.attn_bias_batch_stride = attn_bias.stride(0); + params.mask_batch_stride = mask.stride(0); + params.bias_batch_stride = bias.stride(0); params.o_batch_stride = out.stride(0); if (seqlenq_ngroups_swapped) { params.q_batch_stride *= seqlen_q; @@ -119,7 +118,6 @@ void set_params_fprop( params.seqlen_k_rounded = seqlen_k_rounded; params.d = d; params.d_rounded = d_rounded; - params.keep_window_size = keep_window_size; // Set the different scale values. #ifdef FLASHATTENTION_DISABLE_SOFTCAP @@ -271,13 +269,12 @@ mha_fwd( at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &attn_mask, // batch_size x num_heads_k x seqlen_q x seqlen_k - const at::Tensor &attn_bias, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const float p_dropout, const float softmax_scale, bool is_causal, - const int keep_window_size, const float softcap, const bool return_softmax, std::optional gen_ @@ -294,10 +291,10 @@ mha_fwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(attn_mask.dtype() == q_dtype, "attn_mask must have the same dtype as inputs"); - TORCH_CHECK(attn_bias.dtype() == q_dtype, "attn_bias must have the same dtype as inputs"); + TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs"); + TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(attn_mask); CHECK_DEVICE(attn_bias); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); @@ -334,8 +331,8 @@ mha_fwd( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(attn_mask, batch_size, num_heads_k, seqlen_q, seqlen_k); - CHECK_SHAPE(attn_bias, batch_size, num_heads_k, seqlen_q, seqlen_k); + CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); at::Tensor out; if (out_.has_value()) { @@ -377,8 +374,7 @@ mha_fwd( seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, - keep_window_size, - q, k, v, attn_mask, attn_bias, out, + q, k, v, mask, bias, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, /*seqused_k=*/nullptr, @@ -436,8 +432,8 @@ mha_varlen_fwd( at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const at::Tensor &attn_mask, // total_q x num_heads_k x max_seqlen_k - const at::Tensor &attn_bias, // total_q x num_heads_k x max_seqlen_k + const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k + const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 @@ -450,7 +446,6 @@ mha_varlen_fwd( const float softmax_scale, const bool zero_tensors, bool is_causal, - const int keep_window_size, const float softcap, const bool return_softmax, std::optional gen_ @@ -465,12 +460,12 @@ mha_varlen_fwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(attn_mask.dtype() == q_dtype, "attn_mask must have the same dtype as inputs"); - TORCH_CHECK(attn_bias.dtype() == q_dtype, "attn_bias must have the same dtype as inputs"); + TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs"); + TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(attn_mask); CHECK_DEVICE(attn_bias); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); @@ -487,8 +482,8 @@ mha_varlen_fwd( TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(attn_mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(attn_bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); CHECK_CONTIGUOUS(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_k); @@ -533,8 +528,8 @@ mha_varlen_fwd( const int total_k = k.size(0); CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size); - CHECK_SHAPE(attn_mask, total_q, num_heads_k, max_seqlen_k); - CHECK_SHAPE(attn_bias, total_q, num_heads_k, max_seqlen_k); + CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); + CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); } else { CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); @@ -596,8 +591,7 @@ mha_varlen_fwd( seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, - keep_window_size, - q, k, v, attn_mask, attn_bias, out, + q, k, v, mask, bias, out, cu_seqlens_q_d, cu_seqlens_k.data_ptr(), seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, From 90b9ccfb2285466c11661e85cfbc9e95815528d4 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 28 Jul 2025 12:35:18 +0800 Subject: [PATCH 3/5] Simplifies method names by removing attn prefix Renames attn_mask_offset to mask_offset and attn_bias_offset to bias_offset to improve code readability and reduce verbosity while maintaining the same functionality. --- csrc/src/block_info.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/src/block_info.h b/csrc/src/block_info.h index d0758eb..59ce6e2 100644 --- a/csrc/src/block_info.h +++ b/csrc/src/block_info.h @@ -36,14 +36,14 @@ struct BlockInfo { } template - __forceinline__ __device__ index_t attn_mask_offset(const index_t batch_stride, int row_stride, const int col_stride, const int bidb) const { + __forceinline__ __device__ index_t mask_offset(const index_t batch_stride, int row_stride, const int col_stride, const int bidb) const { index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride; return offset; } template - __forceinline__ __device__ index_t attn_bias_offset(const index_t batch_stride, const int row_stride, const int col_stride, const int bidb + __forceinline__ __device__ index_t bias_offset(const index_t batch_stride, const int row_stride, const int col_stride, const int bidb ) const { index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride; From 41ecfd76d627c13d86acef2a4588f2b791a473f5 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 28 Jul 2025 12:37:14 +0800 Subject: [PATCH 4/5] Renames attention mask and bias parameters for consistency Simplifies parameter naming by removing the "attn_" prefix from mask and bias related variables throughout the flash attention kernel. Updates all references to use the shorter naming convention: - attn_mask_* becomes mask_* - attn_bias_* becomes bias_* Improves code readability and maintains consistency across parameter names while preserving all existing functionality. --- csrc/src/flash_fwd_kernel.h | 44 ++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 15e53df..33b911f 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -175,9 +175,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_coord(_, 0) ); // (kBlockN, kHeadDim, nblocksN) Tensor mMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.attn_mask_ptr) + binfo.attn_mask_offset(params.attn_mask_batch_stride, params.attn_mask_row_stride, params.attn_mask_col_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, params.mask_col_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.attn_mask_head_stride, params.attn_mask_row_stride, params.attn_mask_col_stride) + make_stride(params.mask_head_stride, params.mask_row_stride, params.mask_col_stride) ); Tensor gMask = local_tile( mMask(bidh / params.h_h_k_ratio, _, _), @@ -185,9 +185,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) Tensor mBias = make_tensor( - make_gmem_ptr(reinterpret_cast(params.attn_bias_ptr) + binfo.attn_bias_offset(params.attn_bias_batch_stride, params.attn_bias_row_stride, params.attn_bias_col_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.bias_ptr) + binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, params.bias_col_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.attn_bias_head_stride, params.attn_bias_row_stride, params.attn_bias_col_stride) + make_stride(params.bias_head_stride, params.bias_row_stride, params.bias_col_stride) ); Tensor gBias = local_tile( mBias(bidh / params.h_h_k_ratio, _, _), @@ -774,13 +774,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t col_offset_mask = block_table == nullptr - ? binfo.attn_mask_offset(params.attn_mask_batch_stride, params.attn_mask_row_stride, params.attn_mask_col_stride, bidb_cache) - + (bidh / params.h_h_k_ratio) * params.attn_mask_head_stride + m_block * kBlockM * params.attn_mask_row_stride + (n_block_max - 1) * kBlockN * params.attn_mask_col_stride - : block_table[block_table_idx] * params.attn_mask_batch_stride + (bidh / params.h_h_k_ratio) * params.attn_mask_head_stride + m_block * kBlockM * params.attn_mask_row_stride + block_table_offset * params.attn_mask_col_stride; + ? binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, params.mask_col_stride, bidb_cache) + + (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN * params.mask_col_stride + : block_table[block_table_idx] * params.mask_batch_stride + (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table_offset * params.mask_col_stride; const index_t col_offset_bias = block_table == nullptr - ? binfo.attn_bias_offset(params.attn_bias_batch_stride, params.attn_bias_row_stride, params.attn_bias_col_stride, bidb_cache) - + (bidh / params.h_h_k_ratio) * params.attn_bias_head_stride + m_block * kBlockM * params.attn_bias_row_stride + (n_block_max - 1) * kBlockN * params.attn_bias_col_stride - : block_table[block_table_idx] * params.attn_bias_batch_stride + (bidh / params.h_h_k_ratio) * params.attn_bias_head_stride + m_block * kBlockM * params.attn_bias_row_stride + block_table_offset * params.attn_bias_col_stride; + ? binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, params.bias_col_stride, bidb_cache) + + (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN * params.bias_col_stride + : block_table[block_table_idx] * params.bias_batch_stride + (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table_offset * params.bias_col_stride; // Global memory tensor configuration Tensor mQ = make_tensor( @@ -804,14 +804,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons make_stride(params.v_row_stride, _1{}) ); Tensor gMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.attn_mask_ptr) + col_offset_mask), + make_gmem_ptr(reinterpret_cast(params.mask_ptr) + col_offset_mask), Shape, Int>{}, - make_stride(params.attn_mask_row_stride, params.attn_mask_col_stride) + make_stride(params.mask_row_stride, params.mask_col_stride) ); Tensor gBias = make_tensor( - make_gmem_ptr(reinterpret_cast(params.attn_bias_ptr) + col_offset_bias), + make_gmem_ptr(reinterpret_cast(params.bias_ptr) + col_offset_bias), Shape, Int>{}, - make_stride(params.attn_bias_row_stride, params.attn_bias_col_stride) + make_stride(params.bias_row_stride, params.bias_col_stride) ); // Shared memory layout configuration @@ -1037,16 +1037,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Advance gK if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockN * params.attn_mask_col_stride)); - tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockN * params.attn_bias_col_stride)); + tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockN * params.mask_col_stride)); + tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockN * params.bias_col_stride)); } else { const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; - tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.attn_mask_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.attn_mask_col_stride; - tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.attn_bias_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.attn_bias_col_stride; + tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.mask_col_stride; + tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.bias_col_stride; } FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization @@ -1147,16 +1147,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Advance gK if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockN * params.attn_mask_col_stride)); - tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockN * params.attn_bias_col_stride)); + tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockN * params.mask_col_stride)); + tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockN * params.bias_col_stride)); } else { const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; - tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.attn_mask_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.attn_mask_col_stride; - tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.attn_bias_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.attn_bias_col_stride; + tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.mask_col_stride; + tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.bias_col_stride; } FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); FLASH_NAMESPACE::copy_MN( From 1efb25e6746c2fd07796f8e7d4418e9c7edf9d66 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 28 Jul 2025 12:41:44 +0800 Subject: [PATCH 5/5] Standardizes parameter types for consistency Changes row_stride and col_stride parameters from int to index_t template type in mask_offset and bias_offset methods. Ensures type consistency across all stride parameters and eliminates potential type mismatches in offset calculations. --- csrc/src/block_info.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/src/block_info.h b/csrc/src/block_info.h index 59ce6e2..46102f9 100644 --- a/csrc/src/block_info.h +++ b/csrc/src/block_info.h @@ -36,15 +36,14 @@ struct BlockInfo { } template - __forceinline__ __device__ index_t mask_offset(const index_t batch_stride, int row_stride, const int col_stride, const int bidb) const { + __forceinline__ __device__ index_t mask_offset(const index_t batch_stride, const index_t row_stride, const index_t col_stride, const int bidb) const { index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride; return offset; } template - __forceinline__ __device__ index_t bias_offset(const index_t batch_stride, const int row_stride, const int col_stride, const int bidb - ) const { + __forceinline__ __device__ index_t bias_offset(const index_t batch_stride, const index_t row_stride, const index_t col_stride, const int bidb) const { index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride; return offset;