-
Notifications
You must be signed in to change notification settings - Fork 39
Fix Dynamic Mask Attention Integration in FlashAttention CUDA Kernel #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -434,21 +434,21 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi | |||||||||||||||||||||||||||||||||||||
| __syncthreads(); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // 执行稀疏矩阵乘法 | ||||||||||||||||||||||||||||||||||||||
| // Execute sparse matrix multiplication | ||||||||||||||||||||||||||||||||||||||
| FLASH_NAMESPACE::sparse_gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( | ||||||||||||||||||||||||||||||||||||||
| acc_s, | ||||||||||||||||||||||||||||||||||||||
| tSrQ, | ||||||||||||||||||||||||||||||||||||||
| tSrK, tSsQ, tSsK, | ||||||||||||||||||||||||||||||||||||||
| tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, | ||||||||||||||||||||||||||||||||||||||
| smem_thr_copy_Q, smem_thr_copy_K, | ||||||||||||||||||||||||||||||||||||||
| sPredicate // 活跃键的谓词 | ||||||||||||||||||||||||||||||||||||||
| sPredicate // Active key predicates | ||||||||||||||||||||||||||||||||||||||
| ); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // 应用掩码添加(zero_hold状态既是掩码也是要添加到注意力分数的值) | ||||||||||||||||||||||||||||||||||||||
| // Apply mask values to attention scores (zero_hold states contain mask values to add to attention scores) | ||||||||||||||||||||||||||||||||||||||
| for (int mma = 0; mma < size<0>(acc_s); ++mma) { | ||||||||||||||||||||||||||||||||||||||
| for (int mi = 0; mi < size<1>(acc_s); ++mi) { | ||||||||||||||||||||||||||||||||||||||
| for (int ki = 0; ki < size<2>(acc_s); ++ki) { | ||||||||||||||||||||||||||||||||||||||
| int m_idx = mi; // 或者根据你的tile映射 | ||||||||||||||||||||||||||||||||||||||
| int m_idx = mi; | ||||||||||||||||||||||||||||||||||||||
| int k_idx = ki; | ||||||||||||||||||||||||||||||||||||||
| if (m_idx < kBlockM && k_idx < block_key_len) { | ||||||||||||||||||||||||||||||||||||||
| auto mask_values_row = sDynamicMaskValues(m_idx, _); | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -554,9 +554,26 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi | |||||||||||||||||||||||||||||||||||||
| tSrK, tSsQ, tSsK, | ||||||||||||||||||||||||||||||||||||||
| tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, | ||||||||||||||||||||||||||||||||||||||
| smem_thr_copy_Q, smem_thr_copy_K, | ||||||||||||||||||||||||||||||||||||||
| sPredicate // 活跃键的谓词 | ||||||||||||||||||||||||||||||||||||||
| sPredicate // Active key predicates | ||||||||||||||||||||||||||||||||||||||
| ); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // Apply mask values to attention scores (zero_hold states contain mask values to add to attention scores) | ||||||||||||||||||||||||||||||||||||||
| for (int mma = 0; mma < size<0>(acc_s); ++mma) { | ||||||||||||||||||||||||||||||||||||||
| for (int mi = 0; mi < size<1>(acc_s); ++mi) { | ||||||||||||||||||||||||||||||||||||||
| for (int ki = 0; ki < size<2>(acc_s); ++ki) { | ||||||||||||||||||||||||||||||||||||||
| int m_idx = mi; | ||||||||||||||||||||||||||||||||||||||
| int k_idx = ki; | ||||||||||||||||||||||||||||||||||||||
| if (m_idx < kBlockM && k_idx < block_key_len) { | ||||||||||||||||||||||||||||||||||||||
| auto mask_values_row = sDynamicMaskValues(m_idx, _); | ||||||||||||||||||||||||||||||||||||||
| auto predicate_k_row = sPredicate(m_idx, _); | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+567
to
+568
|
||||||||||||||||||||||||||||||||||||||
| auto mask_values_row = sDynamicMaskValues(m_idx, _); | |
| auto predicate_k_row = sPredicate(m_idx, _); | |
| // `col_idx` represents the column index for the current row `m_idx`. | |
| auto mask_values_row = sDynamicMaskValues(m_idx, col_idx); | |
| auto predicate_k_row = sPredicate(m_idx, col_idx); |
Copilot
AI
May 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The mask application loop is duplicated in both causal and non-causal paths; consider extracting it into a helper function to improve maintainability and avoid code duplication.
| for (int mma = 0; mma < size<0>(acc_s); ++mma) { | |
| for (int mi = 0; mi < size<1>(acc_s); ++mi) { | |
| for (int ki = 0; ki < size<2>(acc_s); ++ki) { | |
| int m_idx = mi; | |
| int k_idx = ki; | |
| if (m_idx < kBlockM && k_idx < block_key_len) { | |
| auto mask_values_row = sDynamicMaskValues(m_idx, _); | |
| auto predicate_k_row = sPredicate(m_idx, _); | |
| if (predicate_k_row(k_idx)) { | |
| acc_s(mma, mi, ki) += static_cast<ElementAccum>(mask_values_row(k_idx)); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| apply_mask_to_scores<ElementAccum>( | |
| acc_s, kBlockM, block_key_len, sDynamicMaskValues, sPredicate | |
| ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Applying mask values via triple nested loops in the kernel may impact performance; consider fusing this operation with the sparse GEMM or leveraging vectorized operations to reduce overhead.