From e422efb9196a16b5f15a2752ce82a07de2de9f4a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 22 May 2025 02:45:40 +0000 Subject: [PATCH 1/3] Initial plan for issue From 765fe4c80cd1b7a8c899c5d379a41b2f199beceb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 22 May 2025 02:50:06 +0000 Subject: [PATCH 2/3] Fix mask application in sparse attention computation Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com> --- csrc/src/flash_attention_fwd_kernel.h | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index 0a8b9ae..ae441fd 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -434,21 +434,21 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); } - // 执行稀疏矩阵乘法 + // Execute sparse matrix multiplication FLASH_NAMESPACE::sparse_gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K, - sPredicate // 活跃键的谓词 + sPredicate // Active key predicates ); - // 应用掩码添加(zero_hold状态既是掩码也是要添加到注意力分数的值) + // Apply mask values to attention scores (zero_hold states contain mask values to add to attention scores) for (int mma = 0; mma < size<0>(acc_s); ++mma) { for (int mi = 0; mi < size<1>(acc_s); ++mi) { for (int ki = 0; ki < size<2>(acc_s); ++ki) { - int m_idx = mi; // 或者根据你的tile映射 + int m_idx = mi; int k_idx = ki; if (m_idx < kBlockM && k_idx < block_key_len) { auto mask_values_row = sDynamicMaskValues(m_idx, _); @@ -554,9 +554,26 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K, - sPredicate // 活跃键的谓词 + sPredicate // Active key predicates ); + // Apply mask values to attention scores (zero_hold states contain mask values to add to attention scores) + for (int mma = 0; mma < size<0>(acc_s); ++mma) { + for (int mi = 0; mi < size<1>(acc_s); ++mi) { + for (int ki = 0; ki < size<2>(acc_s); ++ki) { + int m_idx = mi; + int k_idx = ki; + if (m_idx < kBlockM && k_idx < block_key_len) { + auto mask_values_row = sDynamicMaskValues(m_idx, _); + auto predicate_k_row = sPredicate(m_idx, _); + if (predicate_k_row(k_idx)) { + acc_s(mma, mi, ki) += static_cast(mask_values_row(k_idx)); + } + } + } + } + }; + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { From 1e22daf319082b9a19f53cd1505c2130ed782889 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 22 May 2025 02:50:43 +0000 Subject: [PATCH 3/3] Fix semicolon typo in for loop Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com> --- csrc/src/flash_attention_fwd_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index ae441fd..13a196e 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -572,7 +572,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } } } - }; + } FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads();