Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions csrc/src/flash_attention_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,21 +434,21 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads();
}

// 执行稀疏矩阵乘法
// Execute sparse matrix multiplication
FLASH_NAMESPACE::sparse_gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s,
tSrQ,
tSrK, tSsQ, tSsK,
tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K,
sPredicate // 活跃键的谓词
sPredicate // Active key predicates
);

// 应用掩码添加(zero_hold状态既是掩码也是要添加到注意力分数的值)
// Apply mask values to attention scores (zero_hold states contain mask values to add to attention scores)
for (int mma = 0; mma < size<0>(acc_s); ++mma) {
for (int mi = 0; mi < size<1>(acc_s); ++mi) {
for (int ki = 0; ki < size<2>(acc_s); ++ki) {
int m_idx = mi; // 或者根据你的tile映射
int m_idx = mi;
int k_idx = ki;
if (m_idx < kBlockM && k_idx < block_key_len) {
auto mask_values_row = sDynamicMaskValues(m_idx, _);
Expand Down Expand Up @@ -554,9 +554,26 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
tSrK, tSsQ, tSsK,
tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K,
sPredicate // 活跃键的谓词
sPredicate // Active key predicates
);

// Apply mask values to attention scores (zero_hold states contain mask values to add to attention scores)
for (int mma = 0; mma < size<0>(acc_s); ++mma) {
Copy link

Copilot AI May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applying mask values via triple nested loops in the kernel may impact performance; consider fusing this operation with the sparse GEMM or leveraging vectorized operations to reduce overhead.

Copilot uses AI. Check for mistakes.
for (int mi = 0; mi < size<1>(acc_s); ++mi) {
for (int ki = 0; ki < size<2>(acc_s); ++ki) {
int m_idx = mi;
int k_idx = ki;
if (m_idx < kBlockM && k_idx < block_key_len) {
auto mask_values_row = sDynamicMaskValues(m_idx, _);
auto predicate_k_row = sPredicate(m_idx, _);
Comment on lines +567 to +568
Copy link

Copilot AI May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The use of _ as an index placeholder in sDynamicMaskValues(m_idx, _) and sPredicate(m_idx, _) may be unclear to readers; consider documenting or renaming this placeholder for better readability.

Suggested change
auto mask_values_row = sDynamicMaskValues(m_idx, _);
auto predicate_k_row = sPredicate(m_idx, _);
// `col_idx` represents the column index for the current row `m_idx`.
auto mask_values_row = sDynamicMaskValues(m_idx, col_idx);
auto predicate_k_row = sPredicate(m_idx, col_idx);

Copilot uses AI. Check for mistakes.
if (predicate_k_row(k_idx)) {
acc_s(mma, mi, ki) += static_cast<ElementAccum>(mask_values_row(k_idx));
}
}
}
}
}
Comment on lines +561 to +575
Copy link

Copilot AI May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The mask application loop is duplicated in both causal and non-causal paths; consider extracting it into a helper function to improve maintainability and avoid code duplication.

Suggested change
for (int mma = 0; mma < size<0>(acc_s); ++mma) {
for (int mi = 0; mi < size<1>(acc_s); ++mi) {
for (int ki = 0; ki < size<2>(acc_s); ++ki) {
int m_idx = mi;
int k_idx = ki;
if (m_idx < kBlockM && k_idx < block_key_len) {
auto mask_values_row = sDynamicMaskValues(m_idx, _);
auto predicate_k_row = sPredicate(m_idx, _);
if (predicate_k_row(k_idx)) {
acc_s(mma, mi, ki) += static_cast<ElementAccum>(mask_values_row(k_idx));
}
}
}
}
}
apply_mask_to_scores<ElementAccum>(
acc_s, kBlockM, block_key_len, sDynamicMaskValues, sPredicate
);

Copilot uses AI. Check for mistakes.

FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();
if (n_block > n_block_min) {
Expand Down