From 235f3ce75a4a7cb421bf745cfbfd22f1d80c519d Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Mon, 19 May 2025 13:29:57 +0800 Subject: [PATCH] Update PREDICATES --- csrc/src/flash_attention_fwd_kernel.h | 58 +++++++++++++-------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index 8cfd9e0..9d3f14c 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -235,9 +235,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); dynamic_smem_current_ptr += Kernel_traits::kSmemZeroHoldSize; - auto causal_mask_smem_ptr = has_causal_mask - ? make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)) - : make_smem_ptr(static_cast(nullptr)); + auto causal_mask_smem_ptr = has_causal_mask ? + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)) : + make_smem_ptr(static_cast(nullptr)); Tensor sCausalMask = make_tensor( causal_mask_smem_ptr, typename Kernel_traits::SmemLayoutCausalMask{} @@ -247,31 +247,31 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi dynamic_smem_current_ptr += Kernel_traits::kSmemCausalMaskSize; } Tensor sDynamicMaskValues = make_tensor( - make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // float type + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // float type typename Kernel_traits::SmemLayoutDynamicMaskValues{} ); dynamic_smem_current_ptr += Kernel_traits::kSmemMaskValuesSize; Tensor sDynamicMaskSortKeys = make_tensor( - make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // float type + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // float type typename Kernel_traits::SmemLayoutDynamicMaskSortKeys{} ); dynamic_smem_current_ptr += Kernel_traits::kSmemSortKeysSize; Tensor sDynamicMaskSortIndices = make_tensor( - make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // int type + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // int type typename Kernel_traits::SmemLayoutDynamicMaskSortIndices{} ); dynamic_smem_current_ptr += Kernel_traits::kSmemSortIndicesSize; Tensor sNonZeroIndices = make_tensor( - make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // int type + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // int type typename Kernel_traits::SmemLayoutNonZeroIndices{} ); dynamic_smem_current_ptr += Kernel_traits::kSmemNonZeroIndicesSize; Tensor sPredicate = make_tensor( - make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // Element type + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // Element type typename Kernel_traits::SmemLayoutPredicate{} ); @@ -331,29 +331,29 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // auto smem_thr_copy_CausalMask_smem = smem_tiled_copy_CausalMask_smem.get_thread_slice(tidx); // Tensor tSsCausalMask = has_causal_mask ? smem_thr_copy_CausalMask_smem.partition_S(sCausalMask) : empty_smem_tensor_for_copy_D; - // 设置谓词 - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); - Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); - Tensor cZeroHold = make_identity_tensor(make_shape(size<0>(sZeroHold), size<1>(sZeroHold))); - Tensor cCausalMask = params.causal_mask_ptr != nullptr - ? make_identity_tensor(make_shape(size<0>(sCausalMask), size<1>(sCausalMask))) - : Tensor(); - - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); - Tensor tZeroHoldcZeroHold = gmem_thr_copy_ZeroHold.partition_S(cZeroHold); - Tensor tCausalMaskcCausalMask = params.causal_mask_ptr != nullptr - ? gmem_thr_copy_CausalMask.partition_S(cCausalMask) - : Tensor(); - + // PREDICATES + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Identity tensor for gZeroHold -> sZeroHold copy + Tensor cZeroHold = make_identity_tensor(make_shape(size<0>(sZeroHold), size<1>(sZeroHold))); + // Identity tensor for gCausalMask -> sCausalMask copy, use dummy 1×1 when no mask + Tensor cCausalMask = make_identity_tensor(make_shape( + has_causal_mask ? size<0>(sCausalMask) : Int<1>{}, + has_causal_mask ? size<1>(sCausalMask) : Int<1>{} + )); + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + // Predicate for ZeroHold GMEM copy + Tensor tZeroHoldcZeroHold = gmem_thr_copy_ZeroHold.partition_S(cZeroHold); + // Predicate for CausalMask GMEM copy + Tensor tCausalMaskcCausalMask = gmem_thr_copy_CausalMask.partition_S(cCausalMask); + // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); - Tensor tZeroHoldpZeroHold = make_tensor(make_shape(size<2>(tZeroHoldsZeroHold))); - Tensor tCausalMaskpCausalMask = params.causal_mask_ptr != nullptr - ? make_tensor(make_shape(size<2>(tCausalMasksCausalMask))) - : Tensor(); - - // 设置K维度的谓词 + Tensor tZeroHoldpZeroHold = make_tensor(make_shape(size<2>(tZeroHoldsZeroHold))); // N-dim predicate for ZeroHold + Tensor tCausalMaskpCausalMask = make_tensor(make_shape(size<2>(tCausalMasksCausalMask))); // N-dim predicate for CausalMask (always allocate; only used when has_causal_mask) + // Set predicates for k bounds if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tQpQ); ++k) {