perf: page-gather decode microbench and paged-attention ADR (#117)#145
Conversation
Phase 0 feasibility spike for epic #116 (unified paged KV cache). Measures the decode-time cost of gathering scattered physical KV blocks on MLX so the downstream phases can choose between (A) gather-then-SDPA and (B) a fused Metal paged-attention kernel, and pick the pool tensor layout. `examples/page_gather_microbench.rs` is a synthetic, op-level bench (no model load). It allocates fake K/V with `zeros` and times four decode-step paths across a context/batch/block-size sweep: contiguous SDPA (the lower bound and the effective cost of the current `paged_decode_attention_dense_compat` path), gather-then-SDPA for two pool layouts (`[num_blocks, block_size, n_kv_heads, head_dim]` and the head-split `[n_kv_heads, num_blocks, block_size, head_dim]`), gather-only for each layout to isolate gather cost, and the per-step `slice_update` block append for each layout. Block ids are assigned in reverse pool order over a 2x-slack pool to force genuinely scattered reads, and every path attends over a block-aligned `ctx_pad` so the comparison is apples to apples; the reported `frag%` is the block-size internal fragmentation. Output is both an aligned human-readable table and a `CSV:`-prefixed machine-readable block. The timing harness mirrors `examples/bridge_overhead_microbench.rs` (warmup, then eval-per-iter bracketed by `synchronize_default`). `docs/adr/0001-paged-attention-gather-vs-fused-kernel.md` records the decision: adopt (A) gather-then-SDPA for Phases 1-5 (#118-#122) and defer the (B) fused Metal kernel to Phase 6 (#123), keep the existing default block size, and pick the pool layout from the measured `take`/`slice_update` numbers. The empirical values (crossover context length, layout choice, hardware, results table) are left as `<!--FILL_...-->` sentinels for the spike machine to fill after running the bench. Also establishes `docs/adr/` with an index `README.md`, links it from `docs/README.md`, and adds `scripts/run_page_gather_microbench.sh` (runs the bench under `caffeinate -i`).
Ran examples/page_gather_microbench.rs on the spike machine (Apple M1 Ultra, 128 GB, macOS 26.5, --release --features metal,accelerate) and filled the ADR 0001 sentinels with the measured results: the 24-row table (per-cell minimum of two sweeps), the hardware line, the gather-overhead crossover, the pool layout decision, and the block-size note. Findings: layout A ([num_blocks, block_size, n_kv_heads, head_dim]) is on average 2.1x faster on gather-then-SDPA than the head-split layout, so it is the chosen pool layout, and slice_update block-append cost is layout-insensitive. Single-sequence gather overhead stays under ~15% below 4096 tokens, rising to ~56% at 16384 and ~67% at 32768, while batched decode (batch 4) is already ~48% at 1024 tokens and 2x to 3x the contiguous SDPA cost past 4096. This confirms (A) gather-then-SDPA for Phases 1-5 and keeps the fused Metal kernel (B, #123) deferred to the long-context or batched regime. Also applies a rustfmt fix to the example.
Add a note clarifying that gatherA_only can exceed gatherA_sdpa at short context because timing the gather alone forces a full K/V materialization, whereas the gather-then-SDPA path fuses take/reshape/transpose into the fused-SDPA read. This makes gatherA_sdpa the decode-relevant number and explains why strategy (A) stays cheap at common context lengths.
Implementation Review SummaryIntent
Findings AddressedNone. No findings at MEDIUM or above; nothing to auto-fix. Verification
|
Adds a #[cfg(test)] mod to examples/page_gather_microbench.rs covering the three pure, non-GPU helpers: parse_usize_list (happy path, whitespace tolerance, trailing comma, empty input), per_call_us (round and fractional durations), and the ctx_pad/frag_pct math from run_config (exact-multiple → 0% fragmentation, non-multiple → correct pad and frag, block=1 degenerate case, ctx==block edge case). All 11 tests pass under `cargo test --example page_gather_microbench --features metal,accelerate`. No GPU arrays are constructed in the tests.
PR Finalization CompleteTestsAdded
All 11 tests pass under DocumentationAll cross-links verified:
No CHANGELOGSkipped. The CHANGELOG has no Lint / Format
|
Summary
Phase 0 feasibility spike for epic #116 (unified paged KV cache). Adds a synthetic op-level microbench that measures the decode-time cost of gathering scattered physical KV blocks on MLX, plus an ADR that uses those measurements to choose the paged-attention strategy and the KV pool tensor layout for the downstream phases.
This PR delivers the code, the docs, and the ADR with empirical-value sentinels. The bench is run on the spike machine and the measured numbers fill the sentinels in a follow-up (the MLX C++ link is a long cold build, so compiling and running the bench is handled separately from authoring).
What changed
examples/page_gather_microbench.rs: new synthetic, op-level decode microbench (no model load). Times contiguous SDPA (lower bound, the effective cost of the currentpaged_decode_attention_dense_compatpath), gather-then-SDPA for two candidate pool layouts, gather-only per layout (to isolate gather cost from attention), and the per-stepslice_updateblock append per layout. Sweeps context lengths 1024/4096/16384/32768, batch 1/4, block 16/32/64 at D=128, Hq=32, Hkv=8, f16. Forces scattered reads via reverse-order block ids over a 2x-slack pool, pads context to a block-alignedctx_padso all paths compare apples to apples, and reports block-size internal fragmentation. Emits an aligned human-readable table and aCSV:-prefixed machine-readable block. Timing harness mirrorsexamples/bridge_overhead_microbench.rs.docs/adr/0001-paged-attention-gather-vs-fused-kernel.md: new ADR. Decides (A) gather-then-SDPA for Phases 1-5 (Phase 1: Global block-pool tensor storage #118-Phase 5: Block-budget admission, eviction, and preemption #122), defers the (B) fused Metal paged-attention kernel to Phase 6 (Phase 6: Fused Metal paged-attention kernel #123), keeps the existing default block size, and selects the pool layout from the measuredtake/slice_updatenumbers. Empirical values are<!--FILL_...-->sentinels.docs/adr/README.md: new ADR index, establishingdocs/adr/.docs/README.md: links the new ADR directory from the docs layout.scripts/run_page_gather_microbench.sh: wrapper that runs the bench undercaffeinate -i.Test plan
cargo run --release --features metal,accelerate --example page_gather_microbench(run on the spike machine; numbers fill the ADR sentinels).cargo build --release --features metal,accelerate --example page_gather_microbenchcompiles.Closes #117