For the usage of DashAttention kernels and running the example, please run the following script:
pip install -e .
For benchmark environment setup, please refer to each corresponding folder.
The dash attention's interface can be used as follows:
queries = torch.randn(batch, query_heads, seq_len, head_dim, device=device, dtype=dtype).contiguous()
keys = torch.randn(batch, kv_heads, seq_len, head_dim, device=device, dtype=dtype).contiguous()
values = torch.randn(batch, kv_heads, seq_len, head_dim, device=device, dtype=dtype).contiguous()
head_cls = torch.randn(kv_heads, head_dim, device=device, dtype=dtype).contiguous()
model = dash_attn(
chunk_size=chunk_size,
enable_gqa=True,
estimate_diagonal=True,
return_active_blocks=True,
)
out, active_blocks = model(queries, keys, values, head_cls)We also provide an example on how to use DashAttention in Llama-architecture models in here.
python ./example/run_niah.py
DashAttention implements the attention mechanism introduced in DashAttention: Differentiable and Adaptive Sparse Hierarchical Attention. The method replaces fixed-budget top-k block routing with an adaptive, differentiable sparse router, then refines the selected regions with token-level softmax attention.
The implementation follows the three-stage hierarchy described in the paper:
- Local chunk summarization:
dash_attn.prefill.summarize_chunkanddash_attn.decoding.summarize_chunkbuild one learned key summary per KV chunk. - Entmax block routing:
score_blockscomputes sparse chunk supports and routing priors from query-to-summary scores. - Prior-induced sparse softmax:
full_attnapplies token-level attention only over routed chunks, using the Stage 1 prior to preserve differentiability through the hierarchy.
The public kernel wrapper is dash_attn.dash_attn_interface.dash_attn. It supports both prefill and decoding: prefill summarizes the current sequence and stores complete chunk summaries, while decoding reuses the chunk-summary cache and appends newly completed chunks.
from dash_attn import dash_attn
attn = dash_attn(
chunk_size=64,
enable_gqa=True,
estimate_diagonal=True,
scaling_factor=1.0,
return_active_blocks=False,
)Important arguments:
| Argument | Description |
|---|---|
chunk_size |
Number of tokens per routed KV chunk. |
enable_gqa |
Enables grouped-query attention support when query heads outnumber KV heads. |
estimate_diagonal |
Includes special handling for the current or near-diagonal chunk. |
scaling_factor |
Scales routing logits before sparse block selection; this is the main knob for sparsity. |
return_active_blocks |
Returns the number of active routed blocks per token for sparsity analysis. |
max_chunks |
Preallocated chunk-summary cache capacity used during decoding. |
sigma |
Controls the strength of the Stage 1 routing prior used by Stage 2. |
Inputs are expected in [batch, heads, seq_len, head_dim] layout for queries, keys, and values; head_cls has shape [kv_heads, head_dim].
DashAttention includes a Llama-compatible modeling implementation in dash_attn.models.llama. LlamaConfig defaults to attn_implementation="dash_attn" and adds DashAttention-specific fields such as chunk_size, estimate_diagonal, sigma, and scaling_factor.
from dash_attn.models.llama import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(
"fasa-org/MiniCPM-4-8B-DashAttention",
attn_implementation="dash_attn",
torch_dtype="auto",
)To inspect routing behavior, call the model with return_active_blocks=True, then read model.get_active_blocks().
example/run_niah.pyruns a needle-in-a-haystack style generation example and reports measured sparsity.test/test_smoke.pychecks the standalone DashAttention kernel wrapper.test/test_llama_dash_attn.pychecks the Llama integration and active-block reporting.
Run the test suite with:
pytestThe current kernels require CUDA-capable hardware.
We release our 8B models for reproducibility.
| Model | Link |
|---|---|
| 8B-FullAttn | Hugging Face |
| 8B-InfLLMv2 | Hugging Face |
| 8B-NSA | Hugging Face |
| 8B-DashAttention | Hugging Face |
- Performance: Please refer to README.
This project is released under the BSD-3-Clause License.
This repository is developed with the aid of RULER, OLMES, InfLLMv2, and NSA-triton.
@article{dash-attention,
title={DashAttention: Differentiable and Adaptive Sparse Hierarchical Attention},
author={Huang, Yuxiang and Gon{\c{c}}alves, Nuno M. T. and Alvetreti, Federico and Li, Lei and Han, Xu and Ponti, Edoardo M. and Martins, Andr{\'e} F. T. and Treviso, Marcos V.},
journal={arXiv preprint arXiv:2605.18753},
year={2026}
}