From 395e9a4683966ac078b9305aafb8ed8d9096b3f6 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Tue, 26 Aug 2025 18:57:44 +0800 Subject: [PATCH 1/8] Consolidates mask and bias layouts into unified PS layout Replaces separate SmemLayoutAtomMask and SmemLayoutAtomBias with a single SmemLayoutAtomPS to reduce code duplication and improve memory layout consistency. Introduces kPBlockN parameter with configurable block sizes (16, 32, or 64) and dedicated swizzle pattern for better memory access patterns. Unifies global memory copy operations for mask and bias into a single GmemTiledCopyMaskBias with improved 128-bit alignment and 8 values per read. --- csrc/src/kernel_traits.h | 50 +++++++++++++--------------------------- 1 file changed, 16 insertions(+), 34 deletions(-) diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index f7a38d2..c899c42 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -73,6 +73,9 @@ struct Flash_fwd_kernel_traits : public Base { static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + static constexpr int kSwizzlePS = 3; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, @@ -89,18 +92,11 @@ struct Flash_fwd_kernel_traits : public Base { Stride, _1>>{} ) ); - using SmemLayoutAtomMask = decltype( - composition( - Swizzle{}, - Layout, - Stride<_8, _1>>{} - ) - ); - using SmemLayoutAtomBias = decltype( + using SmemLayoutAtomPS = decltype( composition( - Swizzle{}, - Layout, - Stride<_8, _1>>{} + Swizzle{}, + Layout, Int>, + Stride, _1>>{} ) ); @@ -127,20 +123,13 @@ struct Flash_fwd_kernel_traits : public Base { ); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); - using SmemLayoutMask = decltype( - tile_to_shape( - SmemLayoutAtomMask{}, - Shape, Int>{} - ) - ); - using SmemCopyAtomMask = Copy_Atom, Element>; - using SmemLayoutBias = decltype( + using SmemLayoutPS = decltype( tile_to_shape( - SmemLayoutAtomBias{}, + SmemLayoutAtomPS{}, Shape, Int>{} ) ); - using SmemCopyAtomBias = Copy_Atom, Element>; + using SmemCopyAtomPS = Copy_Atom, Element>; // Shared memory layout for output using SmemLayoutAtomO = decltype( @@ -162,8 +151,8 @@ struct Flash_fwd_kernel_traits : public Base { // Shared memory size calculations static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); - static constexpr int kSmemMaskSize = size(SmemLayoutMask{}) * sizeof(Element); - static constexpr int kSmemBiasSize = size(SmemLayoutBias{}) * sizeof(Element); + static constexpr int kSmemMaskSize = size(SmemLayoutPS{}) * sizeof(Element); + static constexpr int kSmemBiasSize = size(SmemLayoutPS{}) * sizeof(Element); // Shared memory size with QKV matrices and mask/bias matrices static constexpr int kSmemSize = (Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize) + kSmemMaskSize + kSmemBiasSize; @@ -196,20 +185,13 @@ struct Flash_fwd_kernel_traits : public Base { Layout>{} ) ); // Val layout, 8 vals per read - using GmemTiledCopyMask = decltype( - make_tiled_copy( - Copy_Atom, Element>{}, - GmemLayoutAtom{}, - Layout>{} - ) - ); // Val layout, 4 vals per read - using GmemTiledCopyBias = decltype( + using GmemTiledCopyMaskBias = decltype( make_tiled_copy( - Copy_Atom, Element>{}, + Copy_Atom, Element>{}, GmemLayoutAtom{}, - Layout>{} + Layout>{} ) - ); // Val layout, 4 vals per read + ); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy( Copy_Atom, Element>{}, From 29912d4bb9295ea461691e7da1c0c9732f85fbc5 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Tue, 26 Aug 2025 18:58:06 +0800 Subject: [PATCH 2/8] Unifies mask and bias memory operations using shared layout Consolidates separate mask and bias memory copy operations into a unified approach by replacing distinct layout types and copy operations with shared AtomPS layout and MaskBias copy operations. Reduces code duplication and improves memory access patterns by using the same layout configuration for both mask and bias tensors in shared memory operations. --- csrc/src/flash_fwd_kernel.h | 92 ++++++++++++++++++------------------- 1 file changed, 44 insertions(+), 48 deletions(-) diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 9c72926..a146602 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -218,31 +218,29 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); Tensor sMask = make_tensor( sV.data() + size(sV), - typename Kernel_traits::SmemLayoutMask{} + typename Kernel_traits::SmemLayoutAtomPS{} ); Tensor sBias = make_tensor( sMask.data() + size(sMask), - typename Kernel_traits::SmemLayoutBias{} + typename Kernel_traits::SmemLayoutAtomPS{} ); // Global to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; - auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; - auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias; + auto gmem_thr_copy_MaskBias = gmem_tiled_copy_MaskBias.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tMaskgMask = gmem_thr_copy_Mask.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N, nblocksN) - Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask); - Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N, nblocksN) - Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias); + Tensor tMaskgMask = gmem_thr_copy_MaskBias.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N, nblocksN) + Tensor tMasksMask = gmem_thr_copy_MaskBias.partition_D(sMask); + Tensor tBiasgBias = gmem_thr_copy_MaskBias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N, nblocksN) + Tensor tBiassBias = gmem_thr_copy_MaskBias.partition_D(sBias); // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; @@ -267,10 +265,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); @@ -298,10 +296,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // printf("\n"); // } // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) - Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) - Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) + Tensor tMaskcMask = gmem_thr_copy_MaskBias.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_MaskBias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); @@ -354,13 +352,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi binfo.actual_seqlen_k - n_block * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Mask, + gmem_tiled_copy_MaskBias, tMaskgMask(_, _, _, n_block), tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Bias, + gmem_tiled_copy_MaskBias, tBiasgBias(_, _, _, n_block), tBiassBias, tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN @@ -460,13 +458,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Mask, + gmem_tiled_copy_MaskBias, tMaskgMask(_, _, _, n_block - 1), tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Bias, + gmem_tiled_copy_MaskBias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN @@ -558,13 +556,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKVcKV, tKVpKV ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Mask, + gmem_tiled_copy_MaskBias, tMaskgMask(_, _, _, n_block - 1), tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Bias, + gmem_tiled_copy_MaskBias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN @@ -845,31 +843,29 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); Tensor sMask = make_tensor( sV.data() + size(sV), - typename Kernel_traits::SmemLayoutMask{} + typename Kernel_traits::SmemLayoutAtomPS{} ); Tensor sBias = make_tensor( sMask.data() + size(sMask), - typename Kernel_traits::SmemLayoutBias{} + typename Kernel_traits::SmemLayoutAtomPS{} ); // Global to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; - auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; - auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias; + auto gmem_thr_copy_MaskBias = gmem_tiled_copy_MaskBias.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tMaskgMask = gmem_thr_copy_Mask.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) - Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask); - Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) - Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias); + Tensor tMaskgMask = gmem_thr_copy_MaskBias.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) + Tensor tMasksMask = gmem_thr_copy_MaskBias.partition_D(sMask); + Tensor tBiasgBias = gmem_thr_copy_MaskBias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) + Tensor tBiassBias = gmem_thr_copy_MaskBias.partition_D(sBias); // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; @@ -891,10 +887,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); @@ -907,10 +903,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor cMask = make_identity_tensor(make_shape(size<0>(sMask), size<1>(sMask))); // (BLK_M, BLK_N) -> (blk_m, blk_n) Tensor cBias = make_identity_tensor(make_shape(size<0>(sBias), size<1>(sBias))); // (BLK_M, BLK_N) -> (blk_m, blk_n) // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) - Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) - Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) + Tensor tMaskcMask = gmem_thr_copy_MaskBias.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_MaskBias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); @@ -947,13 +943,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_k - n_block * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Mask, + gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Bias, + gmem_tiled_copy_MaskBias, tBiasgBias, tBiassBias, tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN @@ -1074,13 +1070,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Mask, + gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Bias, + gmem_tiled_copy_MaskBias, tBiasgBias, tBiassBias, tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN @@ -1190,13 +1186,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKVcKV, tKVpKV ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Mask, + gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_Bias, + gmem_tiled_copy_MaskBias, tBiasgBias, tBiassBias, tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN From 56c7c9bcff567c3525599593412c991bcd785c6e Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Tue, 26 Aug 2025 22:07:33 +0800 Subject: [PATCH 3/8] Removes unused PBlockN parameter from kernel traits Eliminates the kPBlockN constant and its static assertion since it was not being used effectively in the layout configuration. Simplifies the SmemLayoutAtomPS composition by directly using kBlockN instead of the intermediate kPBlockN parameter. --- csrc/src/kernel_traits.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index c899c42..3319f71 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -73,8 +73,6 @@ struct Flash_fwd_kernel_traits : public Base { static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; - static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; - static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); static constexpr int kSwizzlePS = 3; using TiledMma = TiledMMA< @@ -95,8 +93,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutAtomPS = decltype( composition( Swizzle{}, - Layout, Int>, - Stride, _1>>{} + Layout, Int>, + Stride, _1>>{} ) ); From 834e334e91e8d503cc8f7c7f32d4b71c8d08996f Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Tue, 26 Aug 2025 22:07:47 +0800 Subject: [PATCH 4/8] Replaces compute capability checks with shared memory queries Improves kernel selection logic by dynamically querying device shared memory limits instead of relying on hardcoded compute capability checks. Uses actual shared memory per block availability to determine optimal kernel configurations, enabling better performance across different GPU architectures without requiring architecture-specific branching logic. Simplifies the codebase by removing compute capability detection and associated conditional logic while maintaining performance optimization goals. --- csrc/src/flash_fwd_launch_template.h | 115 ++++++++++++++------------- 1 file changed, 61 insertions(+), 54 deletions(-) diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index 89feeae..d4d8b57 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -162,78 +162,89 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; - run_flash_fwd, Is_causal>(params, stream); + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + if (max_smem_per_block >= 176 * 1024) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } } template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; - // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower - // Using block size (64 x 128) is 27% slower for seqlen=2k - // Using block size (128 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + if (max_smem_per_block >= 224 * 1024) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } } template void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); - bool is_sm8x = cc_major == 8 && cc_minor > 0; - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_causal>(params, stream); - } - } else { + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + if (max_smem_per_block >= 160 * 1024) { run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); } - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // These two are always slower - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); } template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); - bool is_sm8x = cc_major == 8 && cc_minor > 0; - // For sm86 or sm89, 64 x 32 (40 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM. - // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_causal>(params, stream); - } + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + if (max_smem_per_block >= 192 * 1024) { + run_flash_fwd, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + // For sm86 or sm89, 64 x 64 (48 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM. + // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment + run_flash_fwd, Is_causal>(params, stream); } - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // 1st ones are good for H100, A100 - // 2nd one is good for A6000 bc we get slightly better occupancy } template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; - run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } template @@ -249,12 +260,8 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - - // For A100, we want to run with 64 x 64 (112KB smem). - // For H100 we want to run with 64 x 32 (72KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_causal>(params, stream); + if (max_smem_per_block >= 224 * 1024) { + run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); } From 4c7c27f99b75f23ab396a4b7d69b87889b5b5184 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 27 Aug 2025 00:04:26 +0800 Subject: [PATCH 5/8] Improves attention mask handling and expands test coverage Fixes dynamic mask preparation to properly handle invalid topk values by checking against minimum dtype values before scattering to attention mask. Expands benchmark test configurations to include comprehensive coverage across multiple head dimensions (32, 64, 96, 128, 256) and sequence lengths, providing more thorough validation of attention mechanisms. Re-enables previously disabled triton and flex attention test suites to ensure complete equivalence testing across all implementation variants. --- benchmarks/forward_equivalence.py | 260 ++++++++++++++++++++++-------- 1 file changed, 195 insertions(+), 65 deletions(-) diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index b3b0883..0035147 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -88,11 +88,12 @@ def prepare_dynamic_mask( ) if attn_bias.shape[-1] > keep_window_size: - topk_indices = torch.topk( + topk_values, topk_indices = torch.topk( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ).indices + ) + valid_topk = topk_values != min_dtype attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, 1.0) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) else: attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) @@ -518,28 +519,70 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): # If you encounter NAN issues when running multiple configurations, try running a single configuration test_configs = [ # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) - (1, 1, 1, 64, 64, 32, True), - (1, 1, 1, 64, 64, 32, False), - (1, 1, 1, 128, 128, 32, True), - (1, 1, 1, 128, 128, 32, False), - (1, 1, 1, 256, 256, 32, True), - (1, 1, 1, 256, 256, 32, False), - (1, 1, 1, 512, 512, 32, True), - (1, 1, 1, 512, 512, 32, False), - (1, 1, 1, 1024, 1024, 32, True), - (1, 1, 1, 1024, 1024, 32, False), - (1, 1, 1, 2048, 2048, 32, True), - (1, 1, 1, 2048, 2048, 32, False), - (1, 1, 1, 4096, 4096, 32, True), - (1, 1, 1, 4096, 4096, 32, False), - (1, 2, 1, 64, 64, 32, True), - (2, 1, 1, 128, 128, 32, True), - (2, 2, 1, 128, 128, 32, True), - (1, 2, 1, 64, 64, 128, True), + (1, 2, 1, 128, 128, 32, True), + (1, 2, 1, 128, 128, 32, False), + (1, 2, 1, 256, 256, 32, True), + (1, 2, 1, 256, 256, 32, False), + (1, 2, 1, 512, 512, 32, True), + (1, 2, 1, 512, 512, 32, False), + (1, 2, 1, 1024, 1024, 32, True), + (1, 2, 1, 1024, 1024, 32, False), + (1, 2, 1, 2048, 2048, 32, True), + (1, 2, 1, 2048, 2048, 32, False), + (1, 2, 1, 4096, 4096, 32, True), + (1, 2, 1, 4096, 4096, 32, False), + + (1, 2, 1, 128, 128, 64, True), + (1, 2, 1, 128, 128, 64, False), + (1, 2, 1, 256, 256, 64, True), + (1, 2, 1, 256, 256, 64, False), + (1, 2, 1, 512, 512, 64, True), + (1, 2, 1, 512, 512, 64, False), + (1, 2, 1, 1024, 1024, 64, True), + (1, 2, 1, 1024, 1024, 64, False), + (1, 2, 1, 2048, 2048, 64, True), + (1, 2, 1, 2048, 2048, 64, False), + (1, 2, 1, 4096, 4096, 64, True), + (1, 2, 1, 4096, 4096, 64, False), + + (1, 2, 1, 128, 128, 96, True), + (1, 2, 1, 128, 128, 96, False), + (1, 2, 1, 256, 256, 96, True), + (1, 2, 1, 256, 256, 96, False), + (1, 2, 1, 512, 512, 96, True), + (1, 2, 1, 512, 512, 96, False), + (1, 2, 1, 1024, 1024, 96, True), + (1, 2, 1, 1024, 1024, 96, False), + (1, 2, 1, 2048, 2048, 96, True), + (1, 2, 1, 2048, 2048, 96, False), + (1, 2, 1, 4096, 4096, 96, True), + (1, 2, 1, 4096, 4096, 96, False), + + (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, True), - (1, 2, 1, 3, 512, 128, True), - (1, 2, 1, 1, 512, 128, True), + (1, 2, 1, 256, 256, 128, False), + (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 512, 512, 128, False), + (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 1024, 1024, 128, False), + (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 2048, 2048, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + (1, 2, 1, 4096, 4096, 128, False), + + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 128, 128, 128, False), + (1, 2, 1, 256, 256, 256, True), + (1, 2, 1, 256, 256, 256, False), + (1, 2, 1, 512, 512, 256, True), + (1, 2, 1, 512, 512, 256, False), + (1, 2, 1, 1024, 1024, 256, True), + (1, 2, 1, 1024, 1024, 256, False), + (1, 2, 1, 2048, 2048, 256, True), + (1, 2, 1, 2048, 2048, 256, False), + (1, 2, 1, 4096, 4096, 256, True), + (1, 2, 1, 4096, 4096, 256, False), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -672,27 +715,71 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): # If you encounter NAN issues when running multiple configurations, try running a single configuration test_configs = [ # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) - (1, 1, 1, 64, 64, 32, True), - (1, 1, 1, 64, 64, 32, False), - (1, 1, 1, 128, 128, 32, True), - (1, 1, 1, 128, 128, 32, False), - (1, 1, 1, 256, 256, 32, True), - (1, 1, 1, 256, 256, 32, False), - (1, 1, 1, 512, 512, 32, True), - (1, 1, 1, 512, 512, 32, False), - (1, 1, 1, 1024, 1024, 32, True), - (1, 1, 1, 1024, 1024, 32, False), - (1, 1, 1, 2048, 2048, 32, True), - (1, 1, 1, 2048, 2048, 32, False), - (1, 1, 1, 4096, 4096, 32, True), - (1, 1, 1, 4096, 4096, 32, False), - (1, 2, 1, 64, 64, 32, True), - (2, 1, 1, 128, 128, 32, True), - (2, 2, 1, 128, 128, 32, True), - (1, 2, 1, 64, 64, 128, True), + (1, 2, 1, 128, 128, 32, True), + (1, 2, 1, 128, 128, 32, False), + (1, 2, 1, 256, 256, 32, True), + (1, 2, 1, 256, 256, 32, False), + (1, 2, 1, 512, 512, 32, True), + (1, 2, 1, 512, 512, 32, False), + (1, 2, 1, 1024, 1024, 32, True), + (1, 2, 1, 1024, 1024, 32, False), + (1, 2, 1, 2048, 2048, 32, True), + (1, 2, 1, 2048, 2048, 32, False), + (1, 2, 1, 4096, 4096, 32, True), + (1, 2, 1, 4096, 4096, 32, False), + + (1, 2, 1, 128, 128, 64, True), + (1, 2, 1, 128, 128, 64, False), + (1, 2, 1, 256, 256, 64, True), + (1, 2, 1, 256, 256, 64, False), + (1, 2, 1, 512, 512, 64, True), + (1, 2, 1, 512, 512, 64, False), + (1, 2, 1, 1024, 1024, 64, True), + (1, 2, 1, 1024, 1024, 64, False), + (1, 2, 1, 2048, 2048, 64, True), + (1, 2, 1, 2048, 2048, 64, False), + (1, 2, 1, 4096, 4096, 64, True), + (1, 2, 1, 4096, 4096, 64, False), + + (1, 2, 1, 128, 128, 96, True), + (1, 2, 1, 128, 128, 96, False), + (1, 2, 1, 256, 256, 96, True), + (1, 2, 1, 256, 256, 96, False), + (1, 2, 1, 512, 512, 96, True), + (1, 2, 1, 512, 512, 96, False), + (1, 2, 1, 1024, 1024, 96, True), + (1, 2, 1, 1024, 1024, 96, False), + (1, 2, 1, 2048, 2048, 96, True), + (1, 2, 1, 2048, 2048, 96, False), + (1, 2, 1, 4096, 4096, 96, True), + (1, 2, 1, 4096, 4096, 96, False), + + (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 256, 256, 128, False), (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 512, 512, 128, False), + (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 1024, 1024, 128, False), + (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 2048, 2048, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + (1, 2, 1, 4096, 4096, 128, False), + + # Not support head_dim > 128 in triton yet + # (1, 2, 1, 128, 128, 128, True), + # (1, 2, 1, 128, 128, 128, False), + # (1, 2, 1, 256, 256, 256, True), + # (1, 2, 1, 256, 256, 256, False), + # (1, 2, 1, 512, 512, 256, True), + # (1, 2, 1, 512, 512, 256, False), + # (1, 2, 1, 1024, 1024, 256, True), + # (1, 2, 1, 1024, 1024, 256, False), + # (1, 2, 1, 2048, 2048, 256, True), + # (1, 2, 1, 2048, 2048, 256, False), + # (1, 2, 1, 4096, 4096, 256, True), + # (1, 2, 1, 4096, 4096, 256, False), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -843,27 +930,70 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): # Test configurations for Flex Attention test_configs = [ # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) - (1, 1, 1, 64, 64, 32, True), - (1, 1, 1, 64, 64, 32, False), - (1, 1, 1, 128, 128, 32, True), - (1, 1, 1, 128, 128, 32, False), - (1, 1, 1, 256, 256, 32, True), - (1, 1, 1, 256, 256, 32, False), - (1, 1, 1, 512, 512, 32, True), - (1, 1, 1, 512, 512, 32, False), - (1, 1, 1, 1024, 1024, 32, True), - (1, 1, 1, 1024, 1024, 32, False), - (1, 1, 1, 2048, 2048, 32, True), - (1, 1, 1, 2048, 2048, 32, False), - (1, 1, 1, 4096, 4096, 32, True), - (1, 1, 1, 4096, 4096, 32, False), - (1, 2, 1, 64, 64, 32, True), - (2, 1, 1, 128, 128, 32, True), - (2, 2, 1, 128, 128, 32, True), - (1, 2, 1, 64, 64, 128, True), + (1, 2, 1, 128, 128, 32, True), + (1, 2, 1, 128, 128, 32, False), + (1, 2, 1, 256, 256, 32, True), + (1, 2, 1, 256, 256, 32, False), + (1, 2, 1, 512, 512, 32, True), + (1, 2, 1, 512, 512, 32, False), + (1, 2, 1, 1024, 1024, 32, True), + (1, 2, 1, 1024, 1024, 32, False), + (1, 2, 1, 2048, 2048, 32, True), + (1, 2, 1, 2048, 2048, 32, False), + (1, 2, 1, 4096, 4096, 32, True), + (1, 2, 1, 4096, 4096, 32, False), + + (1, 2, 1, 128, 128, 64, True), + (1, 2, 1, 128, 128, 64, False), + (1, 2, 1, 256, 256, 64, True), + (1, 2, 1, 256, 256, 64, False), + (1, 2, 1, 512, 512, 64, True), + (1, 2, 1, 512, 512, 64, False), + (1, 2, 1, 1024, 1024, 64, True), + (1, 2, 1, 1024, 1024, 64, False), + (1, 2, 1, 2048, 2048, 64, True), + (1, 2, 1, 2048, 2048, 64, False), + (1, 2, 1, 4096, 4096, 64, True), + (1, 2, 1, 4096, 4096, 64, False), + + (1, 2, 1, 128, 128, 96, True), + (1, 2, 1, 128, 128, 96, False), + (1, 2, 1, 256, 256, 96, True), + (1, 2, 1, 256, 256, 96, False), + (1, 2, 1, 512, 512, 96, True), + (1, 2, 1, 512, 512, 96, False), + (1, 2, 1, 1024, 1024, 96, True), + (1, 2, 1, 1024, 1024, 96, False), + (1, 2, 1, 2048, 2048, 96, True), + (1, 2, 1, 2048, 2048, 96, False), + (1, 2, 1, 4096, 4096, 96, True), + (1, 2, 1, 4096, 4096, 96, False), + + (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 256, 256, 128, False), (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 512, 512, 128, False), + (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 1024, 1024, 128, False), + (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 2048, 2048, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + (1, 2, 1, 4096, 4096, 128, False), + + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 128, 128, 128, False), + (1, 2, 1, 256, 256, 256, True), + (1, 2, 1, 256, 256, 256, False), + (1, 2, 1, 512, 512, 256, True), + (1, 2, 1, 512, 512, 256, False), + (1, 2, 1, 1024, 1024, 256, True), + (1, 2, 1, 1024, 1024, 256, False), + (1, 2, 1, 2048, 2048, 256, True), + (1, 2, 1, 2048, 2048, 256, False), + (1, 2, 1, 4096, 4096, 256, True), + (1, 2, 1, 4096, 4096, 256, False), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -1051,13 +1181,13 @@ def main(): print("\n" + "📍" + " Starting Standard Forward Pass Tests " + "📍") test_results['cuda'] = test_cuda_forward_equivalence(args.accuracy_threshold) - # if args.test_type in ['all', 'triton']: - # print("\n" + "🔥" + " Starting Python vs Triton Tests " + "🔥") - # test_results['triton'] = test_triton_forward_equivalence(args.accuracy_threshold) + if args.test_type in ['all', 'triton']: + print("\n" + "🔥" + " Starting Python vs Triton Tests " + "🔥") + test_results['triton'] = test_triton_forward_equivalence(args.accuracy_threshold) - # if args.test_type in ['all', 'flex']: - # print("\n" + "🌟" + " Starting Python vs Flex Attention Tests " + "🌟") - # test_results['flex'] = test_flex_forward_equivalence(args.accuracy_threshold) + if args.test_type in ['all', 'flex']: + print("\n" + "🌟" + " Starting Python vs Flex Attention Tests " + "🌟") + test_results['flex'] = test_flex_forward_equivalence(args.accuracy_threshold) # Print overall summary From 56d25fbee478e32829f1eb7975bfa92921187fae Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 27 Aug 2025 00:04:37 +0800 Subject: [PATCH 6/8] Optimizes block size for small head dimensions Increases block size from 64 to 128 for head dimensions <= 32 to improve memory throughput and computational efficiency for smaller attention heads. The nested ternary operator now handles three cases: - Head dim <= 32: uses 128 block size - Head dim <= 64: uses 64 block size - Head dim >= 128: uses 32 block size --- csrc/src/flash_fwd_launch_template.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index d4d8b57..d15298e 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -155,7 +155,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = 64; // Fixed for all head dimensions - constexpr static int kBlockN = Headdim <= 64 ? 64 : (Headdim < 128 ? 64 : 32); + constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 64 ? 64 : (Headdim < 128 ? 64 : 32)); run_flash_splitkv_fwd, Is_causal>(params, stream); } From 08990628ae14634c574e40c3491558129d2d61b8 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 27 Aug 2025 00:07:13 +0800 Subject: [PATCH 7/8] Improves dynamic mask handling and updates comments Enhances the prepare_dynamic_mask function by capturing both values and indices from torch.topk operation, then filtering out invalid entries based on minimum dtype values. This prevents invalid indices from being included in the attention mask. Updates inline comments to standardize "INF" terminology and removes outdated debugging comments from test cases. --- benchmarks/backward_equivalence.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index 4c2c509..3a8598f 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -88,11 +88,12 @@ def prepare_dynamic_mask( ) if attn_bias.shape[-1] > keep_window_size: - topk_indices = torch.topk( + topk_values, topk_indices = torch.topk( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ).indices + ) + valid_topk = topk_values != min_dtype attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, 1.0) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) else: attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) @@ -561,19 +562,19 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 256, 256, 32, False), (1, 2, 1, 512, 512, 32, True), (1, 2, 1, 512, 512, 32, False), - (1, 2, 1, 1024, 1024, 32, True), # some -Inf and Inf in dbias, Idk why + (1, 2, 1, 1024, 1024, 32, True), (1, 2, 1, 1024, 1024, 32, False), (1, 2, 1, 2048, 2048, 32, True), (1, 2, 1, 2048, 2048, 32, False), - (1, 2, 1, 4096, 4096, 32, True), # some NAN in dbias, Idk why + (1, 2, 1, 4096, 4096, 32, True), (1, 2, 1, 4096, 4096, 32, False), (1, 2, 1, 128, 128, 64, True), (1, 2, 1, 128, 128, 64, False), - (1, 2, 1, 256, 256, 64, True), # some NAN in dbias, Idk why + (1, 2, 1, 256, 256, 64, True), (1, 2, 1, 256, 256, 64, False), (1, 2, 1, 512, 512, 64, True), (1, 2, 1, 512, 512, 64, False), - (1, 2, 1, 1024, 1024, 64, True), # some NAN in dbias, Idk why + (1, 2, 1, 1024, 1024, 64, True), # some INF in dbias, Idk why (1, 2, 1, 1024, 1024, 64, False), (1, 2, 1, 2048, 2048, 64, True), (1, 2, 1, 2048, 2048, 64, False), @@ -585,26 +586,26 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 256, 256, 96, False), (1, 2, 1, 512, 512, 96, True), (1, 2, 1, 512, 512, 96, False), - (1, 2, 1, 1024, 1024, 96, True), # some NAN in dbias, Idk why + (1, 2, 1, 1024, 1024, 96, True), # some INF in dbias, Idk why (1, 2, 1, 1024, 1024, 96, False), (1, 2, 1, 2048, 2048, 96, True), (1, 2, 1, 2048, 2048, 96, False), (1, 2, 1, 4096, 4096, 96, True), (1, 2, 1, 4096, 4096, 96, False), - (1, 2, 1, 128, 128, 128, True), # some NAN in dbias, Idk why + (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, True), (1, 2, 1, 256, 256, 128, False), (1, 2, 1, 512, 512, 128, True), (1, 2, 1, 512, 512, 128, False), - (1, 2, 1, 1024, 1024, 128, True), # some NAN in dbias, Idk why + (1, 2, 1, 1024, 1024, 128, True), # some INF in dbias, Idk why (1, 2, 1, 1024, 1024, 128, False), (1, 2, 1, 2048, 2048, 128, True), (1, 2, 1, 2048, 2048, 128, False), (1, 2, 1, 4096, 4096, 128, True), (1, 2, 1, 4096, 4096, 128, False), - # Not support head_dim > 128 yet in sm 80 + # Not support head_dim > 128 in sm80 yet # (1, 2, 1, 128, 128, 256, True), # (1, 2, 1, 128, 128, 128, False), # (1, 2, 1, 256, 256, 256, True), From 14f45ab17b33b39e7ddf9c969ec8bf45096351c9 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 27 Aug 2025 00:08:17 +0800 Subject: [PATCH 8/8] Fixes attention mask handling for invalid topk values Improves the dynamic mask preparation by properly handling cases where topk values are invalid (equal to minimum dtype value). Previously, the mask would incorrectly include positions with invalid attention scores, potentially causing incorrect attention computations. Now validates topk values before setting mask positions, ensuring only valid attention scores are preserved in the final mask. --- benchmarks/forward_performance.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index 8ba8ec0..95c2920 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -110,11 +110,12 @@ def prepare_dynamic_mask( ) if attn_bias.shape[-1] > keep_window_size: - topk_indices = torch.topk( + topk_values, topk_indices = torch.topk( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ).indices + ) + valid_topk = topk_values != min_dtype attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, 1.0) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) else: attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)