Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 47 additions & 27 deletions csrc/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t col_offset_zoh = block_table == nullptr
? binfo.zoh_offset(params.zoh_batch_stride, params.zoh_row_stride, params.zoh_col_stride, bidb_cache)
+ (bidh / params.h_h_k_ratio) * params.zoh_head_stride + m_block * kBlockM * params.zoh_row_stride + (n_block_max - 1) * kBlockN * params.zoh_col_stride
: block_table[block_table_idx] * params.zoh_batch_stride + (bidh / params.h_h_k_ratio) * params.zoh_head_stride + m_block * kBlockM * params.zoh_row_stride + block_table_offset * params.zoh_col_stride;
const index_t col_offset_am = block_table == nullptr
? binfo.active_mask_offset(params.active_mask_batch_stride, params.active_mask_row_stride, params.active_mask_col_stride, bidb_cache)
+ (bidh / params.h_h_k_ratio) * params.active_mask_head_stride + m_block * kBlockM * params.active_mask_row_stride + (n_block_max - 1) * kBlockN * params.active_mask_col_stride
: block_table[block_table_idx] * params.active_mask_batch_stride + (bidh / params.h_h_k_ratio) * params.active_mask_head_stride + m_block * kBlockM * params.active_mask_row_stride + block_table_offset * params.active_mask_col_stride;

// Global memory tensor configuration
Tensor mQ = make_tensor(
Expand All @@ -790,26 +798,24 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{})
);
Tensor mZOH = make_tensor(
make_gmem_ptr(reinterpret_cast<Element*>(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, params.zoh_row_stride, bidb)),
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.zoh_head_stride, params.zoh_row_stride, _1{})
);
Tensor gZOH = local_tile(
mZOH(bidh / params.h_h_k_ratio, _, _),
Tensor gZOH = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.zoh_ptr) + col_offset_zoh),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_coord(m_block, _)
); // (kBlockM, kBlockN, nblocksN)
Tensor mActiveMask = make_tensor(
make_gmem_ptr(reinterpret_cast<Element*>(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, params.active_mask_row_stride, bidb)),
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.active_mask_head_stride, params.active_mask_row_stride, _1{})
make_stride(params.zoh_row_stride, params.zoh_col_stride)
);
Tensor gActiveMask = local_tile(
mActiveMask(bidh / params.h_h_k_ratio, _, _),
Tensor gActiveMask = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.active_mask_ptr) + col_offset_am),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_coord(m_block, _)
); // (kBlockM, kBlockN, nblocksN)
make_stride(params.active_mask_row_stride, params.active_mask_col_stride)
);
Copy link

Copilot AI Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider wrapping the debug print statements with a compile-time flag (e.g., #ifdef DEBUG) to avoid unintentional output in production builds.

Suggested change
);
);
#ifdef DEBUG

Copilot uses AI. Check for mistakes.
if (tidx == 0 && bidh == 0 && bidb == 0 && m_block == 0) {
printf("SplitKV: m_block=%d, n_block_max=%d, n_block_min=%d, n_split_idx=%d\n",
m_block, n_block_max, n_block_min, n_split_idx);
printf("col_offset_zoh=%ld, row_offset=%ld, zoh_ptr=%p, final_ptr=%p\n",
col_offset_zoh, m_block * kBlockM * params.zoh_row_stride,
params.zoh_ptr,
reinterpret_cast<Element *>(params.zoh_ptr) + col_offset_zoh + m_block * kBlockM * params.zoh_row_stride);
}

// Shared memory layout configuration
Tensor sQ = make_tensor(
Expand Down Expand Up @@ -863,12 +869,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Matrix Multiply Accumulate
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K)
Tensor tSrZOH = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
Tensor tSrAM = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N)
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (MMA, MMA_M, MMA_K)

// Copy Atom retiling
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
Expand Down Expand Up @@ -1062,17 +1068,23 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN);
if (tidx == 0 && bidh == 0 && bidb == 0) {
printf("Before copy_ZOH: n_block=%d, seqlen_q_offset=%d, seqlen_k_offset=%d\n",
n_block,
binfo.actual_seqlen_q - m_block * kBlockM,
binfo.actual_seqlen_k - n_block * kBlockN);
}
Comment on lines 1070 to +1076
Copy link

