diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 42e352f..99cb5e4 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -19,6 +19,7 @@ #include "softmax.h" #include "mask.h" #include "dropout.h" +#include "rotary.h" namespace FLASH_NAMESPACE { @@ -179,7 +180,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_stride(params.seqlen_k_rounded, _1{}) ); Tensor mZOH = make_tensor( - make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, params.zoh_row_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), make_stride(params.zoh_head_stride, params.zoh_row_stride, _1{}) ); @@ -189,7 +190,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) Tensor mActiveMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, params.active_mask_row_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), make_stride(params.active_mask_head_stride, params.active_mask_row_stride, _1{}) ); @@ -240,11 +241,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tZOHgZOH = gmem_thr_copy_ZOH.partition_S(gZOH); // (ZOHCPY, ZOHCPY_M, ZOHCPY_N, nblocksN) + Tensor tZOHgZOH = gmem_thr_copy_ZOH.partition_S(gZOH); // (ZOHCPY, ZOHCPY_M, ZOHCPY_N, nblocksN) Tensor tZOHsZOH = gmem_thr_copy_ZOH.partition_D(sZOH); Tensor tAMgAM = gmem_thr_copy_AM.partition_S(gActiveMask); // (AMCPY, AMCPY_M, AMCPY_N, nblocksN) Tensor tAMsAM = gmem_thr_copy_AM.partition_D(sActiveMask); @@ -252,13 +253,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K) Tensor tSrZOH = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) Tensor tSrAM = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) - Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N) Tensor tSgS = thr_mma.partition_C(gP); - Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_K) // Copy Atom retiling auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); @@ -383,7 +384,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; // Init dynamic mask processor - FLASH_NAMESPACE::DynamicMask dynamic_mask( + FLASH_NAMESPACE::DynamicMask dynamic_mask( binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.keep_window_size );