From 1e582494059b216c767e6f4e04ecf47a0217cf2b Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Mon, 19 May 2025 12:41:58 +0800 Subject: [PATCH] Update golobal to Shared Memory operation --- csrc/src/flash_attention_fwd_kernel.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index de501fa..8cfd9e0 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -291,12 +291,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); Tensor tZeroHoldgZeroHold = gmem_thr_copy_ZeroHold.partition_S(gZeroHold); Tensor tZeroHoldsZeroHold = gmem_thr_copy_ZeroHold.partition_D(sZeroHold); - auto tCausalMaskgCausalMask = has_causal_mask ? - gmem_thr_copy_CausalMask.partition_S(gCausalMask) : - make_tensor(static_cast(nullptr), make_shape(Int<1>{}, Int<1>{}), make_stride(0,0)); - auto tCausalMasksCausalMask = has_causal_mask ? - gmem_thr_copy_CausalMask.partition_D(sCausalMask) : - make_tensor(static_cast(nullptr), make_shape(Int<1>{}, Int<1>{}), make_stride(0,0)); + decltype(gmem_thr_copy_CausalMask.partition_S(gCausalMask)) tCausalMaskgCausalMask; + decltype(gmem_thr_copy_CausalMask.partition_D(sCausalMask)) tCausalMasksCausalMask; + if (has_causal_mask) { + tCausalMaskgCausalMask = gmem_thr_copy_CausalMask.partition_S(gCausalMask); + tCausalMasksCausalMask = gmem_thr_copy_CausalMask.partition_D(sCausalMask); + } // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma;