diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index 0a8b9ae..13a196e 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -434,21 +434,21 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); } - // 执行稀疏矩阵乘法 + // Execute sparse matrix multiplication FLASH_NAMESPACE::sparse_gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K, - sPredicate // 活跃键的谓词 + sPredicate // Active key predicates ); - // 应用掩码添加(zero_hold状态既是掩码也是要添加到注意力分数的值) + // Apply mask values to attention scores (zero_hold states contain mask values to add to attention scores) for (int mma = 0; mma < size<0>(acc_s); ++mma) { for (int mi = 0; mi < size<1>(acc_s); ++mi) { for (int ki = 0; ki < size<2>(acc_s); ++ki) { - int m_idx = mi; // 或者根据你的tile映射 + int m_idx = mi; int k_idx = ki; if (m_idx < kBlockM && k_idx < block_key_len) { auto mask_values_row = sDynamicMaskValues(m_idx, _); @@ -554,9 +554,26 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K, - sPredicate // 活跃键的谓词 + sPredicate // Active key predicates ); + // Apply mask values to attention scores (zero_hold states contain mask values to add to attention scores) + for (int mma = 0; mma < size<0>(acc_s); ++mma) { + for (int mi = 0; mi < size<1>(acc_s); ++mi) { + for (int ki = 0; ki < size<2>(acc_s); ++ki) { + int m_idx = mi; + int k_idx = ki; + if (m_idx < kBlockM && k_idx < block_key_len) { + auto mask_values_row = sDynamicMaskValues(m_idx, _); + auto predicate_k_row = sPredicate(m_idx, _); + if (predicate_k_row(k_idx)) { + acc_s(mma, mi, ki) += static_cast(mask_values_row(k_idx)); + } + } + } + } + } + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) {