From 88f7ba4d48286bee9210e79df5b714da9c5c64c5 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 14 Jul 2025 10:09:31 +0800 Subject: [PATCH 1/4] Fixes function name and typo in Flash attention kernel Corrects function calls from copy_Mask to copy_MN throughout the kernel implementation and fixes spelling error in comment from "Golobal" to "Global". These changes ensure proper function naming consistency and improve code readability. --- csrc/src/flash_fwd_kernel.h | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index e0176bc..bfb66f0 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -231,7 +231,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi typename Kernel_traits::SmemLayoutBias{} ); - // Golobal to Shared Memory operation + // Global to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; @@ -351,7 +351,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask(_, _, _, n_block), tMasksMask, @@ -359,7 +359,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block), tBiassBias, @@ -449,7 +449,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask(_, _, _, n_block - 1), tMasksMask, @@ -457,7 +457,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, @@ -549,7 +549,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); if (n_block > n_block_min) { FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask(_, _, _, n_block - 1), tMasksMask, @@ -557,7 +557,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, @@ -932,7 +932,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, @@ -940,7 +940,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, @@ -1051,7 +1051,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, @@ -1059,7 +1059,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, @@ -1159,7 +1159,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.attn_bias_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.attn_bias_col_stride; } FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, @@ -1167,7 +1167,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - FLASH_NAMESPACE::copy_Mask( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, From 6e1f369634c931d29c8b6872852854522f4a3ec6 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 14 Jul 2025 12:11:09 +0800 Subject: [PATCH 2/4] Renames function to better reflect its purpose Changes function name from copy_Mask to copy_MN to improve code clarity and better indicate the function's role in copying matrix dimensions. --- csrc/src/utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/src/utils.h b/csrc/src/utils.h index b3c71e1..1828abe 100644 --- a/csrc/src/utils.h +++ b/csrc/src/utils.h @@ -495,7 +495,7 @@ __forceinline__ __device__ void copy( template -__forceinline__ __device__ void copy_Mask( +__forceinline__ __device__ void copy_MN( TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, const int max_M=0, const int max_N=0 From 47c3dd5cc2dcfaf340e23c19f9f0e7af881c72d4 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 14 Jul 2025 22:20:05 +0800 Subject: [PATCH 3/4] Standardizes copy atom alignment for mask and bias Unifies memory copy alignment to 128 bytes for both mask and bias operations in forward and backward kernel traits. Removes inconsistent 64-byte alignment for bias copy atom in backward kernel and establishes consistent 128-byte alignment across all mask and bias copy operations. --- csrc/src/kernel_traits.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index 281a2dd..15c773f 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -107,9 +107,11 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutMask = decltype(tile_to_shape( SmemLayoutAtomMask{}, Shape, Int>{})); + using SmemCopyAtomMask = Copy_Atom, Element>; using SmemLayoutBias = decltype(tile_to_shape( SmemLayoutAtomBias{}, Shape, Int>{})); + using SmemCopyAtomBias = Copy_Atom, Element>; // Shared memory layout for output using SmemLayoutAtomO = decltype( @@ -268,6 +270,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutMask = decltype(tile_to_shape( SmemLayoutAtomMask{}, make_shape(Int{}, Int{}))); + using SmemCopyAtomMask = Copy_Atom, elem_type>; using SmemLayoutAtomBias = decltype( composition(Swizzle{}, @@ -276,6 +279,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutBias = decltype(tile_to_shape( SmemLayoutAtomBias{}, make_shape(Int{}, Int{}))); + using SmemCopyAtomBias = Copy_Atom, elem_type>; // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 @@ -322,7 +326,6 @@ struct Flash_bwd_kernel_traits : public Base { SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom, elem_type>; - using SmemCopyAtomBias = Copy_Atom, elem_type>; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); From 417e5054503be087976cfd5bed9d8380ed4464b9 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 14 Jul 2025 22:21:07 +0800 Subject: [PATCH 4/4] Fixes shared memory copy atom types for mask and bias Replaces incorrect SmemCopyAtomO with proper SmemCopyAtomMask and SmemCopyAtomBias types in tiled copy operations. Ensures type safety and correct memory access patterns for mask and bias tensors in attention computation kernels. --- csrc/src/flash_fwd_kernel.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index bfb66f0..15e53df 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -273,10 +273,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); @@ -883,10 +883,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);