"Some attention is all you need." — apologies to Vaswani et al. (2017)
A training-free inference-time procedure that reduces attention compute in pretrained transformers by an order of magnitude while preserving language-modeling quality. CSA measures per-(layer, head) attention diffuseness on a small calibration pass over the dense model, allocates per-cell key budgets proportional to that diffuseness, and applies the allocation as a per-query top-k mask at inference. No fine-tuning, no labels, no gradients, no architecture changes.
| Configuration | Dense ppl | % of dense reads at dense-matching quality | Per-position max Δ |
|---|---|---|---|
| GPT-2 small / 1024 ctx / WikiText | 26.49 | 5.7% | ±1.2 ppl |
| Pythia 410M / 1024 ctx / WikiText | 16.26 | 9.6% | +0.13 ppl |
| Pythia 410M / 2048 ctx / WikiText | 15.27 | 9.3% | ±0.07 ppl |
| Pythia 410M / 1024 ctx / Code | 3.78 | 9.4% | ±0.01 ppl |
| Qwen 2.5 1.5B / 1024 ctx / WikiText | 8.76 | 9.9% | ±0.02 ppl |
| Llama 3.2 1B / 1024 ctx / WikiText | 9.71 | 9.4% | ±0.04 ppl |
| Llama 3.2 1B / 2048 ctx / WikiText | 9.12 | 8.9% | ±0.06 ppl |
| Llama 3.2 3B / 1024 ctx / WikiText | 7.79 | 9.6% | ±0.04 ppl |
Across four model families (GPT-2, GPT-NeoX, Qwen2, Llama), five scales (124M, 410M, 1B, 1.5B, 3B parameters), two context lengths, two corpus types, and two attention architectures (multi-head and grouped-query), CSA reaches dense quality at 5.7–9.9% of dense attention reads per query with quality preserved uniformly across every position in the context window.
See docs/whitepaper.md for the full writeup, including the resource-translation analysis (FLOPs / KV bandwidth / KV memory / end-to-end speedup) and comparisons against learned routing, BigBird-style trained-in sparse attention, and KV eviction methods.
git clone <repo-url> csa
cd csa
uv syncRequires Python ≥ 3.11. Tested on macOS (Apple Silicon, MPS) and Linux.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import csa
# Load any HuggingFace causal LM with eager attention + fp32
model = AutoModelForCausalLM.from_pretrained(
"unsloth/Llama-3.2-1B", # or any GPT-2 / GPT-NeoX / Qwen2 / Llama / Mistral model
attn_implementation="eager",
dtype=torch.float32,
).eval()
tok = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B")
# Calibrate once: measure diffuseness, allocate budgets, pick cap
calibration_text = "..." # ~10k tokens of representative text
eval_text = "..." # ~2k tokens for the cap sweep
allocation, k_max = csa.fit(
model,
tok(calibration_text, return_tensors="pt").input_ids[0],
tok(eval_text, return_tensors="pt").input_ids[0],
)
# Save the policy if you want to reload it later
csa.save_policy(allocation, k_max, "policy.json")
# Use the model with the policy active
prompt_ids = tok("The future of artificial intelligence", return_tensors="pt").input_ids
with csa.apply_policy(model, allocation, k_max):
output = model.generate(prompt_ids, max_new_tokens=40)
print(tok.decode(output[0]))A complete end-to-end example with comparisons against dense generation is in examples/csa_example.py.
- Training-free. No gradients, no labels, no fine-tuning. Works on closed-weight models.
- Model-agnostic procedure. The four steps run on any HuggingFace transformer with an eager attention path. Verified on GPT-2, Pythia (GPT-NeoX), Qwen 2.5 (GQA), and Llama 3.2.
- Cheap to apply. Tens of seconds for calibration; tens of minutes for the cap sweep; the deployed policy is a small JSON file of
L × H + 1integers. - Composable. Stacks with KV-cache quantization (e.g. KIVI), speculative decoding, and permanent KV eviction (StreamingLLM / H2O).
- Quality preserved uniformly across context positions — not just on average. Long-range queries are served as well as short-range queries.
CSA does not provide a wall-clock speedup on its own — turning the FLOPs reduction into latency wins requires a sparse-attention kernel (Triton or FlashAttention block-sparse). For memory footprint reduction, CSA's per-(layer, head) caps compose with a permanent eviction policy. See the whitepaper's resource-analysis section for the explicit measured-vs-projected accounting.
src/csa/
__init__.py + api.py — the public library (allocate, cap_sweep, apply_policy, fit)
sparse_attention.py — attention-forward patches with top-k / top-p support
attention_stats.py — per-(layer, head) diffuseness measurement
eval.py — sliding-window perplexity + corpus loaders
oracle.py — per-query k* recording for router-training experiments
router.py — learned router experiment (matched but did not surpass CSA)
runlog.py — reproducibility metadata embedded in result JSONs
examples/
csa_example.py — end-to-end usage demonstration
notebooks/ — reproducible experiment scripts (numbered 01–16)
results/ — JSON + PNG outputs from every experiment in the paper
docs/
whitepaper.md — the paper
hypothesis.md — framing context
experiments.md — chronological experiment log
All experiments in the paper are reproducible from this repository. Each script writes both raw JSON results and a plot to results/, with run metadata (git commit, command, package versions) embedded.
# Calibration + per-layer attention statistics (~30 sec)
uv run python notebooks/03_attention_stats.py
# Cap sweep (~5 min on Mac MPS)
uv run python notebooks/13_cap_sweep.py --smoke
# Cross-model validation (each ~15-30 min on Mac MPS)
uv run python notebooks/15_pythia_validation.py --smoke
uv run python notebooks/15_pythia_validation.py --smoke --model Qwen/Qwen2.5-1.5BDetailed results for every experiment are in docs/experiments.md.
MIT.
