Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

@LoserCheems LoserCheems commented Nov 7, 2025

Summary

  • Purpose: Introduce a compact dynamic-mask attention (DMA) path optimized with Triton that delivers substantially faster training while preserving exact numerical equivalence with the baseline DMA.
  • Outcome: On bf16 and large-window workloads, this implementation achieves ~1.6× end-to-end speedup for forward+backward with identical outputs and gradients to the original implementation.

Design

  • Compact representation: Preprocess K/V/B via indices into compact buffers (CuK, CuV, CuB) and build a boolean mask CuM for [query_len × window_size]. This eliminates useless memory traffic and compute outside the window.
  • Streaming softmax (forward): Iterate over compact blocks, apply bias and mask, compute stable log-sum-exp statistics (lse), and accumulate outputs. Skip fully-inactive tiles.
  • Backward rematerialization:
    • Compute Delta = Σ(o * do) per row in fp32.
    • For each column block, rematerialize scores/probabilities under CuM, accumulate dV/dK/dB/dQ in fp32, then cast to input dtype.
    • Scatter-add compact dK/dV/dB back to the original sequence dimension using attn_indices.
  • GQA mapping: Map Q heads to KV heads via h_h_k_ratio = nheads // nheads_k in both forward and backward.
  • Numerical stability: All accumulations are performed in fp32, with stable lse tracking identical to the baseline DMA. Causal masking is supported and pre-applied into CuM.
  • Constraints: head_dim ≤ 128, dtype in {fp16, bf16}, attn_indices is int64 and must be valid for the chosen window_size.

Changes

  • New fast path: flash_dmattn.flash_dmattn_triton_special.triton_dmattn_func(query, key, value, attn_bias, attn_indices, is_causal=False, softmax_scale=None).
  • Triton kernels:
    • _fwd_preprocess: gather K/V/B into CuK/CuV/CuB and construct CuM with row/col/causal masking.
    • _fwd_kernel: streaming softmax over compact tiles with stable lse.
    • _bwd_preprocess_do_o_dot: compute per-row Delta.
    • _bwd_kernel + _bwd_kernel_one_col_block: column-block backward with fp32 accumulators, then scatter back via indices.
  • Public API: No breaking changes; adds the triton_dmattn_func convenience entrypoint.
  • Behavior: Supports causal and non-causal; works with GQA/MQA (Q heads divisible by KV heads). Assumes attn_indices are valid indices into key_len.

Implementation notes

  • Safety: Double-masking (row and col) prevents OOB loads/stores on non-divisible tiles; causal is pre-baked into CuM.
  • Precision: Internal reductions are fp32; final outputs cast to the input dtype (fp16/bf16).
  • Performance: Skips fully inactive tiles; reduces memory bandwidth via compact buffers; autotune configs provided for common BLOCK_M/N and num_warps.
  • Indices: attn_indices should be in [0, key_len); generator (e.g., topk_indices) guarantees validity. If external indices are used, optional guards can be added during scatter to drop invalid entries.
  • Limits: Designed for head_dim ≤ 128; extending beyond may require additional kernel variants.

Tests

  • Correctness: Exact numerical equivalence to the baseline DMA for forward and backward across large-window causal settings and GQA.
  • Performance (100 runs, bf16):
    • Config: batch=2, num_heads=16, num_kv_heads=8, query_len=8192, key_len=8192, head_dim=128, window_size=2048
    • Baseline DMA (triton) fwd+bwd: 29.750222ms ± 0.265306ms
    • Triton special fwd+bwd: 18.768802ms ± 0.224953ms
    • Speedup: ~1.59×
  • Minimal runnable example:
    import torch
    from flash_dmattn.flash_dmattn_triton_special import triton_dmattn_func
    from flash_dmattn.utils.mask import topk_indices
    
    device = 'cuda'
    dtype = torch.bfloat16
    batch, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size = 2, 16, 8, 8192, 8192, 128, 2048
    
    query = torch.randn(batch, num_heads, query_len, head_dim, device=device, dtype=dtype, requires_grad=True)
    key = torch.randn(batch, num_kv_heads, key_len, head_dim, device=device, dtype=dtype, requires_grad=True)
    value = torch.randn(batch, num_kv_heads, key_len, head_dim, device=device, dtype=dtype, requires_grad=True)
    attn_bias = torch.randn(batch, num_kv_heads, key_len, device=device, dtype=dtype, requires_grad=True)
    attn_indices = topk_indices(attn_bias, window_size)
    
    out = triton_dmattn_func(query, key, value, attn_bias, attn_indices, is_causal=True)
    out.sum().backward()
    

