diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index 9d3f14c..a668b3c 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -169,40 +169,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); Tensor mZeroHold = make_tensor( - make_gmem_ptr(reinterpret_cast(params.zero_hold_ptr) + bidb * params.zero_hold_batch_stride), + make_gmem_ptr(reinterpret_cast(params.zero_hold_ptr) + binfo.q_offset(params.zero_hold_batch_stride, params.zero_hold_row_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.zero_hold_head_stride, params.zero_hold_query_stride, _1{}) + make_stride(params.zero_hold_head_stride, params.zero_hold_row_stride, _1{}) ); Tensor gZeroHold = local_tile( mZeroHold(bidh / params.h_h_k_ratio, _, _), Shape, Int>{}, - make_coord(m_block, 0) - ); - - auto mCausalMask = has_causal_mask ? - make_tensor( - make_gmem_ptr(reinterpret_cast(params.causal_mask_ptr) + bidb * params.causal_mask_batch_stride), - make_shape(1, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.causal_mask_head_stride, params.causal_mask_query_len_stride, _1{}) - ) : - make_tensor( - make_gmem_ptr(static_cast(nullptr)), - make_shape(1, 1, 1), - make_stride(static_cast(0), static_cast(0), _1{}) - ); - - auto gCausalMask = has_causal_mask ? - local_tile( - mCausalMask(0, _, _), - Shape, Int>{}, - make_coord(m_block, 0) - ) : - make_tensor( - make_gmem_ptr(static_cast(nullptr)), - make_layout( - Shape, Int>{}, - make_stride(static_cast(0), _1{})) - ); + make_coord(m_block, n_block_max - 1) + ); // (kBlockM, kBlockN) // Shared memory layout configuration Tensor sQ = make_tensor( @@ -230,22 +205,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Dynamic mask related shared memory. Use a running char* pointer for robust allocation. char* dynamic_smem_current_ptr = reinterpret_cast(sV.data().get() + size(sV) * sizeof(Element)); Tensor sZeroHold = make_tensor( - make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // Element type typename Kernel_traits::SmemLayoutZeroHold{} ); dynamic_smem_current_ptr += Kernel_traits::kSmemZeroHoldSize; - auto causal_mask_smem_ptr = has_causal_mask ? - make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)) : - make_smem_ptr(static_cast(nullptr)); - Tensor sCausalMask = make_tensor( - causal_mask_smem_ptr, - typename Kernel_traits::SmemLayoutCausalMask{} - ); - - if (has_causal_mask) { - dynamic_smem_current_ptr += Kernel_traits::kSmemCausalMaskSize; - } Tensor sDynamicMaskValues = make_tensor( make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // float type typename Kernel_traits::SmemLayoutDynamicMaskValues{} @@ -280,8 +244,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyZeroHold gmem_tiled_copy_ZeroHold; auto gmem_thr_copy_ZeroHold = gmem_tiled_copy_ZeroHold.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyZeroHold gmem_tiled_copy_CausalMask; - auto gmem_thr_copy_CausalMask = gmem_tiled_copy_CausalMask.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -291,12 +253,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); Tensor tZeroHoldgZeroHold = gmem_thr_copy_ZeroHold.partition_S(gZeroHold); Tensor tZeroHoldsZeroHold = gmem_thr_copy_ZeroHold.partition_D(sZeroHold); - decltype(gmem_thr_copy_CausalMask.partition_S(gCausalMask)) tCausalMaskgCausalMask; - decltype(gmem_thr_copy_CausalMask.partition_D(sCausalMask)) tCausalMasksCausalMask; - if (has_causal_mask) { - tCausalMaskgCausalMask = gmem_thr_copy_CausalMask.partition_S(gCausalMask); - tCausalMasksCausalMask = gmem_thr_copy_CausalMask.partition_D(sCausalMask); - } // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; @@ -336,23 +292,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Identity tensor for gZeroHold -> sZeroHold copy Tensor cZeroHold = make_identity_tensor(make_shape(size<0>(sZeroHold), size<1>(sZeroHold))); - // Identity tensor for gCausalMask -> sCausalMask copy, use dummy 1×1 when no mask - Tensor cCausalMask = make_identity_tensor(make_shape( - has_causal_mask ? size<0>(sCausalMask) : Int<1>{}, - has_causal_mask ? size<1>(sCausalMask) : Int<1>{} - )); // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) // Predicate for ZeroHold GMEM copy Tensor tZeroHoldcZeroHold = gmem_thr_copy_ZeroHold.partition_S(cZeroHold); - // Predicate for CausalMask GMEM copy - Tensor tCausalMaskcCausalMask = gmem_thr_copy_CausalMask.partition_S(cCausalMask); // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); Tensor tZeroHoldpZeroHold = make_tensor(make_shape(size<2>(tZeroHoldsZeroHold))); // N-dim predicate for ZeroHold - Tensor tCausalMaskpCausalMask = make_tensor(make_shape(size<2>(tCausalMasksCausalMask))); // N-dim predicate for CausalMask (always allocate; only used when has_causal_mask) // Set predicates for k bounds if (!Is_even_K) { #pragma unroll @@ -363,71 +311,61 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tZeroHoldpZeroHold); ++k) { + tZeroHoldpZeroHold(k) = true; // All elements are valid for the moment + } } - // 初始化动态掩码处理器 + // Prologue + // Init dynamic mask processor DynamicMask dynamic_mask(params.keep_window_size); - - // 加载Q到共享内存 + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM ); - if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } - - // 如果共享Q和K的内存,需要等待并同步 + // If share Q and K smem, wait and sync if (Kernel_traits::Share_Q_K_smem) { FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); } - - // 反向迭代N块 + // Reverse iteration over N blocks int n_block = n_block_max - 1; - - // 加载第一个K块到共享内存 + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, + gmem_tiled_copy_QKV, + tKgK(_, _, _, n_block), + tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); cute::cp_async_fence(); - - // 加载第一个ZeroHold块到共享内存 - FLASH_NAMESPACE::copy( - gmem_tiled_copy_ZeroHold, tZeroHoldgZeroHold(_, _, _, n_block), tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold, - binfo.actual_seqlen_k - n_block * kBlockN - ); - cute::cp_async_fence(); - - // 加载第一个CausalMask块到共享内存(如果有) - if (params.causal_mask_ptr != nullptr) { - FLASH_NAMESPACE::copy( - gmem_tiled_copy_CausalMask, tCausalMaskgCausalMask(_, _, _, n_block), tCausalMasksCausalMask, tCausalMaskcCausalMask, tCausalMaskpCausalMask, - binfo.actual_seqlen_k - n_block * kBlockN - ); - cute::cp_async_fence(); - } - - // 将Q从共享内存加载到寄存器(如果需要) if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { FLASH_NAMESPACE::cp_async_wait<1>(); __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); } - - // 初始化输出累加器 - Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); + // For ZeroHold, Is_even_K in copy refers to the kBlockN dimension alignment for vectorization, + // which is generally true. The boundary is handled by the length argument. + FLASH_NAMESPACE::copy( + gmem_tiled_copy_ZeroHold, + tZeroHoldgZeroHold, + tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold, + binfo.actual_seqlen_k - n_block * kBlockN + ); + cute::cp_async_fence(); + clear(acc_o); - - // 创建softmax计算器 + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; // 处理需要掩码的块(通常是最后几个块)