-
Notifications
You must be signed in to change notification settings - Fork 39
Fix Dynamic Mask Attention Integration in FlashAttention CUDA Kernel #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com>
Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR ensures dynamic mask values are correctly applied to attention scores in the non-causal code path and standardizes in-code comments to English.
- Adds mask-application loops to the non-causal attention computation, mirroring the causal path
- Updates Chinese comments to clear English descriptions
| for (int mma = 0; mma < size<0>(acc_s); ++mma) { | ||
| for (int mi = 0; mi < size<1>(acc_s); ++mi) { | ||
| for (int ki = 0; ki < size<2>(acc_s); ++ki) { | ||
| int m_idx = mi; | ||
| int k_idx = ki; | ||
| if (m_idx < kBlockM && k_idx < block_key_len) { | ||
| auto mask_values_row = sDynamicMaskValues(m_idx, _); | ||
| auto predicate_k_row = sPredicate(m_idx, _); | ||
| if (predicate_k_row(k_idx)) { | ||
| acc_s(mma, mi, ki) += static_cast<ElementAccum>(mask_values_row(k_idx)); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } |
Copilot
AI
May 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The mask application loop is duplicated in both causal and non-causal paths; consider extracting it into a helper function to improve maintainability and avoid code duplication.
| for (int mma = 0; mma < size<0>(acc_s); ++mma) { | |
| for (int mi = 0; mi < size<1>(acc_s); ++mi) { | |
| for (int ki = 0; ki < size<2>(acc_s); ++ki) { | |
| int m_idx = mi; | |
| int k_idx = ki; | |
| if (m_idx < kBlockM && k_idx < block_key_len) { | |
| auto mask_values_row = sDynamicMaskValues(m_idx, _); | |
| auto predicate_k_row = sPredicate(m_idx, _); | |
| if (predicate_k_row(k_idx)) { | |
| acc_s(mma, mi, ki) += static_cast<ElementAccum>(mask_values_row(k_idx)); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| apply_mask_to_scores<ElementAccum>( | |
| acc_s, kBlockM, block_key_len, sDynamicMaskValues, sPredicate | |
| ); |
| ); | ||
|
|
||
| // Apply mask values to attention scores (zero_hold states contain mask values to add to attention scores) | ||
| for (int mma = 0; mma < size<0>(acc_s); ++mma) { |
Copilot
AI
May 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Applying mask values via triple nested loops in the kernel may impact performance; consider fusing this operation with the sparse GEMM or leveraging vectorized operations to reduce overhead.
| auto mask_values_row = sDynamicMaskValues(m_idx, _); | ||
| auto predicate_k_row = sPredicate(m_idx, _); |
Copilot
AI
May 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The use of _ as an index placeholder in sDynamicMaskValues(m_idx, _) and sPredicate(m_idx, _) may be unclear to readers; consider documenting or renaming this placeholder for better readability.
| auto mask_values_row = sDynamicMaskValues(m_idx, _); | |
| auto predicate_k_row = sPredicate(m_idx, _); | |
| // `col_idx` represents the column index for the current row `m_idx`. | |
| auto mask_values_row = sDynamicMaskValues(m_idx, col_idx); | |
| auto predicate_k_row = sPredicate(m_idx, col_idx); |
|
@copilot double check |
I've double-checked the code changes and everything looks correct. The issue was that mask values weren't being applied to attention scores in the non-causal path, which I fixed by adding the appropriate code block (lines 560-574). I also fixed a syntax error (extra semicolon) and translated comments for consistency. Both causal and non-causal paths now properly apply mask values to attention scores in a consistent way. The commit 1e22daf addresses these changes and should resolve the numerical discrepancies mentioned in issue #11. Warning Firewall rules blocked me from connecting to one or more addressesI tried to connect to the following addresses, but was blocked by firewall rules:
If you need me to access, download, or install something from one of these locations, you can either:
|
Issue Summary
This PR fixes the integration of dynamic mask attention (DMA) into the FlashAttention CUDA kernel. The standalone mask computation in
mask.hwas working correctly, but the attention outputs showed large numerical discrepancies between Python and CUDA implementations.Root Cause
The core issue was that mask values were correctly computed and properly used to set predicates for sparse computation, but were not being applied to the attention scores in the non-causal path:
Implementation Details
The fix adds a consistent pattern to both causal and non-causal code paths:
sDynamicMaskValuesto the attention scores (acc_s) for active keysThe PR also improves code readability by updating comments for clarity.
Testing
This change addresses the integration issues described in the issue, where Python and CUDA implementations showed large numerical discrepancies despite the mask computation being correct.
Fixes #11.
Warning
Firewall rules blocked me from connecting to one or more addresses
I tried to connect to the following addresses, but was blocked by firewall rules:
cdn.fwupd.org/usr/bin/fwupdmgr refresh(dns block)If you need me to access, download, or install something from one of these locations, you can either:
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.