From 55af05a698541c3e84212d0c7a9ce299999e83f2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 22 May 2025 06:19:58 +0000 Subject: [PATCH 1/2] Initial plan for issue From 882cf8997d0db6d430cd9b0450275fbbdd8653c2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 22 May 2025 06:28:27 +0000 Subject: [PATCH 2/2] Fix attention score calculation in flash_attention_fwd_kernel.h Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com> --- csrc/src/flash_attention_fwd_kernel.h | 12 +- csrc/src/flash_attention_fwd_kernel.h.bak | 699 ++++++++++++++++++++++ 2 files changed, 707 insertions(+), 4 deletions(-) create mode 100644 csrc/src/flash_attention_fwd_kernel.h.bak diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index 99caf1a..b5f0471 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -454,8 +454,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto mask_values_row = sDynamicMaskValues(m_idx, _); auto predicate_k_row = sPredicate(m_idx, _); if (predicate_k_row(k_idx)) { - // Scale the attention score before adding mask value, matching Python's behavior - acc_s(mma, mi, ki) = acc_s(mma, mi, ki) * params.scale_softmax + static_cast(mask_values_row(k_idx)); + // First scale the attention score + ElementAccum scaled_score = acc_s(mma, mi, ki) * params.scale_softmax; + // Then add the mask value, matching Python's behavior + acc_s(mma, mi, ki) = scaled_score + static_cast(mask_values_row(k_idx)); } else { // For positions where mask is 0, set attention score to -INFINITY so they don't contribute to softmax acc_s(mma, mi, ki) = -INFINITY; @@ -571,8 +573,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto mask_values_row = sDynamicMaskValues(m_idx, _); auto predicate_k_row = sPredicate(m_idx, _); if (predicate_k_row(k_idx)) { - // Scale the attention score before adding mask value, matching Python's behavior - acc_s(mma, mi, ki) = acc_s(mma, mi, ki) * params.scale_softmax + static_cast(mask_values_row(k_idx)); + // First scale the attention score + ElementAccum scaled_score = acc_s(mma, mi, ki) * params.scale_softmax; + // Then add the mask value, matching Python's behavior + acc_s(mma, mi, ki) = scaled_score + static_cast(mask_values_row(k_idx)); } else { // For positions where mask is 0, set attention score to -INFINITY so they don't contribute to softmax acc_s(mma, mi, ki) = -INFINITY; diff --git a/csrc/src/flash_attention_fwd_kernel.h.bak b/csrc/src/flash_attention_fwd_kernel.h.bak new file mode 100644 index 0000000..622c195 --- /dev/null +++ b/csrc/src/flash_attention_fwd_kernel.h.bak @@ -0,0 +1,699 @@ +/****************************************************************************** + * Copyright (c) 2025, Jingze Shi and Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "namespace_config.h" +#include // For at::cuda::philox::unpack + +#include + +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" +#include "mask.h" + +namespace FLASH_NAMESPACE { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bidb, const int bidh, const int m_block, const BlockInfo &binfo) { + // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path. + // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick. + // Otherwise, it's written as (h, b, seqlen_q). + const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped; + auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0; + auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + lse_offset); + + auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q); + auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : ( + params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1) + ); + + auto lse_layout = make_layout(lse_shape, lse_stride); + Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout); + auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _); + return local_tile(mLSE_slice, Shape>{}, make_coord(m_block)); +} + +template +inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; // query_block_len + constexpr int kBlockN = Kernel_traits::kBlockN; // key_block_len + constexpr int kHeadDim = Kernel_traits::kHeadDim; // head_dim + + // Check if there are any queries to process in the block + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + // Compute the actual range of N blocks to process + const int n_block_min = 0; + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal) { + n_block_max = std::min( + n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN) + ); + } + + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + if ((Is_causal || !Is_even_MN) && n_block_max <= n_block_min) { + Tensor mO = make_tensor( + make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), make_stride(params.o_row_stride, params.o_head_stride, _1{}) + ); + Tensor gO = local_tile( + mO(_, bidh, _), + Shape, Int>{}, + make_coord(m_block, 0) + ); // (kBlockM, kHeadDim) + + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + } + return; + } + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + bool has_causal_mask = params.causal_mask_ptr != nullptr && Is_causal; + + // Golobal memory tensor configuration + Tensor mQ = make_tensor( + make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{}) + ); + Tensor gQ = local_tile( + mQ(_, bidh, _), + Shape, Int>{}, + make_coord(m_block, 0) + ); // (kBlockM, kHeadDim) + + Tensor mK = make_tensor( + make_gmem_ptr(reinterpret_cast(params.k_ptr) + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.k_row_stride, params.k_head_stride, _1{}) + ); + Tensor gK = local_tile( + mK(_, bidh / params.h_h_k_ratio, _), + Shape, Int>{}, + make_coord(_, 0) + ); // (kBlockN, kHeadDim, nblocksN) + + Tensor mV = make_tensor( + make_gmem_ptr(reinterpret_cast(params.v_ptr) + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.v_row_stride, params.v_head_stride, _1{}) + ); + Tensor gV = local_tile( + mV(_, bidh / params.h_h_k_ratio, _), + Shape, Int>{}, + make_coord(_, 0) + ); // (kBlockN, kHeadDim, nblocksN) + + Tensor gP = make_tensor( + make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{}) + ); + + Tensor mZeroHold = make_tensor( + make_gmem_ptr(reinterpret_cast(params.zero_hold_ptr) + binfo.q_offset(params.zero_hold_batch_stride, params.zero_hold_row_stride, bidb)), + make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), + make_stride(params.zero_hold_head_stride, params.zero_hold_row_stride, _1{}) + ); + Tensor gZeroHold = local_tile( + mZeroHold(bidh / params.h_h_k_ratio, _, _), + Shape, Int>{}, + make_coord(m_block, n_block_max - 1) + ); // (kBlockM, kBlockN) + + // Shared memory layout configuration + Tensor sQ = make_tensor( + make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{} + ); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor( + sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{} + ); + Tensor sV = make_tensor( + sK.data() + size(sK), + typename Kernel_traits::SmemLayoutKV{} + ); + Tensor sVt = make_tensor( + sV.data(), + typename Kernel_traits::SmemLayoutVtransposed{} + ); + Tensor sVtNoSwizzle = make_tensor( + sV.data().get(), + typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{} + ); + + // Dynamic mask related shared memory. Use a running char* pointer for robust allocation. + char* dynamic_smem_current_ptr = reinterpret_cast(sV.data().get()) + size(sV) * sizeof(Element); + Tensor sZeroHold = make_tensor( + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // Element type + typename Kernel_traits::SmemLayoutZeroHold{} + ); + + dynamic_smem_current_ptr += Kernel_traits::kSmemZeroHoldSize; + Tensor sDynamicMaskValues = make_tensor( + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // float type + typename Kernel_traits::SmemLayoutDynamicMaskValues{} + ); + + dynamic_smem_current_ptr += Kernel_traits::kSmemMaskValuesSize; + Tensor sDynamicMaskSortKeys = make_tensor( + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // float type + typename Kernel_traits::SmemLayoutDynamicMaskSortKeys{} + ); + + dynamic_smem_current_ptr += Kernel_traits::kSmemSortKeysSize; + Tensor sDynamicMaskSortIndices = make_tensor( + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // int type + typename Kernel_traits::SmemLayoutDynamicMaskSortIndices{} + ); + + dynamic_smem_current_ptr += Kernel_traits::kSmemSortIndicesSize; + Tensor sNonZeroIndices = make_tensor( + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // int type + typename Kernel_traits::SmemLayoutNonZeroIndices{} + ); + + dynamic_smem_current_ptr += Kernel_traits::kSmemNonZeroIndicesSize; + Tensor sPredicate = make_tensor( + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // bool type + typename Kernel_traits::SmemLayoutPredicate{} + ); + + // Golobal to Shared Memory operation + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyZeroHold gmem_tiled_copy_ZeroHold; + auto gmem_thr_copy_ZeroHold = gmem_tiled_copy_ZeroHold.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + Tensor tZeroHoldgZeroHold = gmem_thr_copy_ZeroHold.partition_S(gZeroHold); + Tensor tZeroHoldsZeroHold = gmem_thr_copy_ZeroHold.partition_D(sZeroHold); + + // Matrix Multiply Accumulate + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + Tensor tSgS = thr_mma.partition_C(gP); + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // Copy Atom retiling + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // For sZeroHold -> registers (if needed, though mask.h operates on smem directly) + // auto smem_tiled_copy_ZeroHold = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + // auto smem_thr_copy_ZeroHold = smem_tiled_copy_ZeroHold.get_thread_slice(tidx); + // Tensor tSsZeroHold = smem_thr_copy_ZeroHold.partition_S(sZeroHold); + + // PREDICATES + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Identity tensor for gZeroHold -> sZeroHold copy + Tensor cZeroHold = make_identity_tensor(make_shape(size<0>(sZeroHold), size<1>(sZeroHold))); + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + // Predicate for ZeroHold GMEM copy + Tensor tZeroHoldcZeroHold = gmem_thr_copy_ZeroHold.partition_S(cZeroHold); + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + Tensor tZeroHoldpZeroHold = make_tensor(make_shape(size<2>(tZeroHoldsZeroHold))); // N-dim predicate for ZeroHold + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + #pragma unroll + for (int k = 0; k < size(tZeroHoldpZeroHold); ++k) { + tZeroHoldpZeroHold(k) = true; // All elements are valid for the moment + } + } + + // Prologue + // Init dynamic mask processor + DynamicMask dynamic_mask(params.keep_window_size); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM + ); + if (Kernel_traits::Is_Q_in_regs) { + cute::cp_async_fence(); + } + // If share Q and K smem, wait and sync + if (Kernel_traits::Share_Q_K_smem) { + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + // Reverse iteration over N blocks + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, + tKgK(_, _, _, n_block), + tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN + ); + cute::cp_async_fence(); + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + FLASH_NAMESPACE::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + // For ZeroHold, Is_even_K in copy refers to the kBlockN dimension alignment for vectorization, + // which is generally true. The boundary is handled by the length argument. + FLASH_NAMESPACE::copy( + gmem_tiled_copy_ZeroHold, + tZeroHoldgZeroHold, + tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold, + binfo.actual_seqlen_k - n_block * kBlockN + ); + cute::cp_async_fence(); + + clear(acc_o); + + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + // Calculating the actual number of keys in the block + const int block_key_len = min(kBlockN, binfo.actual_seqlen_k - n_block * kBlockN); + + // Process dynamic mask for each query row in the current block + for (int m_idx = 0; m_idx < kBlockM; ++m_idx) { + // Get the global index of the current query + const int query_idx = m_block * kBlockM + m_idx; + if (query_idx >= binfo.actual_seqlen_q) { + continue; + } + + // Apply the dynamic mask to the current query row + auto mask_values_row = sDynamicMaskValues(m_idx, _); // float + auto zero_hold_row = sZeroHold(m_idx, _); // half/bfloat16 + auto sort_keys_row = sDynamicMaskSortKeys(m_idx, _); // float + auto sort_indices_row = sDynamicMaskSortIndices(m_idx, _); // int + dynamic_mask.template apply_mask_1rowblock< + typename decltype(mask_values_row)::engine_type, typename decltype(mask_values_row)::layout_type, + typename decltype(zero_hold_row)::engine_type, typename decltype(zero_hold_row)::layout_type, + typename decltype(sort_keys_row)::engine_type, typename decltype(sort_keys_row)::layout_type, + typename decltype(sort_indices_row)::engine_type, typename decltype(sort_indices_row)::layout_type, + Element, Is_causal + >( + mask_values_row, + zero_hold_row, + query_idx, + block_key_len, + mask_values_row, + sort_keys_row, + sort_indices_row + ); + __syncthreads(); + // Find the non-zero positions + auto predicate_k_row = sPredicate(m_idx, _); // bool + for (int k_idx = tidx; k_idx < block_key_len; k_idx += blockDim.x) { + predicate_k_row(k_idx) = (mask_values_row(k_idx) != 0.0f); + } + __syncthreads(); + } + + // Execute sparse matrix multiplication + FLASH_NAMESPACE::sparse_gemm( + acc_s, + tSrQ, + tSrK, tSsQ, tSsK, + tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K, + sPredicate // Active key predicates + ); + + // Apply mask values to attention scores (zero_hold states contain mask values to add to attention scores) + for (int mma = 0; mma < size<0>(acc_s); ++mma) { + for (int mi = 0; mi < size<1>(acc_s); ++mi) { + for (int ki = 0; ki < size<2>(acc_s); ++ki) { + int m_idx = mi; + int k_idx = ki; + if (m_idx < kBlockM && k_idx < block_key_len) { + auto mask_values_row = sDynamicMaskValues(m_idx, _); + auto predicate_k_row = sPredicate(m_idx, _); + if (predicate_k_row(k_idx)) { + // First scale the attention score + acc_s(mma, mi, ki) = acc_s(mma, mi, ki) * params.scale_softmax + static_cast(mask_values_row(k_idx)); + } else { + // For positions where mask is 0, set attention score to -INFINITY so they don't contribute to softmax + acc_s(mma, mi, ki) = -INFINITY; + } + } + } + } + } + + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o(acc_s, acc_o, 1.0f) + : softmax.template softmax_rescale_o(acc_s, acc_o, 1.0f); + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + if (Return_softmax) { + tSgS.data() = tSgS.data() + (-kBlockN); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + // if (cute::thread0()) { print(tOrP); } + FLASH_NAMESPACE::sparse_gemm_rs( + acc_o, + tOrP, tOrVt, tOsVt, + tiled_mma, smem_tiled_copy_V, smem_thr_copy_V, + sPredicate // 应用相同的谓词来进行稀疏V矩阵乘法 + ); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + // calculate the actual number of keys in the block + const int block_key_len = min(kBlockN, binfo.actual_seqlen_k - n_block * kBlockN); + + // Process dynamic mask for each query row in the current block + for (int m_idx = 0; m_idx < kBlockM; ++m_idx) { + // Get the global index of the current query + const int query_idx = m_block * kBlockM + m_idx; + if (query_idx >= binfo.actual_seqlen_q) { + continue; + } + + // Apply the dynamic mask to the current query row + auto mask_values_row = sDynamicMaskValues(m_idx, _); // float + auto zero_hold_row = sZeroHold(m_idx, _); // half/bfloat16 + auto sort_keys_row = sDynamicMaskSortKeys(m_idx, _); // float + auto sort_indices_row = sDynamicMaskSortIndices(m_idx, _); // int + dynamic_mask.template apply_mask_1rowblock< + typename decltype(mask_values_row)::engine_type, typename decltype(mask_values_row)::layout_type, + typename decltype(zero_hold_row)::engine_type, typename decltype(zero_hold_row)::layout_type, + typename decltype(sort_keys_row)::engine_type, typename decltype(sort_keys_row)::layout_type, + typename decltype(sort_indices_row)::engine_type, typename decltype(sort_indices_row)::layout_type, + Element, /*Is_causal=*/false + >( + mask_values_row, + zero_hold_row, + query_idx, + block_key_len, + mask_values_row, + sort_keys_row, + sort_indices_row + ); + __syncthreads(); + // Find the non-zero positions + auto predicate_k_row = sPredicate(m_idx, _); // bool + for (int k_idx = tidx; k_idx < block_key_len; k_idx += blockDim.x) { + predicate_k_row(k_idx) = (mask_values_row(k_idx) != 0.0f); + } + __syncthreads(); + } + + FLASH_NAMESPACE::sparse_gemm( + acc_s, + tSrQ, + tSrK, tSsQ, tSsK, + tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K, + sPredicate // Active key predicates + ); + + // Apply mask values to attention scores (zero_hold states contain mask values to add to attention scores) + for (int mma = 0; mma < size<0>(acc_s); ++mma) { + for (int mi = 0; mi < size<1>(acc_s); ++mi) { + for (int ki = 0; ki < size<2>(acc_s); ++ki) { + int m_idx = mi; + int k_idx = ki; + if (m_idx < kBlockM && k_idx < block_key_len) { + auto mask_values_row = sDynamicMaskValues(m_idx, _); + auto predicate_k_row = sPredicate(m_idx, _); + if (predicate_k_row(k_idx)) { + // First scale the attention score + acc_s(mma, mi, ki) = acc_s(mma, mi, ki) * params.scale_softmax + static_cast(mask_values_row(k_idx)); + } else { + // For positions where mask is 0, set attention score to -INFINITY so they don't contribute to softmax + acc_s(mma, mi, ki) = -INFINITY; + } + } + } + } + } + + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + softmax.template softmax_rescale_o(acc_s, acc_o, 1.0f); + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + if (Return_softmax) { + tSgS.data() = tSgS.data() + (-kBlockN); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + + FLASH_NAMESPACE::sparse_gemm_rs( + acc_o, + tOrP, tOrVt, tOsVt, + tiled_mma, smem_tiled_copy_V, smem_thr_copy_V, + sPredicate // 应用相同的谓词来进行稀疏V矩阵乘法 + ); + + } + + // Epilogue + + // 后处理和输出归一化 + Tensor lse = softmax.template normalize_softmax_lse( + acc_o, params.scale_softmax, 1.0f + ); + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + Tensor mO = make_tensor( + make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{}) + ); + Tensor gO = local_tile( + mO(_, bidh, _), + Shape, Int>{}, + make_coord(m_block, 0) + ); // (kBlockM, kHeadDim) + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +template +inline __device__ void compute_attn(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // 调用主要的计算函数 + compute_attn_1rowblock(params, bidb, bidh, m_block); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file