Skip to content

[Integration Progress Report] Dynamic Mask Attention Integration into FlashAttention #11

@LoserCheems

Description

@LoserCheems

Overview

This issue tracks the integration of the dynamic_mask_attention_python logic into the CUDA FlashAttention kernel. The integration is split into two main parts:

  1. Dynamic Mask Calculation

    • Python side precomputes zero_hold_states and passes them to the CUDA backend.
    • The CUDA kernel (see mask.h) now supports dynamic mask logic, including causal masking and top-k selection.
    • All mask-related CUDA logic passes standalone tests with perfect accuracy.
  2. Sparse Attention Weight Computation

    • The CUDA kernel (flash_attention_fwd_kernel.h) is updated to use the mask logic from mask.h within compute_attn_1rowblock.
    • Sparse attention computation uses sparse_gemm_rs (in utils.h).
    • Parameter structures are updated (flash.h) to support the new logic.

Current Status

✅ Dynamic Mask Calculation (mask.h)

  • Unit tests (see test logs) show perfect agreement between Python and CUDA for mask computation:
    • Max/mean difference: 0.0
    • Nonzero/top-k position match: 1.0
  • Both causal and non-causal, various batch/head/seq/key sizes, and keep_window_size are covered.

❌ End-to-End Dynamic Mask Attention (flash_attention_fwd_kernel.h)

  • Integration into the main attention kernel is incomplete or incorrect.
  • When running full attention equivalence tests:
    • Large numerical discrepancies between Python and CUDA results.
    • Example: max absolute difference > 3.5, mean difference ~0.88, allclose fails.
    • The difference is not random noise but systematic, indicating a logic or indexing bug.

Key Observations

  • Mask logic is correct in isolation (mask.h).
  • Attention output diverges after integrating mask logic into the main kernel.
  • The error is not due to data type, device, or memory layout issues (all tensors are contiguous and on CUDA).
  • The error is not due to the mask calculation itself, but likely due to how the mask is applied in the attention computation.

Suspected Issues

  • Indexing or broadcasting mismatch between the mask and the attention computation.
  • Incorrect application of the mask values to the attention logits or weights.
  • Potential off-by-one or block/row/col misalignment in the CUDA kernel when mapping mask values to the correct positions.
  • Possible error in the integration of sparse_gemm or sparse_gemm_rs, e.g., not masking the correct elements or not adding the mask values to the correct logits.

Next Steps

  1. Audit the CUDA kernel logic in flash_attention_fwd_kernel.h:

    • Ensure that the mask values are added to the correct attention logits (before softmax).
    • Double-check the mapping between mask indices and the actual key positions.
    • Confirm that the mask is applied identically to the Python reference.
  2. Add debug prints or assertions (e.g., dump intermediate logits, mask values, and output for a single block) to compare CUDA and Python step-by-step.

  3. Check the integration of sparse_gemm and sparse_gemm_rs:

    • Make sure the predicates and mask values are used consistently.
    • Validate that only the intended elements are included in the attention computation.
  4. Test with minimal examples (e.g., batch=1, head=1, seq=4, key=4, keep_window=2) and compare all intermediate tensors.


Summary Table

Component Status Notes
mask.h (dynamic mask) Passes all standalone tests, matches Python perfectly
flash_attention_fwd_kernel Fails equivalence tests, large systematic errors in output
Parameter passing All tensors are contiguous, correct dtype, device, and shape
Python reference Correct, matches mask.h CUDA output

Request for Comments

  • Any suggestions for debugging strategies at the CUDA kernel level?
  • Anyone with experience integrating sparse/dynamic masking into blockwise attention kernels is welcome to comment.
  • If you spot a likely cause in the integration code, please point to the relevant lines.

This issue will be updated as integration progresses.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions