Skip to content

[Integration Progress Report] Dynamic Mask Attention Integration into FlashAttention #13

@LoserCheems

Description

@LoserCheems

Overview

This report summarizes the current status and issues encountered during the integration of the dynamic_mask_attention_python logic into the FlashAttention CUDA backend, focusing on the dynamic mask computation and sparse attention weight calculation.


✅ Progress

  • Dynamic Mask Calculation

    • The Python side now precomputes zero_hold_states and passes them to the CUDA backend.
    • The CUDA kernel (mask.h) correctly implements dynamic mask logic, including causal masking and top-k selection.
    • Standalone tests for mask.h show perfect agreement between Python and CUDA outputs (max diff = 0, all mask positions match).
  • Sparse Attention Weight Calculation

    • The CUDA kernel (flash_attention_fwd_kernel.h) integrates the dynamic mask logic, using the mask to select active keys and perform sparse attention computation.
    • The sparse_gemm_rs utility is in place for masked matrix multiplication.
    • Parameter structures (flash.h) are updated to support the new workflow.

❌ Issues

  • End-to-End Equivalence Test Failure

    • When running the full equivalence test (test_dynamic_mask_attention_equivalence.py), the CUDA and Python outputs diverge significantly:
      • Example: max absolute diff ≈ 3.53, mean diff ≈ 0.88, not within tolerance.
      • The largest difference is observed in the attention output, not in the mask itself.
    • The mask logic itself is verified to be correct in isolation.
  • Potential Causes

    • Mismatch in Data Layout or Indexing: There may be a discrepancy in how the mask is applied to the attention scores or how the indices are mapped between Python and CUDA.
    • Incorrect Mask Application in CUDA: The mask values may not be added to the correct positions in the attention score matrix before softmax.
    • Broadcasting/Shape Issues: The output shapes are correct, but there may be subtle bugs in how the mask is expanded or broadcasted in the CUDA kernel.
    • Numerical Type/Precision Issues: The Python reference uses float32, while the CUDA kernel may use bfloat16 or half, but the observed differences are too large to be explained by precision alone.

📝 Next Steps

  1. Deep Audit of CUDA Mask Application:

    • Review how sDynamicMaskValues is added to the attention scores in compute_attn_1rowblock.
    • Ensure that the mask is applied to the same positions as in the Python reference, especially after top-k selection.
  2. Debugging/Instrumentation:

    • Add debug prints (or device-to-host copies) of intermediate tensors (e.g., mask values, attention logits) in both Python and CUDA for a small test case.
    • Compare the masked attention logits before softmax between Python and CUDA.
  3. Test with Simplified Inputs:

    • Use a minimal test case (e.g., batch=1, heads=1, seq=4) with fixed values to manually verify each step.
  4. Check Indexing Consistency:

    • Confirm that the mapping from non-zero mask indices to key/value vectors is consistent between Python and CUDA.
  5. Precision Consistency:

    • Ensure both implementations use the same dtype for all intermediate computations during debugging.

📋 Summary Table

Component Status Notes
Mask logic (mask.h) ✅ Pass Standalone tests perfect
CUDA kernel integration ⚠️ Issue Large output differences in full pipeline
Python reference ✅ Pass Output as expected
End-to-end equivalence ❌ Fail Max diff > 3, mean diff > 0.8

🏷️ Labels

  • integration
  • bug
  • cuda
  • attention
  • dynamic-mask
  • equivalence-test

📎 Attachments

PyTorch 版本: 2.6.0+cu126
设备: cuda
GPU: NVIDIA GeForce RTX 4090
随机种子: 42

测试配置 1:
batch_size=1, num_kv_heads=2, key_len=8
query_len=4, head_dim=4, keep_window_size=2
计算原始Python实现结果...
Python耗时: 0.170754秒
CUDA耗时: 0.000556秒
加速比: 307.25x
结果分析:
最大差异: 0.00000000
平均差异: 0.00000000
两种实现的结果是否相等: True
非零元素位置匹配率: 1.0000
Top-8元素匹配率: 1.0000

测试配置 2:
batch_size=1, num_kv_heads=2, key_len=200
query_len=100, head_dim=8, keep_window_size=50
计算原始Python实现结果...
Python耗时: 0.014448秒
CUDA耗时: 0.000201秒
加速比: 71.80x
结果分析:
最大差异: 0.00000000
平均差异: 0.00000000
两种实现的结果是否相等: True
非零元素位置匹配率: 1.0000
Top-128元素匹配率: 1.0000

测试配置 3:
batch_size=1, num_kv_heads=2, key_len=2048
query_len=2048, head_dim=128, keep_window_size=2048
计算原始Python实现结果...
Python耗时: 0.169489秒
CUDA耗时: 0.000752秒
加速比: 225.46x
结果分析:
最大差异: 0.00000000
平均差异: 0.00000000
两种实现的结果是否相等: True
非零元素位置匹配率: 1.0000
Top-128元素匹配率: 1.0000

测试配置 4:
batch_size=1, num_kv_heads=1, key_len=40
query_len=40, head_dim=16, keep_window_size=100
计算原始Python实现结果...
Python耗时: 0.008414秒
CUDA耗时: 0.000200秒
加速比: 42.16x
结果分析:
最大差异: 0.00000000
平均差异: 0.00000000
两种实现的结果是否相等: True
非零元素位置匹配率: 1.0000
Top-40元素匹配率: 1.0000

以上是mask.h测试, 完全没有问题了.

测试Python原型和CUDA实现的等价性

使用设备: cuda

测试配置 1/4:
batch_size=1, num_heads=1, num_kv_heads=1
query_len=32, key_len=32, head_dim=32
is_causal=True
需要的共享内存大小: 9472
原始结果: torch.Size([1, 32, 1, 32]), torch.float32
CUDA结果: torch.Size([1, 32, 1, 32]), torch.float32
结果分析:
最大绝对差异: 3.53487039
平均绝对差异: 0.88376451
两种实现的结果是否相等 (rtol=1e-3, atol=1e-3): 否

最大差异位置: batch=0, query=13, head=0, dim=20
Python值: -2.230183
CUDA值: 1.304688
差异: 3.534870
该head在该位置的平均差异: 0.95409763

性能对比:
Python实现: 2504.19 ms
CUDA实现: 31.49 ms
加速比: 79.52x

测试结果: 失败
差异过大,停止后续测试。


Please assign to CUDA kernel maintainers for further investigation.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions