From 9a6b826b0a8ca255ad2b630355927d8db68e7f5a Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 23 Jun 2025 13:59:16 +0800 Subject: [PATCH] Updates tensor offset calculations and formatting Adds missing row stride parameters to offset calculations for ZOH and active mask tensors, ensuring proper memory layout access. Improves code readability by standardizing comment formatting and alignment for tensor partition declarations. Removes template parameter from DynamicMask class instantiation, simplifying the interface. --- csrc/src/flash_fwd_kernel.h | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 42e352f..99cb5e4 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -19,6 +19,7 @@ #include "softmax.h" #include "mask.h" #include "dropout.h" +#include "rotary.h" namespace FLASH_NAMESPACE { @@ -179,7 +180,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_stride(params.seqlen_k_rounded, _1{}) ); Tensor mZOH = make_tensor( - make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, params.zoh_row_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), make_stride(params.zoh_head_stride, params.zoh_row_stride, _1{}) ); @@ -189,7 +190,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) Tensor mActiveMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, params.active_mask_row_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), make_stride(params.active_mask_head_stride, params.active_mask_row_stride, _1{}) ); @@ -240,11 +241,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tZOHgZOH = gmem_thr_copy_ZOH.partition_S(gZOH); // (ZOHCPY, ZOHCPY_M, ZOHCPY_N, nblocksN) + Tensor tZOHgZOH = gmem_thr_copy_ZOH.partition_S(gZOH); // (ZOHCPY, ZOHCPY_M, ZOHCPY_N, nblocksN) Tensor tZOHsZOH = gmem_thr_copy_ZOH.partition_D(sZOH); Tensor tAMgAM = gmem_thr_copy_AM.partition_S(gActiveMask); // (AMCPY, AMCPY_M, AMCPY_N, nblocksN) Tensor tAMsAM = gmem_thr_copy_AM.partition_D(sActiveMask); @@ -252,13 +253,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K) Tensor tSrZOH = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) Tensor tSrAM = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) - Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N) Tensor tSgS = thr_mma.partition_C(gP); - Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_K) // Copy Atom retiling auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); @@ -383,7 +384,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; // Init dynamic mask processor - FLASH_NAMESPACE::DynamicMask dynamic_mask( + FLASH_NAMESPACE::DynamicMask dynamic_mask( binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.keep_window_size );