diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index e0176bc..15e53df 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -231,7 +231,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi typename Kernel_traits::SmemLayoutBias{} ); - // Golobal to Shared Memory operation + // Global to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; @@ -273,10 +273,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); @@ -351,7 +351,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask(_, _, _, n_block), tMasksMask, @@ -359,7 +359,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block), tBiassBias, @@ -449,7 +449,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask(_, _, _, n_block - 1), tMasksMask, @@ -457,7 +457,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, @@ -549,7 +549,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); if (n_block > n_block_min) { FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask(_, _, _, n_block - 1), tMasksMask, @@ -557,7 +557,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, @@ -883,10 +883,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); @@ -932,7 +932,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, @@ -940,7 +940,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, @@ -1051,7 +1051,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, @@ -1059,7 +1059,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, @@ -1159,7 +1159,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.attn_bias_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.attn_bias_col_stride; } FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, @@ -1167,7 +1167,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index 281a2dd..15c773f 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -107,9 +107,11 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutMask = decltype(tile_to_shape( SmemLayoutAtomMask{}, Shape, Int>{})); + using SmemCopyAtomMask = Copy_Atom, Element>; using SmemLayoutBias = decltype(tile_to_shape( SmemLayoutAtomBias{}, Shape, Int>{})); + using SmemCopyAtomBias = Copy_Atom, Element>; // Shared memory layout for output using SmemLayoutAtomO = decltype( @@ -268,6 +270,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutMask = decltype(tile_to_shape( SmemLayoutAtomMask{}, make_shape(Int{}, Int{}))); + using SmemCopyAtomMask = Copy_Atom, elem_type>; using SmemLayoutAtomBias = decltype( composition(Swizzle{}, @@ -276,6 +279,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutBias = decltype(tile_to_shape( SmemLayoutAtomBias{}, make_shape(Int{}, Int{}))); + using SmemCopyAtomBias = Copy_Atom, elem_type>; // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 @@ -322,7 +326,6 @@ struct Flash_bwd_kernel_traits : public Base { SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom, elem_type>; - using SmemCopyAtomBias = Copy_Atom, elem_type>; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); diff --git a/csrc/src/utils.h b/csrc/src/utils.h index b3c71e1..1828abe 100644 --- a/csrc/src/utils.h +++ b/csrc/src/utils.h @@ -495,7 +495,7 @@ __forceinline__ __device__ void copy( template -__forceinline__ __device__ void copy_Mask( +__forceinline__ __device__ void copy_MN( TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, const int max_M=0, const int max_N=0