Skip to content

[FEATURE REQUEST] Support for Sparse Attention via attn_indices #210

@LoserCheems

Description

@LoserCheems

Is your feature request related to a problem? Please describe.
In some scenarios we want to run attention over a fixed sparse pattern (e.g., sliding windows, block-sparse layouts, or precomputed neighbor sets) instead of a dense attention mask. Currently FSA supports dense attn_mask, which is flexible but incurs overhead when the actual attention pattern is sparse and structured.

Describe the solution you'd like
I would like to add support for an attn_indices argument with shape (batch, num_kv_heads, q_len, window), where each entry stores the key indices to attend to for a given query position and KV head. The design would be:

  • attn_indices[b, h_k, i, :] is a list of key positions for query position i in batch b and KV head h_k.
  • -1 indicates a padded / inactive position that should be ignored.
  • attn_indices and attn_mask are mutually exclusive (only one of them can be provided).
  • Causal masking is handled on the Python side by ensuring indices that violate causality are set to -1, so the kernel only needs to respect the indices and the -1 convention.

The core kernels would be extended (or new kernels added) to use these indices for gathering keys/values and computing attention only on the specified sparse positions.

Describe alternatives you've considered

  • Using dense attn_mask to encode the same sparse pattern, but this is less efficient and more memory-heavy, especially when the window size is small compared to q_len and k_len.
  • Implementing custom sparse patterns outside FSA with separate kernels, which would duplicate effort and lose the benefits of the existing Flash-style implementation.
  • Relying on external libraries for sparse attention, but integration and compatibility with the existing FSA API and kernels would be more complex.

Implementation details

  • CUDA/Triton changes:

    • Add support for an attn_indices tensor in the Triton path (flash_dmattn_triton.py or a new sparse kernel), with semantics described above.
    • Option 1: introduce separate sparse kernels (e.g., _fwd_kernel_sparse, _bwd_kernel_sparse) that operate on index lists instead of dense masks.
    • Option 2: extend the existing kernels with a HAS_INDICES tl.constexpr flag and additional arguments, though this may be more complex for autotuning and maintenance.
    • Backward needs to correctly accumulate gradients into dk/dv and handle num_heads != num_kv_heads (GQA/MQA) using the indices.
  • Python API:

    • Expose an attn_indices argument in the Triton-based attention function with shape (batch, num_kv_heads, q_len, window).
    • Enforce that attn_indices and attn_mask are mutually exclusive.
    • Perform causal preprocessing in Python: any index that violates causality (e.g., key_index > query_index for causal attention) is set to -1.
    • Optionally support simple broadcasting along batch or head dimensions, similar to how attn_mask/bias are handled.
  • Performance implications:

    • For small windows (e.g., local attention), this should reduce both compute and memory compared to dense masks, especially at long sequence lengths.
    • The cost of gathering keys/values via indices needs to be evaluated, but with a reasonably small window this should be competitive.
    • Autotuning configurations may need to be revisited for the sparse kernels.
  • Compatibility concerns:

    • The new feature should work with both fp16 and bf16, and respect the existing constraints on head dimension.
    • Need to validate behavior on different GPU architectures supported by FSA’s kernels.

Use case

  • Typical sequence lengths: long-context settings (e.g., 4k–32k tokens), where full dense attention is expensive.
  • Target applications: long document processing, code generation, and any model using local or block-sparse attention patterns.
  • This feature would allow users to define custom sparse patterns (local windows, dilated windows, neighborhood-based attention, etc.) while still leveraging FSA’s optimized kernels, improving both speed and memory efficiency.

Additional context

  • The proposed attn_indices design keeps the kernel logic relatively simple by delegating causal handling and invalid indices to the Python side (-1 convention), and avoids combining multiple masking mechanisms inside the kernel.
  • It aligns well with common sparse attention patterns used in long-context transformers.

Related work

  • Various long-context transformer architectures (e.g., Longformer-style sliding windows, block-sparse attention) rely on precomputed local or structured indices.
  • This feature would make it easier to map such designs onto FSA’s kernels without relying on dense masks.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions