From 6a33f4ddb791f67fec5dafe9c4484ad3b3bfe0e3 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Tue, 2 Sep 2025 17:55:35 +0800 Subject: [PATCH 1/5] Optimizes mask handling in backward kernel computation Reduces redundant mask copying and clearing operations by moving mask checks earlier in the loop and eliminating unnecessary register allocations for fully masked blocks. Consolidates async fence calls to reduce synchronization overhead and removes redundant clear operations on accumulator fragments when blocks are skipped due to masking. --- csrc/src/flash_bwd_kernel.h | 57 +++++++------------------------------ 1 file changed, 10 insertions(+), 47 deletions(-) diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index db4f0e7..8681427 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -602,44 +602,20 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in for (; m_block >= m_block_min; --m_block) { Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); cute::cp_async_wait<0>(); __syncthreads(); - // Copy mask from smem to registers and do OR-reduce to see if any active threads - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_view = smem_thr_copy_PdS.retile_D(tSrMask); + // Do OR-reduce on the mask to see if any active threads Tensor tSsMask_view = smem_thr_copy_PdS.retile_S(tSsMask); bool any_active_local = false; #pragma unroll - for (int i = 0; i < size(tSrMask_view); ++i) { - Element m = tSsMask_view(i); - any_active_local |= (m != Element(0)); - tSrMask_view(i) = m; - } + for (int i = 0; i < size(tSsMask_view); ++i) { any_active_local |= (tSsMask_view(i) != Element(0)); } bool any_active = __syncthreads_or(any_active_local); // Early skip for fully masked blocks if (!any_active) { - Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA - CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA - - clear(acc_dp); - - Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_M, MMA_K tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded)); - if (Is_first || Seq_parallel) { - clear(acc_dq); - } else { - Tensor acc_dq_reshaped_load = make_tensor( - acc_dq.data(), - make_layout(get<0>(acc_dq.layout()), get<2>(acc_dq.layout()), get<1>(acc_dq.layout())) - ); - cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped_load); - } if (Double_buffer && m_block > m_block_min) { // Double buffer for sQ @@ -653,23 +629,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tQgQ, tQsQ, tQcQ, tQpQ ); - FLASH_NAMESPACE::cp_async_fence(); } - Tensor tdSrdS = make_tensor(shape(acc_dp)); - clear(tdSrdS); - Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom, AtomNum), MMA_M, MMA_N) - cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); - __syncthreads(); - // Write dS to dBias - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tBiassBias, tdBiasgdBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, - binfo.actual_seqlen_k - n_block * kBlockN - ); - if (m_block > m_block_min) { // Advance gdO tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); @@ -700,7 +661,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tMaskcMask, binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::cp_async_fence(); } if (m_block > m_block_min) { @@ -714,7 +674,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ)); } if (!Double_buffer && m_block > m_block_min) { - __syncthreads(); // Advance gQ tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); FLASH_NAMESPACE::copy( @@ -722,7 +681,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tQgQ, tQsQ, tQcQ, tQpQ ); - FLASH_NAMESPACE::cp_async_fence(); } if (m_block > m_block_min) { @@ -735,9 +693,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tBiascBias, binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::cp_async_fence(); } + FLASH_NAMESPACE::cp_async_fence(); + if (Is_first && m_block > m_block_min) { cute::copy(tdOrdO, tdOsdO); dot_do_o( @@ -750,7 +709,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in __syncthreads(); Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); clear(tdQrdQ); - cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride)); Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M, BLK_K) -> (blk_m, blk_k) Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); @@ -765,6 +723,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in continue; } + clear(acc_s); + Tensor dP_sum = make_fragment_like(lse); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); } @@ -787,7 +747,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - // Copy bias from smem to registers + // Copy mask and bias from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); + cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view); Tensor tSrBias = make_tensor(shape(acc_s)); Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); From 3610bd3a48962154a3c94cf40851b9169d979bf0 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 3 Sep 2025 11:32:26 +0800 Subject: [PATCH 2/5] Optimizes computation by skipping inactive blocks Adds early mask checking to determine if any threads are active before performing expensive computations. This optimization prevents unnecessary work when entire blocks are masked out, improving performance by: - Moving mask evaluation earlier in the computation pipeline - Conditionally executing bias loading and gemm operations - Tracking active state across iterations to avoid redundant work - Reducing memory transfers and computation overhead for masked regions The change maintains correctness while significantly reducing wasted cycles in attention backward pass kernels when dealing with padded sequences. --- csrc/src/flash_bwd_kernel.h | 622 ++++++++++++++++-------------------- 1 file changed, 274 insertions(+), 348 deletions(-) diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index 8681427..850c93c 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -552,24 +552,40 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK); // // if (cute::thread(1, 0)) { print(tKrK); } - FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, - tKgK, tKsK, - tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN - ); FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + cute::cp_async_fence(); + FLASH_NAMESPACE::cp_async_wait<0>(); + + // Do OR-reduce on the mask to see if any active threads + Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask); + bool any_active_local = false; + bool any_active_local_next = false; // to be updated later for next iteration + #pragma unroll + for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } + bool any_active = __syncthreads_or(any_active_local); + bool any_active_next = any_active; // to be updated later for next iteration + + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, + tKgK, tKsK, + tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN ); + + if (any_active) { + FLASH_NAMESPACE::copy_MN( + gmem_tiled_copy_MaskBias, + tBiasgBias, tBiassBias, + tBiascBias, + binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + if (!Kernel_traits::Is_V_in_regs) { FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, @@ -605,252 +621,136 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in cute::cp_async_wait<0>(); __syncthreads(); - // Do OR-reduce on the mask to see if any active threads - Tensor tSsMask_view = smem_thr_copy_PdS.retile_S(tSsMask); - bool any_active_local = false; - #pragma unroll - for (int i = 0; i < size(tSsMask_view); ++i) { any_active_local |= (tSsMask_view(i) != Element(0)); } - bool any_active = __syncthreads_or(any_active_local); - - // Early skip for fully masked blocks - if (!any_active) { - - tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded)); - - if (Double_buffer && m_block > m_block_min) { - // Double buffer for sQ - const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ); - tQsQ.data() = tQsQ.data() + sQ_offset; - tSsQ.data() = tSsQ.data() + sQ_offset; - // Advance gQ - tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); - FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, - tQgQ, tQsQ, - tQcQ, tQpQ - ); - } - - if (m_block > m_block_min) { - // Advance gdO - tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); - if (Is_first) { - tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride)); - FLASH_NAMESPACE::copy( - gmem_tiled_copy_dO, - tdOgdO, tdOrdO, - tQcQ, tQpQ - ); - FLASH_NAMESPACE::copy( - gmem_tiled_copy_dO, - tdOgO, tdOrO, - tQcQ, tQpQ - ); - } else { - FLASH_NAMESPACE::copy( - gmem_tiled_copy_dO, - tdOgdO, tdOsdO, - tQcQ, tQpQ - ); - } - // Advance gMask - tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask, tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); - } + Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // (MMA=4, MMA_M, MMA_K) - if (m_block > m_block_min) { - gLSE.data() = gLSE.data() + (-int(kBlockM)); - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); } - gdPsum.data() = gdPsum.data() + (-int(kBlockM)); - } + if (any_active) { + clear(acc_s); - if (Double_buffer) { // Double buffer for sQ - tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ)); - } - if (!Double_buffer && m_block > m_block_min) { - // Advance gQ - tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); - FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, - tQgQ, tQsQ, - tQcQ, tQpQ - ); + Tensor dP_sum = make_fragment_like(lse); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); } + + // if (cute::thread0()) { print(sK); } + // Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK); + // #pragma unroll + // for (int k = 0; k < size<2>(tSrK_copy_view); ++k) { + // cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k)); + // } + // if (cute::thread0()) { print(tSrK); } + FLASH_NAMESPACE::gemm( + acc_s, + tSrQ, tSrK, tSsQ, tSsK, + tiled_mma_sdp, + smem_tiled_copy_QdO, smem_tiled_copy_KV, + smem_thr_copy_QdO, smem_thr_copy_KV + ); + if constexpr (Is_softcap) { + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - if (m_block > m_block_min) { - // Advance gBias and gdBias - tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); - tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); + // Copy mask and bias from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); + cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view); + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); + cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); + + // Reshape acc_s, mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); + Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); + Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); + // if (cute::thread(32, 0)) { print(scores); } + + // Softcapping - calculating dTanh and scaling dS later with it + [[maybe_unused]] Tensor dtanh = make_tensor_like(scores); + if constexpr (Is_softcap) { + FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap); } - FLASH_NAMESPACE::cp_async_fence(); + // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond + // actual_seqlen_k, because acc_s would be some finite value for those indices. + // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, + // so the result would still be correct. + // However, it's possible that the values in acc_s are so large that they overflow + // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. + // So we need to mask out the elements beyond actual_seqlen_k. + FLASH_NAMESPACE::apply_mask( + scores, mask, bias, params.scale_softmax, + n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, + m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + AtomLayoutMS * 16 + ); - if (Is_first && m_block > m_block_min) { - cute::copy(tdOrdO, tdOsdO); - dot_do_o( - tdOrdO, tdOrO, gdPsum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow) - ); - } + // if (cute::thread(32, 0)) { print(scores); } + // Compute the exponential value. + // FLASH_NAMESPACE::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + FLASH_NAMESPACE::scale_apply_exp2(scores, lse, float(M_LOG2E)); + // Convert scores from fp32 to fp16/bf16 + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tPrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom, AtomNum), MMA_M, MMA_N) + cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); + // if (cute::thread0()) { print(tPaP); } + // __syncthreads(); + // if (cute::thread0()) { print(sP); } + + clear(acc_dp); + // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_dp.layout())); + // #pragma unroll + // for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) { + // #pragma unroll + // for (int ni = 0; ni < size<1>(acc_dp_reshaped); ++ni) { + // acc_dp_reshaped(mi, ni) = -dP_sum(mi); + // } + // } + + // if (cute::thread0()) { print(dP_sum); } + + FLASH_NAMESPACE::gemm( + acc_dp, + tdPrdO, tdPrV, tdPsdO, tdPsV, + tiled_mma_sdp, + smem_tiled_copy_QdO, smem_tiled_copy_KV, + smem_thr_copy_QdO, smem_thr_copy_KV + ); - if (Is_last) { - __syncthreads(); - Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); - clear(tdQrdQ); - tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride)); - Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M, BLK_K) -> (blk_m, blk_k) - Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); + // Reshape acc_dp from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) + Tensor dS = make_tensor(acc_dp.data(), scores.layout()); + auto pointwise_mult = [](float p, float dp, float d) { + return p * (p >= 0 ? dp - d : d); + }; + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { #pragma unroll - for (int m = 0; m < size<1>(tdQgdQ); ++m) { - if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { - cute::copy(gmem_tiled_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _)); - } + for (int ni = 0; ni < size<1>(dS); ++ni) { + float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); + if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } + dS(mi, ni) = scaled_ds; } } - - continue; + // if (cute::thread0()) { print(dS); } } - clear(acc_s); - - Tensor dP_sum = make_fragment_like(lse); - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); } - - // if (cute::thread0()) { print(sK); } - // Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK); - // #pragma unroll - // for (int k = 0; k < size<2>(tSrK_copy_view); ++k) { - // cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k)); - // } - // if (cute::thread0()) { print(tSrK); } - FLASH_NAMESPACE::gemm( - acc_s, - tSrQ, tSrK, tSsQ, tSsK, - tiled_mma_sdp, - smem_tiled_copy_QdO, smem_tiled_copy_KV, - smem_thr_copy_QdO, smem_thr_copy_KV - ); - if constexpr (Is_softcap) { - FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); - } - - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); - cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); - cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); - - // Reshape acc_s, mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); - Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); - Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); - // if (cute::thread(32, 0)) { print(scores); } - - // Softcapping - calculating dTanh and scaling dS later with it - [[maybe_unused]] Tensor dtanh = make_tensor_like(scores); - if constexpr (Is_softcap) { - FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap); - } - - // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond - // actual_seqlen_k, because acc_s would be some finite value for those indices. - // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, - // so the result would still be correct. - // However, it's possible that the values in acc_s are so large that they overflow - // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. - // So we need to mask out the elements beyond actual_seqlen_k. - FLASH_NAMESPACE::apply_mask( - scores, mask, bias, params.scale_softmax, - n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, - m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, - AtomLayoutMS * 16 - ); - - // if (cute::thread(32, 0)) { print(scores); } - // Compute the exponential value. - // FLASH_NAMESPACE::scale_apply_exp2(scores, lse, params.scale_softmax_log2); - FLASH_NAMESPACE::scale_apply_exp2(scores, lse, float(M_LOG2E)); - // Convert scores from fp32 to fp16/bf16 - Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); - // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tPrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); - Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom, AtomNum), MMA_M, MMA_N) - cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); - // if (cute::thread0()) { print(tPaP); } - // __syncthreads(); - // if (cute::thread0()) { print(sP); } - - Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA - CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA - - clear(acc_dp); - // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_dp.layout())); - // #pragma unroll - // for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) { - // #pragma unroll - // for (int ni = 0; ni < size<1>(acc_dp_reshaped); ++ni) { - // acc_dp_reshaped(mi, ni) = -dP_sum(mi); - // } - // } - - // if (cute::thread0()) { print(dP_sum); } - - FLASH_NAMESPACE::gemm( - acc_dp, - tdPrdO, tdPrV, tdPsdO, tdPsV, - tiled_mma_sdp, - smem_tiled_copy_QdO, smem_tiled_copy_KV, - smem_thr_copy_QdO, smem_thr_copy_KV - ); + tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded)); - // Reshape acc_dp from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) - Tensor dS = make_tensor(acc_dp.data(), scores.layout()); - auto pointwise_mult = [](float p, float dp, float d) { - return p * (p >= 0 ? dp - d : d); - }; - #pragma unroll - for (int mi = 0; mi < size<0>(dS); ++mi) { - #pragma unroll - for (int ni = 0; ni < size<1>(dS); ++ni) { - float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); - if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } - dS(mi, ni) = scaled_ds; + if (any_active) { + if (Is_first || Seq_parallel) { + clear(acc_dq); + } else { + // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum + Tensor acc_dq_reshaped = make_tensor( + acc_dq.data(), + make_layout(get<0>(acc_dq.layout()), get<2>(acc_dq.layout()), get<1>(acc_dq.layout())) + ); + cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped); } } - // if (cute::thread0()) { print(dS); } - - Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_M, MMA_K - tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded)); - if (Is_first || Seq_parallel) { - clear(acc_dq); - } else { - // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum - Tensor acc_dq_reshaped = make_tensor( - acc_dq.data(), - make_layout(get<0>(acc_dq.layout()), get<2>(acc_dq.layout()), get<1>(acc_dq.layout())) - ); - cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped); - } if (Double_buffer && m_block > m_block_min) { // Double buffer for sQ @@ -859,69 +759,52 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tSsQ.data() = tSsQ.data() + sQ_offset; // Advance gQ tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); - FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, - tQgQ, tQsQ, - tQcQ, tQpQ - ); - FLASH_NAMESPACE::cp_async_fence(); + if (any_active) { + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, + tQgQ, tQsQ, + tQcQ, tQpQ + ); + FLASH_NAMESPACE::cp_async_fence(); + } } - Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); - // Convert dS from fp32 to fp16 - Tensor tdSrdS = FLASH_NAMESPACE::convert_type(dS_reshaped); - Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom, AtomNum), MMA_M, MMA_N) - cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); - __syncthreads(); - // Write dS to dBias - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tBiassBias, tdBiasgdBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, - binfo.actual_seqlen_k - n_block * kBlockN - ); + if (any_active) { + // Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); + // Convert dS from fp32 to fp16 + Tensor tdSrdS = FLASH_NAMESPACE::convert_type(acc_dp); + Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom, AtomNum), MMA_M, MMA_N) + cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); + __syncthreads(); + // Write dS to dBias + FLASH_NAMESPACE::copy_MN( + gmem_tiled_copy_MaskBias, + tBiassBias, tdBiasgdBias, + tBiascBias, + binfo.actual_seqlen_q - m_block * kBlockM, + binfo.actual_seqlen_k - n_block * kBlockN + ); - // if (cute::thread0()) { print(tPrP); } - // Layout p_l = tPrP.layout(); - // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l))); - // FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); - // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); - // FLASH_NAMESPACE::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); - FLASH_NAMESPACE::gemm( - acc_dv, - tdVrPt, tdVrdO, tdVsPt, tdVsdOt, - tiled_mma_dkv, - smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, - smem_thr_copy_PdSt, smem_thr_copy_QdOt - ); - // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } - // if (cute::thread0()) { print(acc_dv); } + // if (cute::thread0()) { print(tPrP); } + // Layout p_l = tPrP.layout(); + // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l))); + // FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); + // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); + // FLASH_NAMESPACE::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); + FLASH_NAMESPACE::gemm( + acc_dv, + tdVrPt, tdVrdO, tdVsPt, tdVsdOt, + tiled_mma_dkv, + smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, + smem_thr_copy_PdSt, smem_thr_copy_QdOt + ); + // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } + // if (cute::thread0()) { print(acc_dv); } - __syncthreads(); // Need syncthreads since we're writing to the same sdO location + __syncthreads(); // Need syncthreads since we're writing to the same sdO location + } if (m_block > m_block_min) { - // Advance gdO - tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); - if (Is_first) { - tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride)); - FLASH_NAMESPACE::copy( - gmem_tiled_copy_dO, - tdOgdO, tdOrdO, - tQcQ, tQpQ - ); - FLASH_NAMESPACE::copy( - gmem_tiled_copy_dO, - tdOgO, tdOrO, - tQcQ, tQpQ - ); - } else { - FLASH_NAMESPACE::copy( - gmem_tiled_copy_dO, - tdOgdO, tdOsdO, - tQcQ, tQpQ - ); - } // Advance gMask tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); FLASH_NAMESPACE::copy_MN( @@ -931,16 +814,51 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); FLASH_NAMESPACE::cp_async_fence(); + FLASH_NAMESPACE::cp_async_wait<0>(); + + // Do OR-reduce on the mask to see if any active threads for next iteration + any_active_local_next = false; + #pragma unroll + for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); } + any_active_next = __syncthreads_or(any_active_local_next); + + // Advance gdO + tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); + if (any_active_next) { + if (Is_first) { + tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride)); + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dO, + tdOgdO, tdOrdO, + tQcQ, tQpQ + ); + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dO, + tdOgO, tdOrO, + tQcQ, tQpQ + ); + } else { + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dO, + tdOgdO, tdOsdO, + tQcQ, tQpQ + ); + } + + FLASH_NAMESPACE::cp_async_fence(); + } } - FLASH_NAMESPACE::gemm( - acc_dq, - tdQrdS, tdQrKt, tdQsdS, tdQsKt, - tiled_mma_dq, - smem_tiled_copy_dS, smem_tiled_copy_Kt, - smem_thr_copy_dS, smem_thr_copy_Kt - ); - // if (cute::thread0()) { print(acc_dq); } + if (any_active) { + FLASH_NAMESPACE::gemm( + acc_dq, + tdQrdS, tdQrKt, tdQsdS, tdQsKt, + tiled_mma_dq, + smem_tiled_copy_dS, smem_tiled_copy_Kt, + smem_thr_copy_dS, smem_thr_copy_Kt + ); + // if (cute::thread0()) { print(acc_dq); } + } if (m_block > m_block_min) { gLSE.data() = gLSE.data() + (-int(kBlockM)); @@ -949,67 +867,73 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in gdPsum.data() = gdPsum.data() + (-int(kBlockM)); } - if (!Is_last) { - // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum - Tensor acc_dq_reshaped = make_tensor( - acc_dq.data(), - make_layout(get<0>(acc_dq.layout()), get<2>(acc_dq.layout()), get<1>(acc_dq.layout())) - ); - if (!Seq_parallel) { - cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum); + if (any_active) { + if (!Is_last) { + // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum + Tensor acc_dq_reshaped = make_tensor( + acc_dq.data(), + make_layout(get<0>(acc_dq.layout()), get<2>(acc_dq.layout()), get<1>(acc_dq.layout())) + ); + if (!Seq_parallel) { + cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum); + } else { + // if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); } + CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); } + } } else { - // if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); } - CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); } + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; } + // Convert acc_dq from fp32 to fp16 + Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_M, MMA_K) + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); } - } else { - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; } - // Convert acc_dq from fp32 to fp16 - Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); - Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_M, MMA_K) - cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + + FLASH_NAMESPACE::gemm( + acc_dk, + tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, + tiled_mma_dkv, + smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, + smem_thr_copy_PdSt, smem_thr_copy_QdOt + ); + // if (cute::thread0()) { print(acc_dk); } + + __syncthreads(); // Need syncthreads since we're using the sBias smem for accumulating acc_dk } - FLASH_NAMESPACE::gemm( - acc_dk, - tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, - tiled_mma_dkv, - smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, - smem_thr_copy_PdSt, smem_thr_copy_QdOt - ); - // if (cute::thread0()) { print(acc_dk); } if (Double_buffer) { // Double buffer for sQ tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ)); } if (!Double_buffer && m_block > m_block_min) { - __syncthreads(); // Advance gQ tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); - FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, - tQgQ, tQsQ, - tQcQ, tQpQ - ); - FLASH_NAMESPACE::cp_async_fence(); + if (any_active_next) { + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, + tQgQ, tQsQ, + tQcQ, tQpQ + ); + } } - __syncthreads(); // Need syncthreads since we're using the sBias smem for accumulating acc_dk - if (m_block > m_block_min) { // Advance gBias and gdBias tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); - FLASH_NAMESPACE::cp_async_fence(); + if (any_active_next) { + FLASH_NAMESPACE::copy_MN( + gmem_tiled_copy_MaskBias, + tBiasgBias, tBiassBias, + tBiascBias, + binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + ); + } } + FLASH_NAMESPACE::cp_async_fence(); + if (Is_first && m_block > m_block_min) { cute::copy(tdOrdO, tdOsdO); dot_do_o( @@ -1033,6 +957,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } } + any_active = any_active_next; + } From 6b855d2461da368e7ef8af323f0245771f1c8854 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 3 Sep 2025 11:40:53 +0800 Subject: [PATCH 3/5] Conditionally returns bias gradient based on input Tracks whether bias parameter was provided during forward pass and only returns bias gradient during backward pass when bias was originally given. Prevents unnecessary computation and memory allocation for bias gradients when no bias is used in the attention mechanism. --- flash_dmattn/flash_dmattn_interface.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index d4f968f..1acdfdd 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -420,7 +420,9 @@ def forward( ) if mask is None: mask = torch.ones((batch_size, num_heads_k, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) + return_dbias = True if bias is None: + return_dbias = False bias = torch.zeros((batch_size, num_heads_k, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -465,6 +467,7 @@ def forward( ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic + ctx.return_dbias = return_dbias out = out_padded[..., :head_size_og] return out if not return_softmax else (out, softmax_lse, S_dmask) @@ -510,7 +513,9 @@ def backward( dk = dk[:, : ctx.seqlen_k, :, :] dv = dv[:, : ctx.seqlen_k, :, :] dbias = dbias[..., : ctx.seqlen_k] - return dq, dk, dv, None, dbias, None, None, None, None, None, None + if ctx.return_dbias: + return dq, dk, dv, None, dbias, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None class FlashDMAttnVarlenFunc(torch.autograd.Function): @@ -544,8 +549,10 @@ def forward( ) if mask is None: mask = torch.ones((total_q, num_heads_k, max_seqlen_k), dtype=q.dtype, device=q.device) + return_dbias = True if bias is None: bias = torch.zeros((total_q, num_heads_k, max_seqlen_k), dtype=q.dtype, device=q.device) + return_dbias = False if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if is_causal is None: @@ -606,6 +613,7 @@ def forward( ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic + ctx.return_dbias = return_dbias out = out_padded[..., :head_size_og] if return_softmax: @@ -658,7 +666,9 @@ def backward( if ctx.seqlen_k_og != ctx.max_seqlen_k: dbias = dbias[:, :, :ctx.seqlen_k_og] - return dq, dk, dv, None, dbias, None, None, None, None, None, None, None, None, None, None, None + if ctx.return_dbias: + return dq, dk, dv, None, dbias, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_dmattn_func( From 89223d57eaaf628d844e98b2490f0e3a67fda422 Mon Sep 17 00:00:00 2001 From: Jingze Shi Date: Wed, 3 Sep 2025 11:45:08 +0800 Subject: [PATCH 4/5] Update csrc/src/flash_bwd_kernel.h Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- csrc/src/flash_bwd_kernel.h | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index 850c93c..87376fc 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -844,7 +844,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tQcQ, tQpQ ); } - FLASH_NAMESPACE::cp_async_fence(); } } From 21af55867f54f95d49baf20a79b0365ec900b633 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 3 Sep 2025 11:50:01 +0800 Subject: [PATCH 5/5] Fixes variable initialization in flash attention kernels Initializes any_active_next to false instead of any_active to prevent potential issues with unintended carry-over of active state between iterations in the kernel loops. Changes affect both forward and backward kernel implementations to ensure consistent behavior across the codebase. --- csrc/src/flash_bwd_kernel.h | 2 +- csrc/src/flash_fwd_kernel.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index 87376fc..643cfcd 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -568,7 +568,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in #pragma unroll for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } bool any_active = __syncthreads_or(any_active_local); - bool any_active_next = any_active; // to be updated later for next iteration + bool any_active_next = false; // to be updated later for next iteration FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 6cb5209..77a5a19 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -362,7 +362,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi #pragma unroll for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } bool any_active = __syncthreads_or(any_active_local); - bool any_active_next = any_active; // to be updated later for next iteration + bool any_active_next = false; // to be updated later for next iteration // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. if (any_active) { @@ -1016,7 +1016,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons #pragma unroll for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } bool any_active = __syncthreads_or(any_active_local); - bool any_active_next = any_active; // to be updated later for next iteration + bool any_active_next = false; // to be updated later for next iteration // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. if (any_active) {