Copilot AI Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, consider guarding these debug prints with a compile-time condition to prevent extra overhead in production runs.

Suggested change
binfo.actual_seqlen_k - n_block * kBlockN);
if (tidx == 0 && bidh == 0 && bidb == 0) {
printf("Before copy_ZOH: n_block=%d, seqlen_q_offset=%d, seqlen_k_offset=%d\n",
n_block,
binfo.actual_seqlen_q - m_block * kBlockM,
binfo.actual_seqlen_k - n_block * kBlockN);
}
binfo.actual_seqlen_k - n_block * kBlockN);
#ifdef DEBUG
if (tidx == 0 && bidh == 0 && bidb == 0) {
printf("Before copy_ZOH: n_block=%d, seqlen_q_offset=%d, seqlen_k_offset=%d\n",
n_block,
binfo.actual_seqlen_q - m_block * kBlockM,
binfo.actual_seqlen_k - n_block * kBlockN);
}
#endif

Copilot uses AI. Check for mistakes.
FLASH_NAMESPACE::copy_ZOH<Is_even_MN>(
gmem_tiled_copy_ZOH,
tZOHgZOH(_, _, _, n_block),
tZOHgZOH,
tZOHsZOH,
tZOHcZOH,
binfo.actual_seqlen_q - m_block * kBlockM,
binfo.actual_seqlen_k - n_block * kBlockN
);
FLASH_NAMESPACE::copy_ZOH<Is_even_MN>(
gmem_tiled_copy_AM,
tAMgAM(_, _, _, n_block),
tAMgAM,
tAMsAM,
tAMcAM,
binfo.actual_seqlen_q - m_block * kBlockM,
Expand Down Expand Up @@ -1167,27 +1179,31 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Advance gK
if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
tZOHgZOH.data() = tZOHgZOH.data() + (-int(kBlockN * params.zoh_col_stride));
tAMgAM.data() = tAMgAM.data() + (-int(kBlockN * params.active_mask_col_stride));
} else {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
tZOHgZOH.data() = tZOHgZOH.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.zoh_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.zoh_col_stride;
tAMgAM.data() = tAMgAM.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.active_mask_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.active_mask_col_stride;
}
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
FLASH_NAMESPACE::copy_ZOH</*Is_even_MN=*/true>(
gmem_tiled_copy_ZOH,
tZOHgZOH(_, _, _, n_block - 1),
tZOHgZOH,
tZOHsZOH,
tZOHcZOH,
binfo.actual_seqlen_q - m_block * kBlockM,
binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
FLASH_NAMESPACE::copy_ZOH</*Is_even_MN=*/true>(
gmem_tiled_copy_AM,
tAMgAM(_, _, _, n_block - 1),
tAMgAM,
tAMsAM,
tAMcAM,
binfo.actual_seqlen_q - m_block * kBlockM,
Expand Down Expand Up @@ -1267,25 +1283,29 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Advance gK
if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
tZOHgZOH.data() = tZOHgZOH.data() + (-int(kBlockN * params.zoh_col_stride));
tAMgAM.data() = tAMgAM.data() + (-int(kBlockN * params.active_mask_col_stride));
} else {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
tZOHgZOH.data() = tZOHgZOH.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.zoh_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.zoh_col_stride;
tAMgAM.data() = tAMgAM.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.active_mask_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.active_mask_col_stride;
}
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
FLASH_NAMESPACE::copy_ZOH</*Is_even_MN=*/true>(
gmem_tiled_copy_ZOH,
tZOHgZOH(_, _, _, n_block - 1),
tZOHgZOH,
tZOHsZOH,
tZOHcZOH,
binfo.actual_seqlen_q - m_block * kBlockM,
binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
FLASH_NAMESPACE::copy_ZOH</*Is_even_MN=*/true>(
gmem_tiled_copy_AM,
tAMgAM(_, _, _, n_block - 1),
tAMgAM,
tAMsAM,
tAMcAM,
binfo.actual_seqlen_q - m_block * kBlockM,
Expand Down