-
Notifications
You must be signed in to change notification settings - Fork 39
Description
Is your feature request related to a problem? Please describe.
In some scenarios we want to run attention over a fixed sparse pattern (e.g., sliding windows, block-sparse layouts, or precomputed neighbor sets) instead of a dense attention mask. Currently FSA supports dense attn_mask, which is flexible but incurs overhead when the actual attention pattern is sparse and structured.
Describe the solution you'd like
I would like to add support for an attn_indices argument with shape (batch, num_kv_heads, q_len, window), where each entry stores the key indices to attend to for a given query position and KV head. The design would be:
attn_indices[b, h_k, i, :]is a list of key positions for query positioniin batchband KV headh_k.-1indicates a padded / inactive position that should be ignored.attn_indicesandattn_maskare mutually exclusive (only one of them can be provided).- Causal masking is handled on the Python side by ensuring indices that violate causality are set to
-1, so the kernel only needs to respect the indices and the-1convention.
The core kernels would be extended (or new kernels added) to use these indices for gathering keys/values and computing attention only on the specified sparse positions.
Describe alternatives you've considered
- Using dense
attn_maskto encode the same sparse pattern, but this is less efficient and more memory-heavy, especially when the window size is small compared toq_lenandk_len. - Implementing custom sparse patterns outside FSA with separate kernels, which would duplicate effort and lose the benefits of the existing Flash-style implementation.
- Relying on external libraries for sparse attention, but integration and compatibility with the existing FSA API and kernels would be more complex.
Implementation details
-
CUDA/Triton changes:
- Add support for an
attn_indicestensor in the Triton path (flash_dmattn_triton.pyor a new sparse kernel), with semantics described above. - Option 1: introduce separate sparse kernels (e.g.,
_fwd_kernel_sparse,_bwd_kernel_sparse) that operate on index lists instead of dense masks. - Option 2: extend the existing kernels with a
HAS_INDICEStl.constexprflag and additional arguments, though this may be more complex for autotuning and maintenance. - Backward needs to correctly accumulate gradients into
dk/dvand handlenum_heads != num_kv_heads(GQA/MQA) using the indices.
- Add support for an
-
Python API:
- Expose an
attn_indicesargument in the Triton-based attention function with shape(batch, num_kv_heads, q_len, window). - Enforce that
attn_indicesandattn_maskare mutually exclusive. - Perform causal preprocessing in Python: any index that violates causality (e.g.,
key_index > query_indexfor causal attention) is set to-1. - Optionally support simple broadcasting along batch or head dimensions, similar to how
attn_mask/bias are handled.
- Expose an
-
Performance implications:
- For small windows (e.g., local attention), this should reduce both compute and memory compared to dense masks, especially at long sequence lengths.
- The cost of gathering keys/values via indices needs to be evaluated, but with a reasonably small
windowthis should be competitive. - Autotuning configurations may need to be revisited for the sparse kernels.
-
Compatibility concerns:
- The new feature should work with both fp16 and bf16, and respect the existing constraints on head dimension.
- Need to validate behavior on different GPU architectures supported by FSA’s kernels.
Use case
- Typical sequence lengths: long-context settings (e.g., 4k–32k tokens), where full dense attention is expensive.
- Target applications: long document processing, code generation, and any model using local or block-sparse attention patterns.
- This feature would allow users to define custom sparse patterns (local windows, dilated windows, neighborhood-based attention, etc.) while still leveraging FSA’s optimized kernels, improving both speed and memory efficiency.
Additional context
- The proposed
attn_indicesdesign keeps the kernel logic relatively simple by delegating causal handling and invalid indices to the Python side (-1convention), and avoids combining multiple masking mechanisms inside the kernel. - It aligns well with common sparse attention patterns used in long-context transformers.
Related work
- Various long-context transformer architectures (e.g., Longformer-style sliding windows, block-sparse attention) rely on precomputed local or structured indices.
- This feature would make it easier to map such designs onto FSA’s kernels without relying on dense masks.