-
Notifications
You must be signed in to change notification settings - Fork 39
Closed
Description
Overview
This issue tracks the integration of the dynamic_mask_attention_python logic into the CUDA FlashAttention kernel. The integration is split into two main parts:
-
Dynamic Mask Calculation
- Python side precomputes
zero_hold_statesand passes them to the CUDA backend. - The CUDA kernel (see mask.h) now supports dynamic mask logic, including causal masking and top-k selection.
- All mask-related CUDA logic passes standalone tests with perfect accuracy.
- Python side precomputes
-
Sparse Attention Weight Computation
- The CUDA kernel (flash_attention_fwd_kernel.h) is updated to use the mask logic from mask.h within
compute_attn_1rowblock. - Sparse attention computation uses
sparse_gemm_rs(in utils.h). - Parameter structures are updated (flash.h) to support the new logic.
- The CUDA kernel (flash_attention_fwd_kernel.h) is updated to use the mask logic from mask.h within
Current Status
✅ Dynamic Mask Calculation (mask.h)
- Unit tests (see test logs) show perfect agreement between Python and CUDA for mask computation:
- Max/mean difference: 0.0
- Nonzero/top-k position match: 1.0
- Both causal and non-causal, various batch/head/seq/key sizes, and keep_window_size are covered.
❌ End-to-End Dynamic Mask Attention (flash_attention_fwd_kernel.h)
- Integration into the main attention kernel is incomplete or incorrect.
- When running full attention equivalence tests:
- Large numerical discrepancies between Python and CUDA results.
- Example: max absolute difference > 3.5, mean difference ~0.88, allclose fails.
- The difference is not random noise but systematic, indicating a logic or indexing bug.
Key Observations
- Mask logic is correct in isolation (mask.h).
- Attention output diverges after integrating mask logic into the main kernel.
- The error is not due to data type, device, or memory layout issues (all tensors are contiguous and on CUDA).
- The error is not due to the mask calculation itself, but likely due to how the mask is applied in the attention computation.
Suspected Issues
- Indexing or broadcasting mismatch between the mask and the attention computation.
- Incorrect application of the mask values to the attention logits or weights.
- Potential off-by-one or block/row/col misalignment in the CUDA kernel when mapping mask values to the correct positions.
- Possible error in the integration of
sparse_gemmorsparse_gemm_rs, e.g., not masking the correct elements or not adding the mask values to the correct logits.
Next Steps
-
Audit the CUDA kernel logic in flash_attention_fwd_kernel.h:
- Ensure that the mask values are added to the correct attention logits (before softmax).
- Double-check the mapping between mask indices and the actual key positions.
- Confirm that the mask is applied identically to the Python reference.
-
Add debug prints or assertions (e.g., dump intermediate logits, mask values, and output for a single block) to compare CUDA and Python step-by-step.
-
Check the integration of
sparse_gemmandsparse_gemm_rs:- Make sure the predicates and mask values are used consistently.
- Validate that only the intended elements are included in the attention computation.
-
Test with minimal examples (e.g., batch=1, head=1, seq=4, key=4, keep_window=2) and compare all intermediate tensors.
Summary Table
| Component | Status | Notes |
|---|---|---|
| mask.h (dynamic mask) | ✅ | Passes all standalone tests, matches Python perfectly |
| flash_attention_fwd_kernel | ❌ | Fails equivalence tests, large systematic errors in output |
| Parameter passing | ✅ | All tensors are contiguous, correct dtype, device, and shape |
| Python reference | ✅ | Correct, matches mask.h CUDA output |
Request for Comments
- Any suggestions for debugging strategies at the CUDA kernel level?
- Anyone with experience integrating sparse/dynamic masking into blockwise attention kernels is welcome to comment.
- If you spot a likely cause in the integration code, please point to the relevant lines.
This issue will be updated as integration progresses.
Copilot
Metadata
Metadata
Assignees
Labels
No labels