From 109063fd19e4dff8f8d0eafa777eb5f020b55a3c Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 17 May 2025 19:48:05 +0800 Subject: [PATCH 01/18] Add SmemLayout for Dynamic Mask --- csrc/src/kernel_traits.h | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index 7fd8272..56de831 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -141,6 +141,34 @@ struct Flash_fwd_kernel_traits : public Base { static constexpr int kDynamicMaskBufferPerQuery = kMaxKeysPerBlock * (2 * sizeof(float) + sizeof(int)); static constexpr int kTotalDynamicMaskBuffer = kBlockM * kDynamicMaskBufferPerQuery; + // Dynamic mask shared memory layouts + using SmemLayoutDynamicMaskValues = decltype( + tile_to_shape( + composition(Swizzle{}, + Layout>, + Stride, _1>>{}), + Shape, Int>{} + ) + ); + + using SmemLayoutDynamicMaskSortKeys = decltype( + tile_to_shape( + composition(Swizzle{}, + Layout>, + Stride, _1>>{}), + Shape, Int>{} + ) + ); + + using SmemLayoutDynamicMaskSortIndices = decltype( + tile_to_shape( + composition(Swizzle{}, + Layout>, + Stride, _1>>{}), + Shape, Int>{} + ) + ); + // Shared memory size calculations static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); From aafb1f742007ff6ca5df8a23167700fa73bd31b4 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 17 May 2025 19:57:52 +0800 Subject: [PATCH 02/18] Add SmemLayout for Dynamic Mask --- csrc/src/kernel_traits.h | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index 56de831..c2e28e3 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -138,7 +138,11 @@ struct Flash_fwd_kernel_traits : public Base { // Dynamic mask memory allocation constants static constexpr int kMaxKeysPerBlock = kBlockN; - static constexpr int kDynamicMaskBufferPerQuery = kMaxKeysPerBlock * (2 * sizeof(float) + sizeof(int)); + static constexpr int kMaskValuesSize = kMaxKeysPerBlock * sizeof(float); + static constexpr int kNonZeroIndicesSize = kMaxKeysPerBlock * sizeof(int); + static constexpr int kSortKeysSize = kMaxKeysPerBlock * sizeof(float); + static constexpr int kSortIndicesSize = kMaxKeysPerBlock * sizeof(int); + static constexpr int kDynamicMaskBufferPerQuery = kMaskValuesSize + kNonZeroIndicesSize + kSortKeysSize + kSortIndicesSize; static constexpr int kTotalDynamicMaskBuffer = kBlockM * kDynamicMaskBufferPerQuery; // Dynamic mask shared memory layouts @@ -169,6 +173,15 @@ struct Flash_fwd_kernel_traits : public Base { ) ); + using SmemLayoutNonZeroIndices = decltype( + tile_to_shape( + composition(Swizzle{}, + Layout>, + Stride, _1>>{}), + Shape, Int>{} + ) + ); + // Shared memory size calculations static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); @@ -176,7 +189,7 @@ struct Flash_fwd_kernel_traits : public Base { // Base shared memory size without dynamic mask buffer static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; - // Total shared memory size including dynamic mask buffer + // Total shared memory size including dynamic mask buffer and nonzero indices static constexpr int kSmemSizeWithMask = kSmemSize + kTotalDynamicMaskBuffer; // Global memory access configuration @@ -460,4 +473,4 @@ struct Flash_bwd_kernel_traits : public Base { Layout>{})); // Val layout, 8 vals per read }; -//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file From b2b322ee593ca011fe4f7ce41ad6e164e5d5def9 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 17 May 2025 20:46:08 +0800 Subject: [PATCH 03/18] Add forward kernel --- csrc/src/flash_attention_fwd_kernel.h | 811 ++++++++++++++++++++++++++ 1 file changed, 811 insertions(+) create mode 100644 csrc/src/flash_attention_fwd_kernel.h diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h new file mode 100644 index 0000000..12a91f1 --- /dev/null +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -0,0 +1,811 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "namespace_config.h" +#include // For at::cuda::philox::unpack + +#include + +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" +#include "mask.h" + +namespace FLASH_NAMESPACE { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bidb, const int bidh, const int m_block, const BlockInfo &binfo) { + // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path. + // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick. + // Otherwise, it's written as (h, b, seqlen_q). + const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped; + auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0; + auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + lse_offset); + + auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q); + auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : ( + params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1) + ); + + auto lse_layout = make_layout(lse_shape, lse_stride); + Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout); + auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _); + return local_tile(mLSE_slice, Shape>{}, make_coord(m_block)); +} + +template +inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + // 检查块内是否有要处理的查询 + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + // 计算实际要处理的N块范围 + const int n_block_min = 0; + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); + } + + // 如果没有N块要处理,设置输出为0并返回 + if (n_block_max <= n_block_min) { + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d), make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); + + // 初始化输出为0 + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + + // 构建输出谓词 + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + + // 写出清零后的输出 + FLASH_NAMESPACE::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + + // 设置LSE为无穷小 + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int mi = get<0>(tOcO(0, m, 0)); + if (mi < binfo.actual_seqlen_q - m_block * kBlockM) { + gLSE(mi) = -INFINITY; + } + } + return; + } + + // 全局内存张量配置 + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); + + Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.k_row_stride, params.k_head_stride, _1{})); + + Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); + + Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.v_row_stride, params.v_head_stride, _1{})); + + Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); + + Tensor mZeroHold = make_tensor(make_gmem_ptr(reinterpret_cast(params.zero_hold_ptr) + + bidb * params.zero_hold_batch_stride), + make_shape(params.h, binfo.actual_seqlen_q, binfo.actual_seqlen_k), // Assuming h is num_kv_heads for zero_hold + make_stride(params.zero_hold_head_stride, params.zero_hold_query_stride, _1{})); + Tensor gZeroHold = local_tile(mZeroHold(bidh / params.h_h_k_ratio, _, _), // Use bidh / params.h_h_k_ratio if zero_hold is per kv_head + Shape, Int>{}, + make_coord(m_block, 0)); // m_block for query row, n_block for key column + + Tensor mCausalMask = params.causal_mask_ptr != nullptr + ? make_tensor(make_gmem_ptr(reinterpret_cast(params.causal_mask_ptr) + + bidb * params.causal_mask_batch_stride), + make_shape(1, binfo.actual_seqlen_q, binfo.actual_seqlen_k), + make_stride(params.causal_mask_head_stride, params.causal_mask_query_len_stride, _1{})) + : Tensor(); // Empty tensor if no causal mask is provided + Tensor gCausalMask = params.causal_mask_ptr != nullptr + ? local_tile(mCausalMask(0, _, _), + Shape, Int>{}, + make_coord(_, 0)) + : Tensor(); // Empty tensor if no causal mask is provided + + + + // 共享内存配置 + // QKV的共享内存布局 + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + // Dynamic mask的共享内存布局 + Tensor sZeroHold = make_tensor(sV.data().get() + size(sV), typename Kernel_traits::SmemLayoutZeroHold{}); + Tensor sCausalMask = params.causal_mask_ptr != nullptr + ? make_tensor(sZeroHold.data().get() + size(sZeroHold), + typename Kernel_traits::SmemLayoutZeroHold{}) + : Tensor(); + Tensor sDynamicMaskValues = make_tensor( + (params.causal_mask_ptr != nullptr ? + sCausalMask.data().get() + size(sCausalMask) : + sZeroHold.data().get() + size(sZeroHold)), + typename Kernel_traits::SmemLayoutDynamicMaskValues{} + ); + Tensor sDynamicMaskSortKeys = make_tensor( + sDynamicMaskValues.data().get() + size(sDynamicMaskValues), + typename Kernel_traits::SmemLayoutDynamicMaskSortKeys{} + ); + Tensor sDynamicMaskSortIndices = make_tensor( + sDynamicMaskSortKeys.data().get() + size(sDynamicMaskSortKeys), + typename Kernel_traits::SmemLayoutDynamicMaskSortIndices{} + ); + Tensor sNonZeroIndices = make_tensor( + sDynamicMaskSortIndices.data().get() + size(sDynamicMaskSortIndices), + typename Kernel_traits::SmemLayoutNonZeroIndices{} + ); + Tensor sPredicate = make_tensor( + sNonZeroIndices.data().get() + size(sNonZeroIndices), + typename Kernel_traits::SmemLayoutZeroHold{} + ); + + + // 设置全局内存到共享内存的拷贝 + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyZeroHold gmem_tiled_copy_ZeroHold; + auto gmem_thr_copy_ZeroHold = gmem_tiled_copy_ZeroHold.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyZeroHold gmem_tiled_copy_CausalMask; + auto gmem_thr_copy_CausalMask = gmem_tiled_copy_CausalMask.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + Tensor tZeroHoldgZeroHold = gmem_thr_copy_ZeroHold.partition_S(gZeroHold); + Tensor tZeroHoldsZeroHold = gmem_thr_copy_ZeroHold.partition_D(sZeroHold); + Tensor tCausalMaskgCausalMask = params.causal_mask_ptr != nullptr + ? gmem_thr_copy_CausalMask.partition_S(gCausalMask) + : Tensor(); + Tensor tCausalMasksCausalMask = params.causal_mask_ptr != nullptr + ? gmem_thr_copy_CausalMask.partition_D(sCausalMask) + : Tensor(); + + // 设置矩阵乘法操作 + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); + Tensor tSrK = thr_mma.partition_fragment_B(sK); + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); + + // 设置从共享内存到寄存器的拷贝 + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + auto smem_tiled_copy_ZeroHold = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_ZeroHold = smem_tiled_copy_ZeroHold.get_thread_slice(tidx); + Tensor tSsZeroHold = smem_thr_copy_ZeroHold.partition_S(sZeroHold); + + auto smem_tiled_copy_CausalMask = params.causal_mask_ptr != nullptr + ? make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma) + : decltype(make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma)){}; + auto smem_thr_copy_CausalMask = params.causal_mask_ptr != nullptr + ? smem_tiled_copy_CausalMask.get_thread_slice(tidx) + : decltype(smem_tiled_copy_CausalMask.get_thread_slice(tidx)){}; + Tensor tSsCausalMask = params.causal_mask_ptr != nullptr + ? smem_thr_copy_CausalMask.partition_S(sCausalMask) + : Tensor(); + + // 设置谓词 + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); + Tensor cZeroHold = make_identity_tensor(make_shape(size<0>(sZeroHold), size<1>(sZeroHold))); + Tensor cCausalMask = params.causal_mask_ptr != nullptr + ? make_identity_tensor(make_shape(size<0>(sCausalMask), size<1>(sCausalMask))) + : Tensor(); + + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); + Tensor tZeroHoldcZeroHold = gmem_thr_copy_ZeroHold.partition_S(cZeroHold); + Tensor tCausalMaskcCausalMask = params.causal_mask_ptr != nullptr + ? gmem_thr_copy_CausalMask.partition_S(cCausalMask) + : Tensor(); + + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + Tensor tZeroHoldpZeroHold = make_tensor(make_shape(size<2>(tZeroHoldsZeroHold))); + Tensor tCausalMaskpCausalMask = params.causal_mask_ptr != nullptr + ? make_tensor(make_shape(size<2>(tCausalMasksCausalMask))) + : Tensor(); + + // 设置K维度的谓词 + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // 初始化动态掩码处理器 + DynamicMask dynamic_mask(params.keep_window_size); + + // 加载Q到共享内存 + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM + ); + + if (Kernel_traits::Is_Q_in_regs) { + cute::cp_async_fence(); + } + + // 如果共享Q和K的内存,需要等待并同步 + if (Kernel_traits::Share_Q_K_smem) { + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + // 反向迭代N块 + int n_block = n_block_max - 1; + + // 加载第一个K块到共享内存 + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN + ); + cute::cp_async_fence(); + + // 加载第一个ZeroHold块到共享内存 + FLASH_NAMESPACE::copy( + gmem_tiled_copy_ZeroHold, tZeroHoldgZeroHold(_, _, _, n_block), tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold, + binfo.actual_seqlen_k - n_block * kBlockN + ); + cute::cp_async_fence(); + + // 加载第一个CausalMask块到共享内存(如果有) + if (params.causal_mask_ptr != nullptr) { + FLASH_NAMESPACE::copy( + gmem_tiled_copy_CausalMask, tCausalMaskgCausalMask(_, _, _, n_block), tCausalMasksCausalMask, tCausalMaskcCausalMask, tCausalMaskpCausalMask, + binfo.actual_seqlen_k - n_block * kBlockN + ); + cute::cp_async_fence(); + } + + // 将Q从共享内存加载到寄存器(如果需要) + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + FLASH_NAMESPACE::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + // 初始化输出累加器 + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); + clear(acc_o); + + // 创建softmax计算器 + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; + + // 处理需要掩码的块(通常是最后几个块) + constexpr int n_masking_steps = (!Is_causal) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + // 等待K数据 + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + + // 加载V块到共享内存 + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN + ); + cute::cp_async_fence(); + + // 计算块中实际键的数量 + const int block_key_len = min(kBlockN, binfo.actual_seqlen_k - n_block * kBlockN); + + // 为当前块内的每个查询行处理动态掩码 + const int queries_in_block = min(kBlockM, binfo.actual_seqlen_q - m_block * kBlockM); + for (int m_idx = 0; m_idx < queries_in_block; ++m_idx) { + // 获取当前查询的全局索引 + const int query_idx = m_block * kBlockM + m_idx; + + // 获取当前查询行的动态掩码内存 + Tensor mask_values = sDynamicMaskValues(m_idx, _); + Tensor sort_keys = sDynamicMaskSortKeys(m_idx, _); + Tensor sort_indices = sDynamicMaskSortIndices(m_idx, _); + Tensor nonzero_indices = sNonZeroIndices(m_idx, _); + Tensor predicate_k = sPredicate(m_idx, _); + + // 获取当前查询行的zero_hold和causal_mask + const Element* zero_hold_row = &sZeroHold[m_idx][0]; + const Element* causal_mask_row = params.causal_mask_ptr != nullptr ? + &sCausalMask[m_idx][0] : nullptr; + + // 使用DynamicMask结构体来应用掩码 + dynamic_mask.apply_mask_1rowblock( + mask_values, + zero_hold_row, + causal_mask_row, + block_key_len, + mask_values.data().get(), + sort_keys.data().get(), + reinterpret_cast(sort_indices.data().get()), + ); + __syncthreads(); + + // 初始化键的活性状态谓词 + if (tidx == 0) { + // 只需一个线程来初始化整个谓词数组 + #pragma unroll + for (int k_idx = 0; k_idx < kBlockN; ++k_idx) { + predicate_k(k_idx) = false; + } + } + __syncthreads(); + + // 找出非零位置 + int nonzero_count = 0; + // 每个线程负责处理部分键位置 + for (int k_idx = tidx; k_idx < block_key_len; k_idx += blockDim.x) { + if (mask_values(k_idx) != 0.0f) { + // 使用原子操作安全地增加计数并获取索引位置 + int idx = atomicAdd(&nonzero_count, 1); + if (idx < Kernel_traits::kMaxKeysPerBlock) { + nonzero_indices(idx) = k_idx; + // 标记该键为活跃状态 + predicate_k(k_idx) = true; + } + } + } + __syncthreads(); + + // 如果没有非零键,跳过当前查询行 + if (nonzero_count == 0) { + continue; + } + + // 处理多查询头情况 (MQA/GQA) + const int num_queries_per_kv = params.h_h_k_ratio; + + // 对于每个查询组内的查询头 + for (int q_group_idx = 0; q_group_idx < num_queries_per_kv; q_group_idx++) { + // 创建累加器用于注意力分数 + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); + clear(acc_s); + + // 执行稀疏矩阵乘法 + FLASH_NAMESPACE::sparse_gemm_rs( + acc_s(_, m_idx, _), // 当前查询行的累加器 + tSrQ(_, m_idx, _), // 当前查询 + tSrK, // 键值 + tSsK, // 共享内存中的键值 + tiled_mma, + smem_tiled_copy_K, + smem_thr_copy_K, + predicate_k // 活跃键的谓词 + ); + + // 应用掩码添加(zero_hold状态既是掩码也是要添加到注意力分数的值) + for (int s_idx = 0; s_idx < size(acc_s); ++s_idx) { + const int k_idx = get<2>(thr_mma.get_slice_idx(s_idx, acc_s)); + if (k_idx < block_key_len && predicate_k(k_idx)) { + acc_s(s_idx) += static_cast(mask_values[k_idx]); + } + } + + // 执行softmax并更新输出累加器 + if (q_group_idx == 0 && n_block == n_block_max - 1) { + softmax.template softmax_rescale_o( + acc_s, acc_o, params.scale_softmax_log2); + } else { + softmax.template softmax_rescale_o( + acc_s, acc_o, params.scale_softmax_log2); + } + + // 将浮点分数转换为Element类型进行输出计算 + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + Tensor tOrP = make_tensor( + rP.data(), + FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout()) + ); + + // 计算该查询头的输出 + FLASH_NAMESPACE::sparse_gemm_rs( + acc_o, // 输出累加器 + tOrP, // 注意力权重 + tOrVt, // 值向量 + tOsVt, // 共享内存中的值向量 + tiled_mma, + smem_tiled_copy_V, + smem_thr_copy_V, + predicate_k // 应用相同的谓词来进行稀疏V矩阵乘法 + ); + } + __syncthreads(); + } + + // 等待V数据 + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + + // 准备加载下一个K块(如果有) + if (n_block > n_block_min) { + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tKgK(_, _, _, n_block-1), tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - (n_block-1) * kBlockN + ); + cute::cp_async_fence(); + + // 加载下一个ZeroHold块到共享内存 + FLASH_NAMESPACE::copy( + gmem_tiled_copy_ZeroHold, tZeroHoldgZeroHold(_, _, _, n_block-1), tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold, + binfo.actual_seqlen_k - (n_block-1) * kBlockN + ); + cute::cp_async_fence(); + + // 加载下一个CausalMask块到共享内存(如果有) + if (params.causal_mask_ptr != nullptr) { + FLASH_NAMESPACE::copy( + gmem_tiled_copy_CausalMask, tCausalMaskgCausalMask(_, _, _, n_block-1), tCausalMasksCausalMask, tCausalMaskcCausalMask, tCausalMaskpCausalMask, + binfo.actual_seqlen_k - (n_block-1) * kBlockN + ); + cute::cp_async_fence(); + } + } + + // 提前退出检查 + if (n_masking_steps > 1 && n_block <= n_block_min) { + break; + } + } + + // 处理不需要掩码的块 + for (; n_block >= n_block_min; --n_block) { + // 等待K数据 + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + + // 加载V块到共享内存 + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV + ); + cute::cp_async_fence(); + + // 计算块中实际键的数量 + const int block_key_len = min(kBlockN, binfo.actual_seqlen_k - n_block * kBlockN); + const int queries_in_block = min(kBlockM, binfo.actual_seqlen_q - m_block * kBlockM); + + // 为当前块内的每个查询行处理动态掩码 + for (int m_idx = 0; m_idx < queries_in_block; ++m_idx) { + // 获取当前查询的零状态行 + Tensor mask_values = sDynamicMaskValues(m_idx, _); + Tensor sort_keys = sDynamicMaskSortKeys(m_idx, _); + Tensor sort_indices = sDynamicMaskSortIndices(m_idx, _); + Tensor nonzero_indices = sNonZeroIndices(m_idx, _); + Tensor predicate_k = sPredicate(m_idx, _); + + // 获取当前查询行的zero_hold + const Element* zero_hold_row = &sZeroHold[m_idx][0]; + + // 使用DynamicMask结构体来应用掩码,没有因果掩码 + dynamic_mask.apply_mask_1rowblock( + mask_values, + zero_hold_row, + nullptr, // 无因果掩码 + block_key_len, + mask_values.data().get(), + sort_keys.data().get(), + reinterpret_cast(sort_indices.data().get()) + ); + __syncthreads(); + + // 初始化键的活性状态谓词 + if (tidx == 0) { + // 只需一个线程来初始化整个谓词数组 + #pragma unroll + for (int k_idx = 0; k_idx < kBlockN; ++k_idx) { + predicate_k(k_idx) = false; + } + } + __syncthreads(); + + // 找出非零位置 + int nonzero_count = 0; + // 每个线程负责处理部分键位置 + for (int k_idx = tidx; k_idx < block_key_len; k_idx += blockDim.x) { + if (mask_values(k_idx) != 0.0f) { + // 使用原子操作安全地增加计数并获取索引位置 + int idx = atomicAdd(&nonzero_count, 1); + if (idx < Kernel_traits::kMaxKeysPerBlock) { + nonzero_indices(idx) = k_idx; + // 标记该键为活跃状态 + predicate_k(k_idx) = true; + } + } + } + __syncthreads(); + + // 如果没有非零键,跳过当前查询行 + if (nonzero_count == 0) { + continue; + } + + // 处理多查询头情况 (MQA/GQA) + const int num_queries_per_kv = params.h_h_k_ratio; + + // 对于每个查询组内的查询头 + for (int q_group_idx = 0; q_group_idx < num_queries_per_kv; q_group_idx++) { + // 创建累加器用于注意力分数 + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); + clear(acc_s); + + // 执行稀疏矩阵乘法 + FLASH_NAMESPACE::sparse_gemm_rs( + acc_s(_, m_idx, _), // 当前查询行的累加器 + tSrQ(_, m_idx, _), // 当前查询 + tSrK, // 键值 + tSsK, // 共享内存中的键值 + tiled_mma, + smem_tiled_copy_K, + smem_thr_copy_K, + predicate_k // 活跃键的谓词 + ); + + // 应用掩码添加 + for (int s_idx = 0; s_idx < size(acc_s); ++s_idx) { + const int k_idx = get<2>(thr_mma.get_slice_idx(s_idx, acc_s)); + if (k_idx < block_key_len && predicate_k(k_idx)) { + acc_s(s_idx) += static_cast(mask_values[k_idx]); + } + } + + // 执行softmax并更新输出累加器 + softmax.template softmax_rescale_o( + acc_s, acc_o, params.scale_softmax_log2); + + // 将浮点分数转换为Element类型进行输出计算 + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + Tensor tOrP = make_tensor( + rP.data(), + FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout()) + ); + + // 计算该查询头的输出 + FLASH_NAMESPACE::sparse_gemm_rs( + acc_o, // 输出累加器 + tOrP, // 注意力权重 + tOrVt, // 值向量 + tOsVt, // 共享内存中的值向量 + tiled_mma, + smem_tiled_copy_V, + smem_thr_copy_V, + predicate_k // 应用相同的谓词来进行稀疏V矩阵乘法 + ); + } + __syncthreads(); + } + + // 等待V数据 + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + + if (n_block > n_block_min) { + // 准备加载下一个K块(如果有) + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tKgK(_, _, _, n_block-1), tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - (n_block-1) * kBlockN + ); + cute::cp_async_fence(); + + // 加载下一个ZeroHold块到共享内存 + FLASH_NAMESPACE::copy( + gmem_tiled_copy_ZeroHold, tZeroHoldgZeroHold(_, _, _, n_block-1), tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold, + binfo.actual_seqlen_k - (n_block-1) * kBlockN + ); + cute::cp_async_fence(); + + // 加载下一个CausalMask块到共享内存(如果有) + if (params.causal_mask_ptr != nullptr) { + FLASH_NAMESPACE::copy( + gmem_tiled_copy_CausalMask, tCausalMaskgCausalMask(_, _, _, n_block-1), tCausalMasksCausalMask, tCausalMaskcCausalMask, tCausalMaskpCausalMask, + binfo.actual_seqlen_k - (n_block-1) * kBlockN + ); + cute::cp_async_fence(); + } + } + } + + // 后处理和输出归一化 + Tensor lse = softmax.template normalize_softmax_lse( + acc_o, params.scale_softmax, 1.0f + ); + + // 转换acc_o到Element类型 + Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); + + // 准备共享内存用于输出 + Tensor sO = make_tensor( + sQ.data(), + typename Kernel_traits::SmemLayoutO{} + ); + + // 设置从累加器到共享内存的拷贝 + auto smem_tiled_copy_O = make_tiled_copy_C( + typename Kernel_traits::SmemCopyAtomO{}, + tiled_mma + ); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); + + // 确保共享内存区域可以安全使用 + if (Kernel_traits::Share_Q_K_smem) { + __syncthreads(); + } + + // 拷贝输出到共享内存 + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + // 设置全局内存输出张量 + Tensor mO = make_tensor( + make_gmem_ptr(reinterpret_cast(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{}) + ); + + Tensor gO = local_tile( + mO(_, bidh, _), + Shape, Int>{}, + make_coord(m_block, 0) + ); + + Tensor gLSE = get_lse_tile( + params, bidb, bidh, m_block, binfo + ); + + // 设置从共享内存到全局内存的拷贝 + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + // 从共享内存拷贝到寄存器,准备写入全局内存 + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + // 设置输出的谓词 + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + + // 写入输出到全局内存 + FLASH_NAMESPACE::copy( + gmem_tiled_copy_O, + tOrO, + tOgO, + tOcO, + tOpO, + binfo.actual_seqlen_q - m_block * kBlockM + ); + + // 写入LSE值到全局内存 + Tensor caccO = make_identity_tensor(Shape, Int>{}); + Tensor taccOcO = thr_mma.partition_C(caccO); + static_assert(decltype(size<0>(taccOcO))::value == 4); + + // 将张量转换为(2,2)形式,然后只获取行索引 + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); + + // 只有第一个线程写入LSE值 + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + if (m_block * kBlockM + get<0>(taccOcO_row(mi)) < binfo.actual_seqlen_q) { + gLSE(get<0>(taccOcO_row(mi))) = lse(mi); + } + } + } +} + +template +inline __device__ void compute_attn(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // 调用主要的计算函数 + compute_attn_1rowblock(params, bidb, bidh, m_block); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file From 3423ea1b39e67aa284a4d1c04b38ec8033f14cbe Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 17 May 2025 21:54:24 +0800 Subject: [PATCH 04/18] Add forward launch template --- csrc/src/flash_fwd_launch_template.h | 301 +++++++++++++++++++++++++++ 1 file changed, 301 insertions(+) create mode 100644 csrc/src/flash_fwd_launch_template.h diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h new file mode 100644 index 0000000..15e079f --- /dev/null +++ b/csrc/src/flash_fwd_launch_template.h @@ -0,0 +1,301 @@ +#define FLASH_ATTENTION_ENABLE_BF16 +/****************************************************************************** + * Copyright (c) 2025, Jingze Shi and Tri Dao. + ******************************************************************************/ + +#pragma once +#include "namespace_config.h" +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include "static_switch.h" +#include "hardware_info.h" +#include "flash.h" +#include "flash_fwd_kernel.h" + +namespace FLASH_NAMESPACE { + +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif + +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashDynamicMaskAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +// Use a macro to clean up kernel definitions +#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ +template \ +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax) { + #if defined(ARCH_SUPPORTS_FLASH) + FLASH_NAMESPACE::compute_attn(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) { + #if defined(ARCH_SUPPORTS_FLASH) + FLASH_NAMESPACE::compute_attn_splitkv(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { + static_assert(Log_max_splits >= 1); + FLASH_NAMESPACE::combine_attn_seqk_parallel(params); +} + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + const size_t smem_size = Kernel_traits::kSmemSizeWithMask; + // printf("smem_size = %d\n", smem_size); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool return_softmax = params.p_ptr != nullptr; + + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); +} + +template +void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + + const size_t smem_size = Kernel_traits::kSmemSizeWithMask; + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); + if (params.num_splits > 1) { + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + // Also check for dynamic mask here + DYNAMIC_MASK_SWITCH(use_dynamic_mask, UseDynamicMaskConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } +} + +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int kBlockM = 64; // Fixed for all head dimensions + // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, + // and for headdim 192 with block size 64 x 128. + constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + + // Pass the dynamic mask flag appropriately + const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; + if (use_dynamic_mask) { + run_flash_splitkv_fwd, Is_causal>(params, stream); + } else { + run_flash_splitkv_fwd, Is_causal>(params, stream); + } +} + +template +void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 32; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // Check for dynamic mask + const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; + DYNAMIC_MASK_SWITCH(use_dynamic_mask, Is_dynamic, [&] { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // Check for dynamic mask + const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; + DYNAMIC_MASK_SWITCH(use_dynamic_mask, Is_dynamic, [&] { + if constexpr(!Is_dropout) { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); + }); +} + +template +void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 96; + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x = cc_major == 8 && cc_minor > 0; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // Check for dynamic mask + const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; + DYNAMIC_MASK_SWITCH(use_dynamic_mask, Is_dynamic, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x = cc_major == 8 && cc_minor > 0; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // Check for dynamic mask + const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; + DYNAMIC_MASK_SWITCH(use_dynamic_mask, Is_dynamic, [&] { + if constexpr(!Is_dropout) { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); + }); +} + +template +void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 192; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // Check for dynamic mask + const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; + DYNAMIC_MASK_SWITCH(use_dynamic_mask, Is_dynamic, [&] { + if constexpr(!Is_dropout) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); + }); +} + +template +void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 256; + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // Check for dynamic mask + const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; + DYNAMIC_MASK_SWITCH(use_dynamic_mask, Is_dynamic, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); + }); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file From 9cc80666bfdcf531dfb2b0bdb841f6d2e295d801 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 17 May 2025 22:20:48 +0800 Subject: [PATCH 05/18] Remove Invalid Params from launch template --- csrc/src/flash_fwd_launch_template.h | 214 +++++++++------------------ 1 file changed, 72 insertions(+), 142 deletions(-) diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index 15e079f..3b5b215 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -30,9 +30,9 @@ namespace FLASH_NAMESPACE { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn(params); + FLASH_NAMESPACE::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -51,7 +51,7 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int L FLASH_NAMESPACE::combine_attn_seqk_parallel(params); } -template +template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const size_t smem_size = Kernel_traits::kSmemSizeWithMask; // printf("smem_size = %d\n", smem_size); @@ -69,20 +69,16 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { - SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If return_softmax, set IsEvenMNConst to false to reduce number of templates - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); @@ -104,18 +100,15 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); @@ -127,24 +120,21 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - // Also check for dynamic mask here - DYNAMIC_MASK_SWITCH(use_dynamic_mask, UseDynamicMaskConst, [&] { - if (params.num_splits <= 2) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 4) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 8) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 16) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 32) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 64) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } - }); + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -156,50 +146,22 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, // and for headdim 192 with block size 64 x 128. constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - - // Pass the dynamic mask flag appropriately - const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; - if (use_dynamic_mask) { - run_flash_splitkv_fwd, Is_causal>(params, stream); - } else { - run_flash_splitkv_fwd, Is_causal>(params, stream); - } + run_flash_splitkv_fwd, Is_causal>(params, stream); } template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // Check for dynamic mask - const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; - DYNAMIC_MASK_SWITCH(use_dynamic_mask, Is_dynamic, [&] { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); - }); + run_flash_fwd, Is_causal>(params, stream); } template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // Check for dynamic mask - const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; - DYNAMIC_MASK_SWITCH(use_dynamic_mask, Is_dynamic, [&] { - if constexpr(!Is_dropout) { - // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower - // Using block size (64 x 256) is 27% slower for seqlen=2k - // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); - }); + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_causal>(params, stream); } template @@ -207,22 +169,16 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x = cc_major == 8 && cc_minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // Check for dynamic mask - const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; - DYNAMIC_MASK_SWITCH(use_dynamic_mask, Is_dynamic, [&] { - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); - }); + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } } template @@ -230,43 +186,23 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x = cc_major == 8 && cc_minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // Check for dynamic mask - const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; - DYNAMIC_MASK_SWITCH(use_dynamic_mask, Is_dynamic, [&] { - if constexpr(!Is_dropout) { - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); - }); + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } } template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // Check for dynamic mask - const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; - DYNAMIC_MASK_SWITCH(use_dynamic_mask, Is_dynamic, [&] { - if constexpr(!Is_dropout) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); - }); + run_flash_fwd, Is_causal>(params, stream); } template @@ -283,19 +219,13 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // Check for dynamic mask - const bool use_dynamic_mask = params.zero_hold_ptr != nullptr && params.keep_window_size > 0; - DYNAMIC_MASK_SWITCH(use_dynamic_mask, Is_dynamic, [&] { - // For A100, we want to run with 128 x 64 (128KB smem). - // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); - }); + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } } } // namespace FLASH_NAMESPACE \ No newline at end of file From e12b9c650cd5df0c8e3765155e85faa65ee11fae Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 18 May 2025 10:20:07 +0800 Subject: [PATCH 06/18] Check if there are any queries to process in the block --- csrc/src/flash_attention_fwd_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index 12a91f1..a97d188 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -63,7 +63,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; - // 检查块内是否有要处理的查询 + // Check if there are any queries to process in the block const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; From ba0f1f1bce23eb1afa1ed10a73f751fc867c90d3 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 18 May 2025 10:25:22 +0800 Subject: [PATCH 07/18] We exit early and write 0 to gO and gLSE --- csrc/src/flash_attention_fwd_kernel.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index a97d188..895e99e 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -61,7 +61,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kNWarps = Kernel_traits::kNWarps; // Check if there are any queries to process in the block const BlockInfo binfo(params, bidb); @@ -75,8 +74,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); } - // 如果没有N块要处理,设置输出为0并返回 - if (n_block_max <= n_block_min) { + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + if ((Is_causal || !Is_even_MN) && n_block_max <= n_block_min) { Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d), make_stride(params.o_row_stride, params.o_head_stride, _1{})); Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); From 33dceb117adf5b82ede79744172a77ebbcea2dff Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 18 May 2025 10:31:53 +0800 Subject: [PATCH 08/18] We exit early and write 0 to gO and gLSE --- csrc/src/flash_attention_fwd_kernel.h | 32 ++++++++++++++------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index 895e99e..25d72a0 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -77,19 +77,26 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. // Otherwise we might read OOB elements from gK and gV. if ((Is_causal || !Is_even_MN) && n_block_max <= n_block_min) { - Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d), make_stride(params.o_row_stride, params.o_head_stride, _1{})); - Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); + Tensor mO = make_tensor( + make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), make_stride(params.o_row_stride, params.o_head_stride, _1{}) + ); + Tensor gO = local_tile( + mO(_, bidh, _), + Shape, Int>{}, + make_coord(m_block, 0) + ); // (kBlockM, kHeadDim) + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); - // 初始化输出为0 typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOrO = make_tensor(shape(tOgO)); clear(tOrO); - - // 构建输出谓词 - Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O.partition_D(cO); Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); if (!Is_even_K) { @@ -98,19 +105,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } - - // 写出清零后的输出 - FLASH_NAMESPACE::copy( + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); - - // 设置LSE为无穷小 #pragma unroll for (int m = 0; m < size<1>(tOgO); ++m) { - const int mi = get<0>(tOcO(0, m, 0)); - if (mi < binfo.actual_seqlen_q - m_block * kBlockM) { - gLSE(mi) = -INFINITY; - } + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } } return; } From 157275a478a20af618a090936dc6ade21732dd7d Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 18 May 2025 10:32:39 +0800 Subject: [PATCH 09/18] Compute the actual range of N blocks to process --- csrc/src/flash_attention_fwd_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index 25d72a0..cb58125 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -66,7 +66,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - // 计算实际要处理的N块范围 + // Compute the actual range of N blocks to process const int n_block_min = 0; int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); if (Is_causal) { From 76bf08aed3052d19a7ab59391a7c6d91145e0b4f Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 18 May 2025 12:50:26 +0800 Subject: [PATCH 10/18] Update Golobal memory tensor configuration --- csrc/src/flash_attention_fwd_kernel.h | 122 +++++++++++++++++--------- 1 file changed, 80 insertions(+), 42 deletions(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index cb58125..8ac2c02 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -116,51 +116,89 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } return; } + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } - // 全局内存张量配置 - Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) - + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, params.h, params.d), - make_stride(params.q_row_stride, params.q_head_stride, _1{})); - - Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, - make_coord(m_block, 0)); + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + bool has_causal_mask = params.causal_mask_ptr != nullptr; + + // Golobal memory tensor configuration + Tensor mQ = make_tensor( + make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{}) + ); + Tensor gQ = local_tile( + mQ(_, bidh, _), + Shape, Int>{}, + make_coord(m_block, 0) + ); // (kBlockM, kHeadDim) - Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) - + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), - make_shape(binfo.actual_seqlen_k, params.h_k, params.d), - make_stride(params.k_row_stride, params.k_head_stride, _1{})); - - Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, - make_coord(_, 0)); + Tensor mK = make_tensor( + make_gmem_ptr(reinterpret_cast(params.k_ptr) + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.k_row_stride, params.k_head_stride, _1{}) + ); + Tensor gK = local_tile( + mK(_, bidh / params.h_h_k_ratio, _), + Shape, Int>{}, + make_coord(_, 0) + ); // (kBlockN, kHeadDim, nblocksN) - Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) - + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), - make_shape(binfo.actual_seqlen_k, params.h_k, params.d), - make_stride(params.v_row_stride, params.v_head_stride, _1{})); - - Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, - make_coord(_, 0)); - - Tensor mZeroHold = make_tensor(make_gmem_ptr(reinterpret_cast(params.zero_hold_ptr) - + bidb * params.zero_hold_batch_stride), - make_shape(params.h, binfo.actual_seqlen_q, binfo.actual_seqlen_k), // Assuming h is num_kv_heads for zero_hold - make_stride(params.zero_hold_head_stride, params.zero_hold_query_stride, _1{})); - Tensor gZeroHold = local_tile(mZeroHold(bidh / params.h_h_k_ratio, _, _), // Use bidh / params.h_h_k_ratio if zero_hold is per kv_head - Shape, Int>{}, - make_coord(m_block, 0)); // m_block for query row, n_block for key column - - Tensor mCausalMask = params.causal_mask_ptr != nullptr - ? make_tensor(make_gmem_ptr(reinterpret_cast(params.causal_mask_ptr) - + bidb * params.causal_mask_batch_stride), - make_shape(1, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.causal_mask_head_stride, params.causal_mask_query_len_stride, _1{})) - : Tensor(); // Empty tensor if no causal mask is provided - Tensor gCausalMask = params.causal_mask_ptr != nullptr - ? local_tile(mCausalMask(0, _, _), - Shape, Int>{}, - make_coord(_, 0)) - : Tensor(); // Empty tensor if no causal mask is provided + Tensor mV = make_tensor( + make_gmem_ptr(reinterpret_cast(params.v_ptr) + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.v_row_stride, params.v_head_stride, _1{}) + ); + Tensor gV = local_tile( + mV(_, bidh / params.h_h_k_ratio, _), + Shape, Int>{}, + make_coord(_, 0) + ); // (kBlockN, kHeadDim, nblocksN) + + Tensor gP = make_tensor( + make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{}) + ); + + Tensor mZeroHold = make_tensor( + make_gmem_ptr(reinterpret_cast(params.zero_hold_ptr) + bidb * params.zero_hold_batch_stride), + make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), + make_stride(params.zero_hold_head_stride, params.zero_hold_query_stride, _1{}) + ); + Tensor gZeroHold = local_tile( + mZeroHold(bidh / params.h_h_k_ratio, _, _), + Shape, Int>{}, + make_coord(m_block, 0) + ); + + Tensor mCausalMask = has_causal_mask ? + make_tensor( + make_gmem_ptr(reinterpret_cast(params.causal_mask_ptr) + bidb * params.causal_mask_batch_stride), + make_shape(1, binfo.actual_seqlen_q, binfo.actual_seqlen_k), + make_stride(params.causal_mask_head_stride, params.causal_mask_query_len_stride, _1{}) + ) : + make_tensor( + static_cast(nullptr), + make_shape(1, 1, 1), + make_stride(0, 0, 0) + ); + Tensor gCausalMask = has_causal_mask ? + local_tile( + mCausalMask(0, _, _), + Shape, Int>{}, + make_coord(m_block, 0) + ) : + make_tensor( + static_cast(nullptr), + make_shape(1, 1), + make_stride(0, 0) + ); From 58c09f2b4be4bdab6ea015d09526375b962858b4 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 18 May 2025 20:34:57 +0800 Subject: [PATCH 11/18] Compute the actual range of N blocks to process --- csrc/src/flash_attention_fwd_kernel.h | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index 8ac2c02..869981e 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -58,9 +58,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // The thread index. const int tidx = threadIdx.x; - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kBlockM = Kernel_traits::kBlockM; // query_block_len + constexpr int kBlockN = Kernel_traits::kBlockN; // key_block_len + constexpr int kHeadDim = Kernel_traits::kHeadDim; // head_dim // Check if there are any queries to process in the block const BlockInfo binfo(params, bidb); @@ -70,8 +70,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const int n_block_min = 0; int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); if (Is_causal) { - n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); + n_block_max = std::min( + n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN) + ); } // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. From 3cb16ad7ab8778762327783062b35e28e3638011 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 18 May 2025 20:36:10 +0800 Subject: [PATCH 12/18] Add judging condition for causal mask --- csrc/src/flash_attention_fwd_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index 869981e..7d62025 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -126,7 +126,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - bool has_causal_mask = params.causal_mask_ptr != nullptr; + bool has_causal_mask = params.causal_mask_ptr != nullptr && Is_causal; // Golobal memory tensor configuration Tensor mQ = make_tensor( From a5c84dae6e6241f9ad8ca89b98e2c334b742ef87 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 18 May 2025 20:57:40 +0800 Subject: [PATCH 13/18] Shared memory layout configuration --- csrc/src/flash_attention_fwd_kernel.h | 81 ++++++++++++++++++--------- 1 file changed, 54 insertions(+), 27 deletions(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index 7d62025..618b4c5 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -201,46 +201,73 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_shape(1, 1), make_stride(0, 0) ); - - - - // 共享内存配置 - // QKV的共享内存布局 - Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQ{}); - Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), - typename Kernel_traits::SmemLayoutKV{}); - Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); - Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - - // Dynamic mask的共享内存布局 - Tensor sZeroHold = make_tensor(sV.data().get() + size(sV), typename Kernel_traits::SmemLayoutZeroHold{}); - Tensor sCausalMask = params.causal_mask_ptr != nullptr - ? make_tensor(sZeroHold.data().get() + size(sZeroHold), - typename Kernel_traits::SmemLayoutZeroHold{}) - : Tensor(); + + // Shared memory layout configuration + Tensor sQ = make_tensor( + make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{} + ); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor( + sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{} + ); + Tensor sV = make_tensor( + sK.data() + size(sK), + typename Kernel_traits::SmemLayoutKV{} + ); + Tensor sVt = make_tensor( + sV.data(), + typename Kernel_traits::SmemLayoutVtransposed{} + ); + Tensor sVtNoSwizzle = make_tensor( + sV.data().get(), + typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{} + ); + + // Dynamic mask related shared memory. Use a running char* pointer for robust allocation. + char* dynamic_smem_current_ptr = reinterpret_cast(sV.data() + size(sV)); + Tensor sZeroHold = make_tensor( + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), + typename Kernel_traits::SmemLayoutZeroHold{} + ); + + dynamic_smem_current_ptr += Kernel_traits::kSmemZeroHoldSize; + auto causal_mask_layout_smem = typename Kernel_traits::SmemLayoutCausalMask{}; + Tensor sCausalMask = has_causal_mask ? + make_tensor(make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), causal_mask_layout_smem) + : make_tensor(static_cast(nullptr), make_shape(Int<1>{}, Int<1>{}), make_stride(0,0)); // Dummy + + if (has_causal_mask) { + dynamic_smem_current_ptr += Kernel_traits::kSmemCausalMaskSize; + } Tensor sDynamicMaskValues = make_tensor( - (params.causal_mask_ptr != nullptr ? - sCausalMask.data().get() + size(sCausalMask) : - sZeroHold.data().get() + size(sZeroHold)), + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // float type typename Kernel_traits::SmemLayoutDynamicMaskValues{} ); + + dynamic_smem_current_ptr += Kernel_traits::kSmemMaskValuesSize; Tensor sDynamicMaskSortKeys = make_tensor( - sDynamicMaskValues.data().get() + size(sDynamicMaskValues), + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // float type typename Kernel_traits::SmemLayoutDynamicMaskSortKeys{} ); + + dynamic_smem_current_ptr += Kernel_traits::kSmemSortKeysSize; Tensor sDynamicMaskSortIndices = make_tensor( - sDynamicMaskSortKeys.data().get() + size(sDynamicMaskSortKeys), + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // int type typename Kernel_traits::SmemLayoutDynamicMaskSortIndices{} ); + + dynamic_smem_current_ptr += Kernel_traits::kSmemSortIndicesSize; Tensor sNonZeroIndices = make_tensor( - sDynamicMaskSortIndices.data().get() + size(sDynamicMaskSortIndices), + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // int type typename Kernel_traits::SmemLayoutNonZeroIndices{} ); + + dynamic_smem_current_ptr += Kernel_traits::kSmemNonZeroIndicesSize; Tensor sPredicate = make_tensor( - sNonZeroIndices.data().get() + size(sNonZeroIndices), - typename Kernel_traits::SmemLayoutZeroHold{} + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // Element type + typename Kernel_traits::SmemLayoutPredicate{} ); From 72fde7821a873c5e412896cdc18c05150e7b7b70 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 18 May 2025 21:09:46 +0800 Subject: [PATCH 14/18] Golobal to Shared Memory operation --- csrc/src/flash_attention_fwd_kernel.h | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index 618b4c5..c28ae4a 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -269,9 +269,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // Element type typename Kernel_traits::SmemLayoutPredicate{} ); - - // 设置全局内存到共享内存的拷贝 + // Golobal to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyZeroHold gmem_tiled_copy_ZeroHold; @@ -281,18 +280,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); Tensor tZeroHoldgZeroHold = gmem_thr_copy_ZeroHold.partition_S(gZeroHold); Tensor tZeroHoldsZeroHold = gmem_thr_copy_ZeroHold.partition_D(sZeroHold); - Tensor tCausalMaskgCausalMask = params.causal_mask_ptr != nullptr - ? gmem_thr_copy_CausalMask.partition_S(gCausalMask) - : Tensor(); - Tensor tCausalMasksCausalMask = params.causal_mask_ptr != nullptr - ? gmem_thr_copy_CausalMask.partition_D(sCausalMask) - : Tensor(); + auto tCausalMaskgCausalMask = has_causal_mask ? + gmem_thr_copy_CausalMask.partition_S(gCausalMask) : + make_tensor(static_cast(nullptr), make_shape(Int<1>{}, Int<1>{}), make_stride(0,0)); + auto tCausalMasksCausalMask = has_causal_mask ? + gmem_thr_copy_CausalMask.partition_D(sCausalMask) : + make_tensor(static_cast(nullptr), make_shape(Int<1>{}, Int<1>{}), make_stride(0,0)); // 设置矩阵乘法操作 typename Kernel_traits::TiledMma tiled_mma; From 8224b24a3468f86148c32abdfa0f7519126dd477 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 18 May 2025 21:13:40 +0800 Subject: [PATCH 15/18] Add Matrix Multiply Accumulate --- csrc/src/flash_attention_fwd_kernel.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index c28ae4a..c42eb97 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -293,12 +293,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi gmem_thr_copy_CausalMask.partition_D(sCausalMask) : make_tensor(static_cast(nullptr), make_shape(Int<1>{}, Int<1>{}), make_stride(0,0)); - // 设置矩阵乘法操作 + // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); - Tensor tSrK = thr_mma.partition_fragment_B(sK); - Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + Tensor tSgS = thr_mma.partition_C(gP); + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K // 设置从共享内存到寄存器的拷贝 auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); From 74bd9d7ec9cd04bfb8fe9b3515e7596095235cec Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 18 May 2025 21:23:43 +0800 Subject: [PATCH 16/18] Add Copy Atom retiling --- csrc/src/flash_attention_fwd_kernel.h | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index c42eb97..db0b964 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -302,7 +302,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tSgS = thr_mma.partition_C(gP); Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K - // 设置从共享内存到寄存器的拷贝 + // Copy Atom retiling auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); @@ -315,19 +315,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - auto smem_tiled_copy_ZeroHold = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_ZeroHold = smem_tiled_copy_ZeroHold.get_thread_slice(tidx); - Tensor tSsZeroHold = smem_thr_copy_ZeroHold.partition_S(sZeroHold); - - auto smem_tiled_copy_CausalMask = params.causal_mask_ptr != nullptr - ? make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma) - : decltype(make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma)){}; - auto smem_thr_copy_CausalMask = params.causal_mask_ptr != nullptr - ? smem_tiled_copy_CausalMask.get_thread_slice(tidx) - : decltype(smem_tiled_copy_CausalMask.get_thread_slice(tidx)){}; - Tensor tSsCausalMask = params.causal_mask_ptr != nullptr - ? smem_thr_copy_CausalMask.partition_S(sCausalMask) - : Tensor(); + // For sZeroHold -> registers (if needed, though mask.h operates on smem directly) + // auto smem_tiled_copy_ZeroHold = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + // auto smem_thr_copy_ZeroHold = smem_tiled_copy_ZeroHold.get_thread_slice(tidx); + // Tensor tSsZeroHold = smem_thr_copy_ZeroHold.partition_S(sZeroHold); + + // For sCausalMask -> registers (if needed) + // using CausalMaskSmemCopyAtom = typename Kernel_traits::SmemCopyAtom; // Assuming Element type + // auto smem_tiled_copy_CausalMask_smem = make_tiled_copy_B(CausalMaskSmemCopyAtom{}, tiled_mma); + // auto smem_thr_copy_CausalMask_smem = smem_tiled_copy_CausalMask_smem.get_thread_slice(tidx); + // Tensor tSsCausalMask = has_causal_mask ? smem_thr_copy_CausalMask_smem.partition_S(sCausalMask) : empty_smem_tensor_for_copy_D; // 设置谓词 Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); From 60e3d194d3447458c224dfb63912df4c022e6b51 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Mon, 19 May 2025 10:56:28 +0800 Subject: [PATCH 17/18] Update Golobal memory tensor configuration --- csrc/src/flash_attention_fwd_kernel.h | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index db0b964..6ad5deb 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -179,27 +179,29 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_coord(m_block, 0) ); - Tensor mCausalMask = has_causal_mask ? + auto mCausalMask = has_causal_mask ? make_tensor( make_gmem_ptr(reinterpret_cast(params.causal_mask_ptr) + bidb * params.causal_mask_batch_stride), make_shape(1, binfo.actual_seqlen_q, binfo.actual_seqlen_k), make_stride(params.causal_mask_head_stride, params.causal_mask_query_len_stride, _1{}) ) : make_tensor( - static_cast(nullptr), + make_gmem_ptr(static_cast(nullptr)), make_shape(1, 1, 1), - make_stride(0, 0, 0) + make_stride(static_cast(0), static_cast(0), _1{}) ); - Tensor gCausalMask = has_causal_mask ? + + auto gCausalMask = has_causal_mask ? local_tile( mCausalMask(0, _, _), Shape, Int>{}, make_coord(m_block, 0) ) : make_tensor( - static_cast(nullptr), - make_shape(1, 1), - make_stride(0, 0) + make_gmem_ptr(static_cast(nullptr)), + make_layout( + Shape, Int>{}, + make_stride(static_cast(0), _1{})) ); // Shared memory layout configuration From 1352882d39c9738e72c3b9da032f0129a45233b6 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Mon, 19 May 2025 12:26:45 +0800 Subject: [PATCH 18/18] Update Dynamic mask related shared memory --- csrc/src/flash_attention_fwd_kernel.h | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index 6ad5deb..de501fa 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -228,17 +228,20 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); // Dynamic mask related shared memory. Use a running char* pointer for robust allocation. - char* dynamic_smem_current_ptr = reinterpret_cast(sV.data() + size(sV)); + char* dynamic_smem_current_ptr = reinterpret_cast(sV.data().get() + size(sV) * sizeof(Element)); Tensor sZeroHold = make_tensor( make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), typename Kernel_traits::SmemLayoutZeroHold{} ); dynamic_smem_current_ptr += Kernel_traits::kSmemZeroHoldSize; - auto causal_mask_layout_smem = typename Kernel_traits::SmemLayoutCausalMask{}; - Tensor sCausalMask = has_causal_mask ? - make_tensor(make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), causal_mask_layout_smem) - : make_tensor(static_cast(nullptr), make_shape(Int<1>{}, Int<1>{}), make_stride(0,0)); // Dummy + auto causal_mask_smem_ptr = has_causal_mask + ? make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)) + : make_smem_ptr(static_cast(nullptr)); + Tensor sCausalMask = make_tensor( + causal_mask_smem_ptr, + typename Kernel_traits::SmemLayoutCausalMask{} + ); if (has_causal_mask) { dynamic_smem_current_ptr += Kernel_traits::kSmemCausalMaskSize;