From e6168c0f986221d8d6109d913386dfeec86dd4ee Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 27 Sep 2025 23:18:13 +0800 Subject: [PATCH 01/16] Split fast path for even MN; fix OOB clear Introduces compile-time branching to separate the even-tile fast path from ragged edges. On even tiles, performs unguarded bulk copies and removes per-element predicates. On ragged tiles, guards M/N bounds, switches to element-wise copy, and explicitly clears out-of-bounds regions. Reduces runtime branching and divergence, improving correctness on partial tiles and performance on full tiles. --- csrc/flash_dmattn/src/utils.h | 66 +++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 22 deletions(-) diff --git a/csrc/flash_dmattn/src/utils.h b/csrc/flash_dmattn/src/utils.h index 52fb7a0..0d0a4ac 100644 --- a/csrc/flash_dmattn/src/utils.h +++ b/csrc/flash_dmattn/src/utils.h @@ -594,19 +594,32 @@ __forceinline__ __device__ void copy_mask( CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { + if constexpr (Is_even_MN) { + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { #pragma unroll for (int n = 0; n < size<2>(S); ++n) { - if (Is_even_MN || predicate_N(n)) { - cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, n)); + cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + } + } + } else { + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) < max_M) { + #pragma unroll + for (int n = 0; n < size<2>(S); ++n) { + if (predicate_N(n)) { + #pragma unroll + for (int i = 0; i < size<0>(S); ++i) { + D(i, m, n) = S(i, m, n); + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, n)); + } } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, _)); } } } @@ -675,24 +688,33 @@ __forceinline__ __device__ void copy_bias( CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N - - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { + + if constexpr (Is_even_MN) { + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { #pragma unroll for (int n = 0; n < size<2>(S); ++n) { - if (Is_even_MN || predicate_N(n)) { - // cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); - #pragma unroll - for (int i = 0; i < size<0>(S); ++i) { - D(i, m, n) = S(i, m, n); + cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + } + } + } else { + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) < max_M) { + #pragma unroll + for (int n = 0; n < size<2>(S); ++n) { + if (predicate_N(n)) { + #pragma unroll + for (int i = 0; i < size<0>(S); ++i) { + D(i, m, n) = S(i, m, n); + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, n)); } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, n)); } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, _)); } } } From 0b5654eae0eb6e79df7b2e4c845148ebac69fcc1 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 27 Sep 2025 23:18:40 +0800 Subject: [PATCH 02/16] Adds seq-len padding and unpadding for K/V Pads key/value sequence length to a multiple of 8 and adjusts mask/bias accordingly to satisfy kernel alignment. Stores the original length and slices gradients/bias in backward to restore shapes. Improves correctness and supports non-multiple-of-8 sequence lengths without shape mismatches. --- flash_dmattn/flash_dmattn_interface.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 5b5a889..223a8fd 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -247,14 +247,14 @@ def forward( q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - # seqlen_k_og = k.shape[1] - # if seqlen_k_og % 8 != 0: - # k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) - # v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) - # if mask is not None: - # mask = torch.nn.functional.pad(mask, [0, 8 - seqlen_k_og % 8], value=False) - # if bias is not None: - # bias = torch.nn.functional.pad(bias, [0, 8 - seqlen_k_og % 8], value=0.0) + seqlen_k_og = k.shape[1] + if seqlen_k_og % 8 != 0: + k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) + v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) + if mask is not None: + mask = torch.nn.functional.pad(mask, [0, 8 - seqlen_k_og % 8], value=False) + if bias is not None: + bias = torch.nn.functional.pad(bias, [0, 8 - seqlen_k_og % 8], value=0.0) out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, @@ -274,7 +274,7 @@ def forward( ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic - # ctx.seqlen_k_og = seqlen_k_og + ctx.seqlen_k_og = seqlen_k_og out = out_padded[..., :head_size_og] @@ -318,10 +318,10 @@ def backward( dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - # if ctx.seqlen_k_og % 8 != 0: - # dk = dk[:, : ctx.seqlen_k_og, :, :] - # dv = dv[:, : ctx.seqlen_k_og, :, :] - # dbias = dbias[..., : ctx.seqlen_k_og] + if ctx.seqlen_k_og % 8 != 0: + dk = dk[:, : ctx.seqlen_k_og, :, :] + dv = dv[:, : ctx.seqlen_k_og, :, :] + dbias = dbias[..., : ctx.seqlen_k_og] return dq, dk, dv, None, dbias, None, None, None, None, None, None From d9a5248e59da7317dad7189a5d8a46aba958afe4 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Tue, 30 Sep 2025 12:29:18 +0800 Subject: [PATCH 03/16] Fixes smem copy types for mask and bias Uses dedicated shared-memory copy ops for mask and bias to match their layouts, preventing stride/type mismatches in attention computation and improving correctness/perf. Applies to both regular and split-KV paths and cleans minor whitespace. --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 0e4d69b..38a6428 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -267,10 +267,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); @@ -1026,7 +1026,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -1059,10 +1058,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPS{}, tiled_mma); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); From f7a05f30a372db82bc421b485c3f78f9c9eee81f Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Tue, 30 Sep 2025 16:33:53 +0800 Subject: [PATCH 04/16] Defines ElementMask and propagates through traits Standardizes the mask element type as uint8_t in base traits and exposes it in forward/backward kernel traits. Improves consistency and avoids missing-type compile errors where the mask type is referenced, while easing future type changes. --- csrc/flash_dmattn/src/kernel_traits.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/flash_dmattn/src/kernel_traits.h b/csrc/flash_dmattn/src/kernel_traits.h index f6a4e42..040cdab 100644 --- a/csrc/flash_dmattn/src/kernel_traits.h +++ b/csrc/flash_dmattn/src/kernel_traits.h @@ -23,6 +23,7 @@ struct Flash_kernel_traits { static constexpr bool Has_cp_async = false; #endif + using ElementMask = uint8_t; using ElementAccum = float; using index_t = int64_t; @@ -55,6 +56,7 @@ struct Flash_fwd_kernel_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; + using ElementMask = typename Base::ElementMask; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; @@ -233,6 +235,7 @@ template< > struct Flash_bwd_kernel_traits : public Base { using Element = typename Base::Element; + using ElementMask = typename Base::ElementMask; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; From 0cb79bf374d28f86fe293af7858ec28d4191682c Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 1 Oct 2025 00:43:10 +0800 Subject: [PATCH 05/16] Refactors mask/bias copy policies for Flash attn Updates forward and backward paths to use a non-vectorized copy for masks and a hardware-tuned global copy for bias, avoiding unsafe 128B alignment assumptions on masks and improving portability. Improves correctness on potentially unaligned mask accesses and aligns bias copies with the chosen gmem policy, with minor cleanups in tiled copy definitions. --- csrc/flash_dmattn/src/kernel_traits.h | 45 ++++++++++++++++----------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/csrc/flash_dmattn/src/kernel_traits.h b/csrc/flash_dmattn/src/kernel_traits.h index 040cdab..9bc237c 100644 --- a/csrc/flash_dmattn/src/kernel_traits.h +++ b/csrc/flash_dmattn/src/kernel_traits.h @@ -131,7 +131,8 @@ struct Flash_fwd_kernel_traits : public Base { Shape, Int>{} ) ); - using SmemCopyAtomPS = Copy_Atom, Element>; + using SmemCopyAtomMask = Copy_Atom; + using SmemCopyAtomBias = Copy_Atom, Element>; // Shared memory layout for output using SmemLayoutAtomO = decltype( @@ -189,14 +190,14 @@ struct Flash_fwd_kernel_traits : public Base { ); // Val layout, 8 vals per read using GmemTiledCopyMask = decltype( make_tiled_copy( - Copy_Atom, Element>{}, + Copy_Atom{}, GmemLayoutAtom{}, Layout>{} ) ); // Val layout, 8 vals per read using GmemTiledCopyBias = decltype( make_tiled_copy( - Copy_Atom, Element>{}, + Copy_Atom{}, GmemLayoutAtom{}, Layout>{} ) @@ -350,6 +351,8 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); using SmemCopyAtomPdS = Copy_Atom, elem_type>; + using SmemCopyAtomMask = Copy_Atom; + using SmemCopyAtomBias = Copy_Atom, elem_type>; using SmemLayoutQdOtransposed = decltype( composition( @@ -433,46 +436,52 @@ struct Flash_bwd_kernel_traits : public Base { Layout>{} ) ); // Val layout, 8 vals per read - using GmemTiledCopydO = decltype( + using GmemTiledCopyMask = decltype( make_tiled_copy( - Copy_Atom, elem_type>{}, + Copy_Atom{}, GmemLayoutAtom{}, - Layout>{} + Layout>{} ) - ); // Val layout, 8 vals per store - using GmemTiledCopydKV = decltype( + ); // Val layout, 8 vals per read + using GmemTiledCopyBias = decltype( make_tiled_copy( - Copy_Atom, elem_type>{}, + Copy_Atom{}, GmemLayoutAtom{}, - Layout>{} + Layout>{} ) - ); // Val layout, 8 vals per store - using GmemTiledCopydQ = decltype( + ); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype( make_tiled_copy( Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{} ) ); // Val layout, 8 vals per store - using GmemTiledCopyMask = decltype( + using GmemTiledCopydBias = decltype( make_tiled_copy( Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{} ) ); // Val layout, 8 vals per read - using GmemTiledCopyBias = decltype( + using GmemTiledCopydKV = decltype( make_tiled_copy( Copy_Atom, elem_type>{}, GmemLayoutAtom{}, - Layout>{} + Layout>{} ) - ); // Val layout, 8 vals per read + ); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype( + make_tiled_copy( + Copy_Atom, elem_type>{}, + GmemLayoutAtom{}, + Layout>{} + ) + ); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, Layout, Stride< _8, _1>>, // Thread layout, 8 threads per row - Layout, Stride< _16, _1>> // Thread layout, 16 threads per row - + Layout, Stride< _16, _1>> // Thread layout, 16 threads per row >; using GmemTiledCopydQaccum = decltype( make_tiled_copy( From 94fe6c17d55c17ab8b846e245aff14f57e274d98 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 1 Oct 2025 00:43:56 +0800 Subject: [PATCH 06/16] Removes redundant sync after bias copy Removes block-wide barriers that were only needed when bias loads were scalar. With vectorized bias copies and async copy fencing in place, the extra synchronization is unnecessary. Reduces sync overhead and stalls, improving forward attention performance without affecting correctness. --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 38a6428..c798a02 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -400,9 +400,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } cute::cp_async_fence(); } @@ -560,9 +557,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. @@ -723,9 +717,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. @@ -1159,9 +1150,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } cute::cp_async_fence(); } @@ -1350,9 +1338,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. @@ -1540,9 +1525,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. From 0ad8237eb6c278738e09276c4c0178d7f393700a Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 1 Oct 2025 00:44:59 +0800 Subject: [PATCH 07/16] Vectorizes dBias copy; removes sync barriers Adds a dedicated memory copy path for bias gradients and uses proper shared-memory partitioning for mask/bias, aligning with the compute tile. Replaces scalar bias copies with vectorized transactions, allowing removal of explicit synchronization after bias copy operations. Improves performance and avoids layout mismatches in bias-enabled backward passes. --- csrc/flash_dmattn/src/flash_bwd_kernel.h | 30 ++++++++---------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 780e616..3554802 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -278,6 +278,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in GmemTiledCopydO gmem_tiled_copy_dO; auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; + typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias; + auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice(tidx); auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); using GmemLayoutAtomdQaccum = std::conditional_t< !Seq_parallel, @@ -300,7 +302,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask); Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias); - Tensor tdBiasgdBias = gmem_thr_copy_Bias.partition_D(gdBias); + Tensor tdBiasgdBias = gmem_thr_copy_dBias.partition_D(gdBias); Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom, AtomNum), ATOM_M, ATOM_N) Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); @@ -350,20 +352,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); } Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); - // auto smem_tiled_copy_Mask = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp); - // auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); - // Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); - // auto smem_tiled_copy_Bias = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp); - // auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); - // Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); - // Partition sP and sdS to match the accumulator partitioning // This has to be tiled_mma_sdp, not tiled_mma_dkv // auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp); auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp); auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx); - Tensor tSsMask = smem_thr_copy_PdS.partition_S(sMask); - Tensor tSsBias = smem_thr_copy_PdS.partition_S(sBias); + auto smem_tiled_copy_Mask = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp); + auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); + auto smem_tiled_copy_Bias = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp); + auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); + Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); + Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom, AtomNum), PIPE_M, PIPE_N) Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom, AtomNum), PIPE_M, PIPE_N) @@ -608,9 +607,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } } @@ -853,14 +849,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if constexpr (Has_bias) { // Write dS to dBias FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, + gmem_tiled_copy_dBias, tBiassBias, tdBiasgdBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } // if (cute::thread0()) { print(tPrP); } @@ -1013,9 +1006,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tBiascBias, tBiaspBias, binfo.actual_seqlen_q - (m_block - 1) * kBlockM ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); } } } From 0cd9eab2008414026618d2d35c47b0dd2221e6f9 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 1 Oct 2025 01:41:45 +0800 Subject: [PATCH 08/16] Unifies bias copy logic and clears OOB tiles Collapses separate even/uneven paths into a single unrolled loop that uses tiled copies for in-bounds regions and clears out-of-bounds elements when requested. Replaces scalar element-wise copies with vectorized/tiled copies on valid tiles to improve performance and reduce code duplication while preserving correctness on partial tiles. --- csrc/flash_dmattn/src/utils.h | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/csrc/flash_dmattn/src/utils.h b/csrc/flash_dmattn/src/utils.h index 0d0a4ac..f195bf5 100644 --- a/csrc/flash_dmattn/src/utils.h +++ b/csrc/flash_dmattn/src/utils.h @@ -689,32 +689,19 @@ __forceinline__ __device__ void copy_bias( CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N - if constexpr (Is_even_MN) { - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { #pragma unroll for (int n = 0; n < size<2>(S); ++n) { - cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); - } - } - } else { - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) < max_M) { - #pragma unroll - for (int n = 0; n < size<2>(S); ++n) { - if (predicate_N(n)) { - #pragma unroll - for (int i = 0; i < size<0>(S); ++i) { - D(i, m, n) = S(i, m, n); - } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, n)); - } + if (Is_even_MN || predicate_N(n)) { + cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, n)); } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, _)); } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); } } } From b7314012eaa47175d8186cdaccc7fcdd22c1e6bb Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 2 Oct 2025 00:25:08 +0800 Subject: [PATCH 09/16] Removes clamping; guards optional bias slice Removes tensor clamping in forward/backward to preserve true values and reduce overhead. Guards slicing of an optional bias to avoid None errors when sequence length isn't divisible by 8. --- flash_dmattn/flash_dmattn_interface.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 223a8fd..2c267d4 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -95,7 +95,7 @@ def _flash_dmattn_forward( softcap, return_softmax, ) - _sanitize_tensors(out, nan=0.0, posinf=torch.finfo(out.dtype).max, neginf=torch.finfo(out.dtype).min) + # _sanitize_tensors(out, nan=0.0, posinf=torch.finfo(out.dtype).max, neginf=torch.finfo(out.dtype).min) return out, softmax_lse, S_dmask @@ -170,7 +170,7 @@ def _flash_dmattn_backward( softcap, deterministic, ) - _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=torch.finfo(dq.dtype).max, neginf=torch.finfo(dq.dtype).min) + # _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=torch.finfo(dq.dtype).max, neginf=torch.finfo(dq.dtype).min) return softmax_d @@ -321,7 +321,8 @@ def backward( if ctx.seqlen_k_og % 8 != 0: dk = dk[:, : ctx.seqlen_k_og, :, :] dv = dv[:, : ctx.seqlen_k_og, :, :] - dbias = dbias[..., : ctx.seqlen_k_og] + if dbias is not None: + dbias = dbias[..., : ctx.seqlen_k_og] return dq, dk, dv, None, dbias, None, None, None, None, None, None From 07cd983a1e5ef1103c4c5e8d6adee90ac3511843 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 2 Oct 2025 00:27:11 +0800 Subject: [PATCH 10/16] Simplifies masked copy; fixes active reduction Unifies mask handling for even/odd shapes and N predicates, always using the tiled path and clearing OOB uniformly. Removes the type-cast template and per-element copy, reducing branching and improving performance. Fixes block activity detection by syncing and OR-reducing over the destination after copy, preventing false negatives; renames the output flag for clarity. --- csrc/flash_dmattn/src/utils.h | 51 ++++++++++++++--------------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/csrc/flash_dmattn/src/utils.h b/csrc/flash_dmattn/src/utils.h index f195bf5..13d204b 100644 --- a/csrc/flash_dmattn/src/utils.h +++ b/csrc/flash_dmattn/src/utils.h @@ -594,32 +594,19 @@ __forceinline__ __device__ void copy_mask( CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N - if constexpr (Is_even_MN) { - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { #pragma unroll for (int n = 0; n < size<2>(S); ++n) { - cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); - } - } - } else { - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) < max_M) { - #pragma unroll - for (int n = 0; n < size<2>(S); ++n) { - if (predicate_N(n)) { - #pragma unroll - for (int i = 0; i < size<0>(S); ++i) { - D(i, m, n) = S(i, m, n); - } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, n)); - } + if (Is_even_MN || predicate_N(n)) { + cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, n)); } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, _)); } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); } } } @@ -627,7 +614,7 @@ __forceinline__ __device__ void copy_mask( //////////////////////////////////////////////////////////////////////////////////////////////////// template < - bool Is_even_MN=true, bool Clear_OOB_MN=false, typename To_type=void, + bool Is_even_MN=true, bool Clear_OOB_MN=false, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2, typename Engine3, typename Layout3 @@ -635,7 +622,7 @@ template < __forceinline__ __device__ void copy_mask_with_or_reduce( TiledCopy tiled_copy, Tensor const &S, Tensor &D, - bool &block_active, + bool &active, Tensor const &identity_MN, Tensor const &predicate_N, const int max_M=0 ) { @@ -645,18 +632,13 @@ __forceinline__ __device__ void copy_mask_with_or_reduce( CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N - bool any_active = false; #pragma unroll for (int m = 0; m < size<1>(S); ++m) { if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { #pragma unroll for (int n = 0; n < size<2>(S); ++n) { if (Is_even_MN || predicate_N(n)) { - #pragma unroll - for (int i = 0; i < size<0>(S); ++i) { - any_active |= S(i, m, n); - D(i, m, n) = static_cast(S(i, m, n)); - } + cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); } else if (Clear_OOB_MN) { cute::clear(D(_, m, n)); } @@ -666,7 +648,14 @@ __forceinline__ __device__ void copy_mask_with_or_reduce( } } - block_active = __syncthreads_or(any_active); + __syncthreads(); + + bool active_local = false; + #pragma unroll + for (int i = 0; i < size(D); ++i) { + active_local |= D(i); + } + active = __syncthreads_or(active_local); } //////////////////////////////////////////////////////////////////////////////////////////////////// From c6920eb0d548786cfe2e4d909808278f9d370512 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 2 Oct 2025 00:28:27 +0800 Subject: [PATCH 11/16] Vectorizes mask/bias I/O; refactors mem layouts Separates per-matrix global-memory layouts and thread mapping to account for differing element sizes, improving coalescing and alignment. Switches mask transfers to aligned auto-vectorized paths and widens mask load width, plus adds divisibility assertions to catch misconfigurations early. Cleans up and clarifies shared-memory layout comments/structure for mask and bias, while preserving Q/K/V/O behavior. --- csrc/flash_dmattn/src/kernel_traits.h | 64 +++++++++++++++++---------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/csrc/flash_dmattn/src/kernel_traits.h b/csrc/flash_dmattn/src/kernel_traits.h index 9bc237c..7c4f225 100644 --- a/csrc/flash_dmattn/src/kernel_traits.h +++ b/csrc/flash_dmattn/src/kernel_traits.h @@ -85,7 +85,7 @@ struct Flash_fwd_kernel_traits : public Base { Tile, _16, _16> >; - // Shared memory layout for Q matrix and Mask matrix + // Shared memory layout for Q matrix using SmemLayoutAtomQ = decltype( composition( Swizzle{}, @@ -94,14 +94,6 @@ struct Flash_fwd_kernel_traits : public Base { Stride, _1>>{} ) ); - using SmemLayoutAtomPS = decltype( - composition( - Swizzle{}, - Layout, Int>, - Stride, _1>>{} - ) - ); - using SmemLayoutQ = decltype( tile_to_shape( SmemLayoutAtomQ{}, @@ -125,13 +117,21 @@ struct Flash_fwd_kernel_traits : public Base { ); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); + // Shared memory layout for Mask and Bias matrices + using SmemLayoutAtomPS = decltype( + composition( + Swizzle{}, + Layout, Int>, + Stride, _1>>{} + ) + ); using SmemLayoutPS = decltype( tile_to_shape( SmemLayoutAtomPS{}, Shape, Int>{} ) ); - using SmemCopyAtomMask = Copy_Atom; + using SmemCopyAtomMask = Copy_Atom, ElementMask>; using SmemCopyAtomBias = Copy_Atom, Element>; // Shared memory layout for output @@ -160,18 +160,34 @@ struct Flash_fwd_kernel_traits : public Base { // Shared memory size with QKV matrices and mask/bias matrices static constexpr int kSmemSize = (Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize) + (Has_mask ? kSmemMaskSize : 0) + (Has_bias ? kSmemBiasSize : 0); - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemElemsPerLoadQKVO = sizeof(cute::uint128_t) / sizeof(Element); + static constexpr int kGmemElemsPerLoadMask = sizeof(cute::uint128_t) / sizeof(ElementMask); + static constexpr int kGmemElemsPerLoadBias = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoadQKVO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadQKVO"); + static_assert(kBlockN % kGmemElemsPerLoadMask == 0, "kBlockN must be a multiple of kGmemElemsPerLoadMask"); + static_assert(kBlockN % kGmemElemsPerLoadBias == 0, "kBlockN must be a multiple of kGmemElemsPerLoadBias"); // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. // For example, for d=128, smem is split into 2 "pages", each page takes care of columns // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, // to the same banks. - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout< - Shape , Int>, - Stride, _1> + static constexpr int kGmemThreadsPerRowQKVO = kBlockKSmem / kGmemElemsPerLoadQKVO; + static constexpr int kGmemThreadsPerRowMask = kBlockN / kGmemElemsPerLoadMask; + static constexpr int kGmemThreadsPerRowBias = kBlockN / kGmemElemsPerLoadBias; + static_assert(kNThreads % kGmemThreadsPerRowQKVO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowQKVO"); + static_assert(kNThreads % kGmemThreadsPerRowMask == 0, "kNThreads must be a multiple of kGmemThreadsPerRowMask"); + static_assert(kNThreads % kGmemThreadsPerRowBias == 0, "kNThreads must be a multiple of kGmemThreadsPerRowBias"); + using GmemLayoutAtomQKVO = Layout< + Shape , Int>, + Stride, _1> + >; + using GmemLayoutAtomMask = Layout< + Shape , Int>, + Stride, _1> + >; + using GmemLayoutAtomBias = Layout< + Shape , Int>, + Stride, _1> >; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading @@ -184,28 +200,28 @@ struct Flash_fwd_kernel_traits : public Base { using GmemTiledCopyQKV = decltype( make_tiled_copy( Copy_Atom{}, - GmemLayoutAtom{}, + GmemLayoutAtomQKVO{}, Layout>{} ) ); // Val layout, 8 vals per read using GmemTiledCopyMask = decltype( make_tiled_copy( - Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{} + Copy_Atom, ElementMask>{}, + GmemLayoutAtomMask{}, + Layout>{} ) - ); // Val layout, 8 vals per read + ); // Val layout, 16 vals per read using GmemTiledCopyBias = decltype( make_tiled_copy( Copy_Atom{}, - GmemLayoutAtom{}, + GmemLayoutAtomBias{}, Layout>{} ) ); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy( Copy_Atom, Element>{}, - GmemLayoutAtom{}, + GmemLayoutAtomQKVO{}, Layout>{} ) ); // Val layout, 8 vals per store From 3d88f28e04a92a313cab487e2291fa2936005706 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 2 Oct 2025 00:29:34 +0800 Subject: [PATCH 12/16] Vectorizes mask copy and splits gmem layouts Updates memory tiling to use per-type layouts (QKVO, mask, bias) with matching vector widths and thread mapping. Vectorizes mask copies in shared/global memory and increases mask read width to 16 for better bandwidth. Adds stronger compile-time checks to enforce alignment and divisibility, reducing misaligned accesses and improving coalescing and stability. --- csrc/flash_dmattn/src/kernel_traits.h | 52 +++++++++++++++++---------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/csrc/flash_dmattn/src/kernel_traits.h b/csrc/flash_dmattn/src/kernel_traits.h index 7c4f225..884fe0a 100644 --- a/csrc/flash_dmattn/src/kernel_traits.h +++ b/csrc/flash_dmattn/src/kernel_traits.h @@ -367,7 +367,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); using SmemCopyAtomPdS = Copy_Atom, elem_type>; - using SmemCopyAtomMask = Copy_Atom; + using SmemCopyAtomMask = Copy_Atom, ElementMask>; using SmemCopyAtomBias = Copy_Atom, elem_type>; using SmemLayoutQdOtransposed = decltype( @@ -427,15 +427,31 @@ struct Flash_bwd_kernel_traits : public Base { : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemBiasdSSize + kSmemPMaskSize) ); - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemElemsPerLoadQKVO = sizeof(cute::uint128_t) / sizeof(Element); + static constexpr int kGmemElemsPerLoadMask = sizeof(cute::uint128_t) / sizeof(ElementMask); + static constexpr int kGmemElemsPerLoadBias = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoadQKVO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadQKVO"); + static_assert(kBlockN % kGmemElemsPerLoadMask == 0, "kBlockN must be a multiple of kGmemElemsPerLoadMask"); + static_assert(kBlockN % kGmemElemsPerLoadBias == 0, "kBlockN must be a multiple of kGmemElemsPerLoadBias"); // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem // to affect speed in practice. - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout< - Shape , Int>, - Stride, _1> + static constexpr int kGmemThreadsPerRowQKVO = kBlockKSmem / kGmemElemsPerLoadQKVO; + static constexpr int kGmemThreadsPerRowMask = kBlockN / kGmemElemsPerLoadMask; + static constexpr int kGmemThreadsPerRowBias = kBlockN / kGmemElemsPerLoadBias; + static_assert(kNThreads % kGmemThreadsPerRowQKVO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowQKVO"); + static_assert(kNThreads % kGmemThreadsPerRowMask == 0, "kNThreads must be a multiple of kGmemThreadsPerRowMask"); + static_assert(kNThreads % kGmemThreadsPerRowBias == 0, "kNThreads must be a multiple of kGmemThreadsPerRowBias"); + using GmemLayoutAtomQKVO = Layout< + Shape , Int>, + Stride, _1> + >; + using GmemLayoutAtomMask = Layout< + Shape , Int>, + Stride, _1> + >; + using GmemLayoutAtomBias = Layout< + Shape , Int>, + Stride, _1> >; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading @@ -448,49 +464,49 @@ struct Flash_bwd_kernel_traits : public Base { using GmemTiledCopyQKV = decltype( make_tiled_copy( Copy_Atom{}, - GmemLayoutAtom{}, + GmemLayoutAtomQKVO{}, Layout>{} ) ); // Val layout, 8 vals per read using GmemTiledCopyMask = decltype( make_tiled_copy( - Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{} + Copy_Atom, ElementMask>{}, + GmemLayoutAtomMask{}, + Layout>{} ) - ); // Val layout, 8 vals per read + ); // Val layout, 16 vals per read using GmemTiledCopyBias = decltype( make_tiled_copy( Copy_Atom{}, - GmemLayoutAtom{}, + GmemLayoutAtomBias{}, Layout>{} ) ); // Val layout, 8 vals per read using GmemTiledCopydO = decltype( make_tiled_copy( Copy_Atom, elem_type>{}, - GmemLayoutAtom{}, + GmemLayoutAtomQKVO{}, Layout>{} ) ); // Val layout, 8 vals per store using GmemTiledCopydBias = decltype( make_tiled_copy( Copy_Atom, elem_type>{}, - GmemLayoutAtom{}, + GmemLayoutAtomBias{}, Layout>{} ) ); // Val layout, 8 vals per read using GmemTiledCopydKV = decltype( make_tiled_copy( Copy_Atom, elem_type>{}, - GmemLayoutAtom{}, + GmemLayoutAtomQKVO{}, Layout>{} ) ); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( make_tiled_copy( Copy_Atom, elem_type>{}, - GmemLayoutAtom{}, + GmemLayoutAtomQKVO{}, Layout>{} ) ); // Val layout, 8 vals per store From 69df087da42a6dbf8341ec4181d8b18a0e1a458d Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 2 Oct 2025 00:33:04 +0800 Subject: [PATCH 13/16] Refactors layout definitions for Gmem in Flash kernel traits for improved readability --- csrc/flash_dmattn/src/kernel_traits.h | 28 +++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/csrc/flash_dmattn/src/kernel_traits.h b/csrc/flash_dmattn/src/kernel_traits.h index 884fe0a..4412984 100644 --- a/csrc/flash_dmattn/src/kernel_traits.h +++ b/csrc/flash_dmattn/src/kernel_traits.h @@ -229,15 +229,15 @@ struct Flash_fwd_kernel_traits : public Base { // Accumulator layout for output using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, - Layout, Stride< _8, _1>>, // Thread layout, 8 threads per row - Layout, Stride< _16, _1>> // Thread layout, 16 threads per row + Layout, Stride<_8, _1>>, // Thread layout, 8 threads per row + Layout, Stride<_16, _1>> // Thread layout, 16 threads per row >; using GmemTiledCopyOaccum = decltype( make_tiled_copy( Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, - Layout>{} + Layout>{} ) ); // Val layout, 4 vals per store }; @@ -442,15 +442,15 @@ struct Flash_bwd_kernel_traits : public Base { static_assert(kNThreads % kGmemThreadsPerRowMask == 0, "kNThreads must be a multiple of kGmemThreadsPerRowMask"); static_assert(kNThreads % kGmemThreadsPerRowBias == 0, "kNThreads must be a multiple of kGmemThreadsPerRowBias"); using GmemLayoutAtomQKVO = Layout< - Shape , Int>, + Shape, Int>, Stride, _1> >; using GmemLayoutAtomMask = Layout< - Shape , Int>, + Shape, Int>, Stride, _1> >; using GmemLayoutAtomBias = Layout< - Shape , Int>, + Shape, Int>, Stride, _1> >; @@ -486,7 +486,7 @@ struct Flash_bwd_kernel_traits : public Base { make_tiled_copy( Copy_Atom, elem_type>{}, GmemLayoutAtomQKVO{}, - Layout>{} + Layout>{} ) ); // Val layout, 8 vals per store using GmemTiledCopydBias = decltype( @@ -500,35 +500,35 @@ struct Flash_bwd_kernel_traits : public Base { make_tiled_copy( Copy_Atom, elem_type>{}, GmemLayoutAtomQKVO{}, - Layout>{} + Layout>{} ) ); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( make_tiled_copy( Copy_Atom, elem_type>{}, GmemLayoutAtomQKVO{}, - Layout>{} + Layout>{} ) ); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, - Layout, Stride< _8, _1>>, // Thread layout, 8 threads per row - Layout, Stride< _16, _1>> // Thread layout, 16 threads per row + Layout, Stride<_8, _1>>, // Thread layout, 8 threads per row + Layout, Stride<_16, _1>> // Thread layout, 16 threads per row >; using GmemTiledCopydQaccum = decltype( make_tiled_copy( Copy_Atom, ElementAccum>{}, GmemLayoutAtomdQaccum{}, - Layout>{} + Layout>{} ) ); // Val layout, 4 vals per store using GmemTiledCopydQaccumAtomicAdd = decltype( make_tiled_copy( Copy_Atom, ElementAccum>{}, - Layout, // Thread layout, 8 threads per row + Layout, // Thread layout, 8 threads per row Stride<_32, _1>>{}, - Layout>{} + Layout>{} ) ); // Val layout, 1 val per store }; From bbfbbc37caf45376bf7fa0bee23e4fd1ec7fb93c Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 2 Oct 2025 00:38:18 +0800 Subject: [PATCH 14/16] Use ElementMask and split mask copy/reduce Standardizes mask dtype to an explicit element type in global/shared memory to fix type mismatches and ensure alignment. Aligns the shared mask buffer via a placeholder and updates the layout to avoid misaligned accesses. Replaces fused mask copy+reduce with a generic copy followed by an explicit OR-reduction and barrier for clearer synchronization and correctness. Unifies bias handling onto the generic copy path. --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 173 ++++++++++------------- 1 file changed, 76 insertions(+), 97 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index c798a02..576de90 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -54,6 +54,7 @@ template(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), make_shape(params.h_mask, binfo.actual_seqlen_q, binfo.actual_seqlen_k), make_stride(params.mask_head_stride, params.mask_row_stride, _1{}) ); @@ -216,13 +217,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{} ); - Tensor sMask = make_tensor( + Tensor sMaskPlace = make_tensor( Has_mask ? sV.data() + size(sV) : sV.data(), - typename Kernel_traits::SmemLayoutAtomPS{} + typename Kernel_traits::SmemLayoutPS{} + ); // For pointers alignment only + Tensor sMask = make_tensor( + make_smem_ptr(reinterpret_cast(sMaskPlace.data().get())), + typename Kernel_traits::SmemLayoutPS{} ); Tensor sBias = make_tensor( - Has_bias ? (Has_mask ? sMask.data() + size(sMask) : sV.data() + size(sV)) : sV.data(), - typename Kernel_traits::SmemLayoutAtomPS{} + Has_bias ? (Has_mask ? sMaskPlace.data() + size(sMaskPlace) : sV.data() + size(sV)) : sV.data(), + typename Kernel_traits::SmemLayoutPS{} ); // Global to Shared Memory operation @@ -364,25 +369,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask(_, _, _, n_block), tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads - - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask(_, _, _, n_block), tMasksMask, - any_active, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for current iteration. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active, + smem_thr_copy_Mask + ); } // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. @@ -394,7 +393,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi binfo.actual_seqlen_k - n_block * kBlockN ); if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block), tBiassBias, tBiascBias, tBiaspBias, @@ -524,24 +523,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi if (n_block > n_block_min) { if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask(_, _, _, n_block - 1), tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask(_, _, _, n_block - 1), tMasksMask, - any_active_next, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for next iteration. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active_next, + smem_thr_copy_Mask + ); } if (any_active_next) { @@ -551,7 +545,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKVcKV, tKVpKV ); if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, tBiascBias, tBiaspBias, @@ -684,24 +678,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi if (n_block > n_block_min) { if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask(_, _, _, n_block - 1), tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask(_, _, _, n_block - 1), tMasksMask, - any_active_next, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for next iteration. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active_next, + smem_thr_copy_Mask + ); } if (any_active_next) { @@ -711,7 +700,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKVcKV, tKVpKV ); if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, tBiascBias, tBiaspBias, @@ -834,6 +823,7 @@ template(params.mask_ptr) + col_offset_mask), + make_gmem_ptr(reinterpret_cast(params.mask_ptr) + col_offset_mask), Shape, Int>{}, make_stride(params.mask_row_stride, _1{}) ); @@ -1001,13 +991,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{} ); - Tensor sMask = make_tensor( + Tensor sMaskPlace = make_tensor( Has_mask ? sV.data() + size(sV) : sV.data(), - typename Kernel_traits::SmemLayoutAtomPS{} + typename Kernel_traits::SmemLayoutPS{} + ); // For pointers alignment only + Tensor sMask = make_tensor( + make_smem_ptr(reinterpret_cast(sMaskPlace.data().get())), + typename Kernel_traits::SmemLayoutPS{} ); Tensor sBias = make_tensor( - Has_bias ? (Has_mask ? sMask.data() + size(sMask) : sV.data() + size(sV)) : sV.data(), - typename Kernel_traits::SmemLayoutAtomPS{} + Has_bias ? (Has_mask ? sMaskPlace.data() + size(sMaskPlace) : sV.data() + size(sV)) : sV.data(), + typename Kernel_traits::SmemLayoutPS{} ); // Global to Shared Memory operation @@ -1115,24 +1109,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, - any_active, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for current iteration. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active, + smem_thr_copy_Mask + ); } // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. @@ -1144,7 +1133,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_k - n_block * kBlockN ); if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, tBiascBias, tBiaspBias, @@ -1305,24 +1294,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons } if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, - any_active_next, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for next iteration. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active_next, + smem_thr_copy_Mask + ); } if (any_active_next) { @@ -1332,9 +1316,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKVcKV, tKVpKV ); if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, + tBiasgBias, tBiassBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); @@ -1492,24 +1476,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons } if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, - any_active_next, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for next iteration. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active_next, + smem_thr_copy_Mask + ); } if (any_active_next) { @@ -1519,9 +1498,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKVcKV, tKVpKV ); if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, + tBiasgBias, tBiassBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); From d226164b849a8c8365e049d2e816efb43e0ce04b Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 2 Oct 2025 00:38:44 +0800 Subject: [PATCH 15/16] Use QKVO threads-per-row for reduction partitioning Aligns reduction configuration with the QKVO-specific per-row thread count to keep template and divisor consistent. Fixes a mismatch that could mis-partition threads, improving correctness and consistency in backward preprocessing. --- csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h index c8a35c0..d6710b2 100644 --- a/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h @@ -148,9 +148,9 @@ inline __device__ void compute_dot_do_o( tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM ); - dot_do_o( + dot_do_o( tdOrdO, tdOrO, dP_sum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow) + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO) ); if (Clear_dQaccum) { // We're actually not zero'ing out all of dQaccum, but only the part that we're going to From c426443f711d7ea345485090a4bdf72878c308b2 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 2 Oct 2025 00:40:46 +0800 Subject: [PATCH 16/16] Corrects mask smem type; unifies copy/reduce Uses a dedicated mask element type with aligned shared memory, separating mask typing from shared buffers to prevent misalignment and aliasing. Replaces combined mask copy+reduce with a generic copy, explicit barrier, and a separate OR-reduction to ensure accurate activity detection. Unifies bias/mask transfers via generic copy utilities and updates the dot-product threading trait, improving correctness across mixed element types and preparing for varied mask formats. --- csrc/flash_dmattn/src/flash_bwd_kernel.h | 71 +++++++++++------------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 3554802..e8578e2 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -80,6 +80,7 @@ template(sMaskPlace.data().get())), + typename Kernel_traits::SmemLayoutMaskBiasPdS{} ); Tensor sP = make_tensor( - sMask.data(), + sMaskPlace.data(), typename Kernel_traits::SmemLayoutMaskBiasPdS{} ); Tensor sPt = make_tensor( - sMask.data(), + sMaskPlace.data(), typename Kernel_traits::SmemLayoutPdStransposed{} ); Tensor sPtNoSwizzle = make_tensor( - sMask.data(), + sMaskPlace.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{} ); // sMask, sP and sdQ share the same memory so be careful Tensor sdQ = make_tensor( - sMask.data(), + sMaskPlace.data(), typename Kernel_traits::SmemLayoutdQ{} ); @@ -572,24 +577,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // // if (cute::thread(1, 0)) { print(tKrK); } if constexpr (Has_mask) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, - any_active, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - m_block * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for current interation. + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active, + smem_thr_copy_Mask + ); } FLASH_NAMESPACE::copy( @@ -601,7 +601,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (any_active) { if constexpr (Has_bias) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, tBiascBias, tBiaspBias, @@ -623,9 +623,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); } if (Is_first) { cute::copy(tdOrdO, tdOsdO); - dot_do_o( + dot_do_o( tdOrdO, tdOrO, gdPsum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow) + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO) ); } @@ -848,7 +848,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in __syncthreads(); if constexpr (Has_bias) { // Write dS to dBias - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_dBias, tBiassBias, tdBiasgdBias, tBiascBias, tBiaspBias, @@ -879,24 +879,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if constexpr (Has_mask) { // Advance gMask tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - (m_block - 1) * kBlockM - // ); - // FLASH_NAMESPACE::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration - - FLASH_NAMESPACE::copy_mask_with_or_reduce( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Mask, tMaskgMask, tMasksMask, - any_active_next, tMaskcMask, tMaskpMask, binfo.actual_seqlen_q - (m_block - 1) * kBlockM ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + __syncthreads(); + // Do OR-reduce on the mask to see if any active threads for next iteration + FLASH_NAMESPACE::mask_or_reduce( + tMasksMask, + any_active_next, + smem_thr_copy_Mask + ); } // Advance gdO @@ -1000,7 +995,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); if (any_active_next) { - FLASH_NAMESPACE::copy_bias( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, tBiascBias, tBiaspBias, @@ -1014,9 +1009,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (Is_first && m_block > m_block_min) { cute::copy(tdOrdO, tdOsdO); - dot_do_o( + dot_do_o( tdOrdO, tdOrO, gdPsum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow) + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO) ); }