Documentation

  • API reference: Add triton_dmattn_func to the English API docs, including input shapes, dtype constraints, and notes on attn_indices.
  • Integration guide: Brief section showing how to compute attn_indices (e.g., via topk_indices) and switch between baseline and special Triton path.
  • Performance notes: Document typical speedups and constraints (head_dim ≤ 128, bf16/fp16).

Checklist

  • Linked issue provided
  • API stabilised
  • Tests added or updated
  • Docs added or updated
  • No known performance regressions

Enables fused Triton forward/backward paths for dynamic masked attention to reduce padding overhead and deliver faster windowed attention execution.
Introduces reusable top-k extraction on the bias tensor to simplify downstream mask logic.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a specialized Triton implementation for flash dynamic masked attention along with a utility function to extract top-k indices from attention bias. The implementation introduces a gather-based approach where attention is computed only on a subset of key-value pairs selected by top-k indices.

Key changes:

  • Added topk_indices utility function to extract and sort top-k indices from attention bias
  • Implemented a new Triton-based flash attention variant that uses gathered K/V/bias values
  • Added preprocessing, forward, and backward kernels for the specialized implementation

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 9 comments.

File Description
flash_dmattn/utils/mask.py Added topk_indices function to compute sorted top-k indices from attention bias
flash_dmattn/flash_dmattn_triton_special.py New file implementing specialized Triton kernels for flash dynamic masked attention with gather-based optimization

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

(batch_size, num_kv_heads, key_len).
window_size (int): The number of top elements to consider for the mask.
**kwargs: Additional keyword arguments.
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace should be removed.

Suggested change

Copilot uses AI. Check for mistakes.
mask=valid_idx,
other=0.0,
)

Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace should be removed.

Suggested change

Copilot uses AI. Check for mistakes.
mask=(start_n + offs_n) < window_size,
other=0.0,
)

Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace should be removed.

Suggested change

Copilot uses AI. Check for mistakes.

# Compute dp
dp = tl.dot(do, tl.trans(v))

Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace should be removed.

Suggested change

Copilot uses AI. Check for mistakes.
)

# We could have padded the head dimension
dq = dq[..., : do.shape[-1]]
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace should be removed.

Suggested change
dq = dq[..., : do.shape[-1]]
dq = dq[..., : do.shape[-1]]

Copilot uses AI. Check for mistakes.
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
db = tl.zeros([BLOCK_N], dtype=tl.float32)

# Load k and v, them will stay in SRAM throughout
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'them' to 'they'.

Suggested change
# Load k and v, them will stay in SRAM throughout
# Load k and v, they will stay in SRAM throughout

Copilot uses AI. Check for mistakes.
acc_s += tl.where(m, 0, float("-inf"))

# Compute p
m_ij = tl.maximum(tl.max(acc_s, 1), lse_i)
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name lse_i is initialized with negative infinity at line 359. When all mask elements are False and any_active prevents computation, lse_i remains negative infinity throughout. At line 516, this would store negative infinity values. Consider adding a comment explaining this edge case behavior or adding a guard to handle empty masks explicitly.

Suggested change
m_ij = tl.maximum(tl.max(acc_s, 1), lse_i)
# Guard against the case where all mask elements are False and lse_i remains -inf.
if tl.all(~m):
# All elements are masked out; set m_ij to 0 (or another safe value).
m_ij = tl.zeros([acc_s.shape[0]], dtype=acc_s.dtype)
else:
m_ij = tl.maximum(tl.max(acc_s, 1), lse_i)

Copilot uses AI. Check for mistakes.

lse_i = tl.load(LSE + offs_m_curr)
# p = tl.exp(acc_s - lse_i[:, None])
p = tl.exp(acc_s - tl.where(lse_i > float("-inf"), lse_i, 0.0)[:, None])
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guard handles the case when lse_i is negative infinity, but the forward kernel at line 463 doesn't have the same protection. Consider applying consistent handling of negative infinity LSE values in both forward and backward passes, or document why they differ.

Copilot uses AI. Check for mistakes.
assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA"
assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128"
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
seqlen_k_rounded = math.ceil(seqlen_k / 128) * 128
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable seqlen_k_rounded is not used.

Suggested change
seqlen_k_rounded = math.ceil(seqlen_k / 128) * 128
# seqlen_k_rounded = math.ceil(seqlen_k / 128) * 128 # Removed unused variable

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 926bb35 into main Nov 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants