diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 0e1a15e..42b78e0 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -768,6 +768,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t col_offset_zoh = block_table == nullptr + ? binfo.zoh_offset(params.zoh_batch_stride, params.zoh_row_stride, params.zoh_col_stride, bidb_cache) + + (bidh / params.h_h_k_ratio) * params.zoh_head_stride + m_block * kBlockM * params.zoh_row_stride + (n_block_max - 1) * kBlockN * params.zoh_col_stride + : block_table[block_table_idx] * params.zoh_batch_stride + (bidh / params.h_h_k_ratio) * params.zoh_head_stride + m_block * kBlockM * params.zoh_row_stride + block_table_offset * params.zoh_col_stride; + const index_t col_offset_am = block_table == nullptr + ? binfo.active_mask_offset(params.active_mask_batch_stride, params.active_mask_row_stride, params.active_mask_col_stride, bidb_cache) + + (bidh / params.h_h_k_ratio) * params.active_mask_head_stride + m_block * kBlockM * params.active_mask_row_stride + (n_block_max - 1) * kBlockN * params.active_mask_col_stride + : block_table[block_table_idx] * params.active_mask_batch_stride + (bidh / params.h_h_k_ratio) * params.active_mask_head_stride + m_block * kBlockM * params.active_mask_row_stride + block_table_offset * params.active_mask_col_stride; // Global memory tensor configuration Tensor mQ = make_tensor( @@ -790,26 +798,24 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Shape, Int>{}, make_stride(params.v_row_stride, _1{}) ); - Tensor mZOH = make_tensor( - make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, params.zoh_row_stride, bidb)), - make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.zoh_head_stride, params.zoh_row_stride, _1{}) - ); - Tensor gZOH = local_tile( - mZOH(bidh / params.h_h_k_ratio, _, _), + Tensor gZOH = make_tensor( + make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + col_offset_zoh), Shape, Int>{}, - make_coord(m_block, _) - ); // (kBlockM, kBlockN, nblocksN) - Tensor mActiveMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, params.active_mask_row_stride, bidb)), - make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.active_mask_head_stride, params.active_mask_row_stride, _1{}) + make_stride(params.zoh_row_stride, params.zoh_col_stride) ); - Tensor gActiveMask = local_tile( - mActiveMask(bidh / params.h_h_k_ratio, _, _), + Tensor gActiveMask = make_tensor( + make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + col_offset_am), Shape, Int>{}, - make_coord(m_block, _) - ); // (kBlockM, kBlockN, nblocksN) + make_stride(params.active_mask_row_stride, params.active_mask_col_stride) + ); + if (tidx == 0 && bidh == 0 && bidb == 0 && m_block == 0) { + printf("SplitKV: m_block=%d, n_block_max=%d, n_block_min=%d, n_split_idx=%d\n", + m_block, n_block_max, n_block_min, n_split_idx); + printf("col_offset_zoh=%ld, row_offset=%ld, zoh_ptr=%p, final_ptr=%p\n", + col_offset_zoh, m_block * kBlockM * params.zoh_row_stride, + params.zoh_ptr, + reinterpret_cast(params.zoh_ptr) + col_offset_zoh + m_block * kBlockM * params.zoh_row_stride); + } // Shared memory layout configuration Tensor sQ = make_tensor( @@ -863,12 +869,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K) Tensor tSrZOH = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) Tensor tSrAM = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) - Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) - Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N) + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_K) // Copy Atom retiling auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); @@ -1062,9 +1068,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); + if (tidx == 0 && bidh == 0 && bidb == 0) { + printf("Before copy_ZOH: n_block=%d, seqlen_q_offset=%d, seqlen_k_offset=%d\n", + n_block, + binfo.actual_seqlen_q - m_block * kBlockM, + binfo.actual_seqlen_k - n_block * kBlockN); + } FLASH_NAMESPACE::copy_ZOH( gmem_tiled_copy_ZOH, - tZOHgZOH(_, _, _, n_block), + tZOHgZOH, tZOHsZOH, tZOHcZOH, binfo.actual_seqlen_q - m_block * kBlockM, @@ -1072,7 +1084,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); FLASH_NAMESPACE::copy_ZOH( gmem_tiled_copy_AM, - tAMgAM(_, _, _, n_block), + tAMgAM, tAMsAM, tAMcAM, binfo.actual_seqlen_q - m_block * kBlockM, @@ -1167,19 +1179,23 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Advance gK if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tZOHgZOH.data() = tZOHgZOH.data() + (-int(kBlockN * params.zoh_col_stride)); + tAMgAM.data() = tAMgAM.data() + (-int(kBlockN * params.active_mask_col_stride)); } else { const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + tZOHgZOH.data() = tZOHgZOH.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.zoh_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.zoh_col_stride; + tAMgAM.data() = tAMgAM.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.active_mask_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.active_mask_col_stride; } FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. FLASH_NAMESPACE::copy_ZOH( gmem_tiled_copy_ZOH, - tZOHgZOH(_, _, _, n_block - 1), + tZOHgZOH, tZOHsZOH, tZOHcZOH, binfo.actual_seqlen_q - m_block * kBlockM, @@ -1187,7 +1203,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); FLASH_NAMESPACE::copy_ZOH( gmem_tiled_copy_AM, - tAMgAM(_, _, _, n_block - 1), + tAMgAM, tAMsAM, tAMcAM, binfo.actual_seqlen_q - m_block * kBlockM, @@ -1267,17 +1283,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Advance gK if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tZOHgZOH.data() = tZOHgZOH.data() + (-int(kBlockN * params.zoh_col_stride)); + tAMgAM.data() = tAMgAM.data() + (-int(kBlockN * params.active_mask_col_stride)); } else { const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + tZOHgZOH.data() = tZOHgZOH.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.zoh_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.zoh_col_stride; + tAMgAM.data() = tAMgAM.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.active_mask_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.active_mask_col_stride; } FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); FLASH_NAMESPACE::copy_ZOH( gmem_tiled_copy_ZOH, - tZOHgZOH(_, _, _, n_block - 1), + tZOHgZOH, tZOHsZOH, tZOHcZOH, binfo.actual_seqlen_q - m_block * kBlockM, @@ -1285,7 +1305,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); FLASH_NAMESPACE::copy_ZOH( gmem_tiled_copy_AM, - tAMgAM(_, _, _, n_block - 1), + tAMgAM, tAMsAM, tAMcAM, binfo.actual_seqlen_q - m_block * kBlockM,