v0.3.0: Autotune & dispatch overhaul across CUDA kernels + faster training on Apple Silicon
Added
- Training on Apple Silicon. New
maxsim_train_metal(forward +
saved argmax) andmaxsim_backward_metalMetal kernels mirror the
Tritonmaxsim_backward_unifiedAPI; a_MaxSimFnMetal(autograd. Function)wires them intomaxsim_mpsso the full training path
(forward, backward, L2-normalize Jacobian) runs on Metal instead of
falling back totorch.compile. Inference and unsupported-dtype
paths are unchanged. - KD / pairs layout on Metal (4-D
D). The MPS Metal kernel now
acceptsDas[Nq, K, Ld, d]directly — PyLate's
colbert_kd_scoresandcolbert_scores_pairwiseshapes hit the
Metal kernel instead of falling back totorch.compile. maxsim()4-D dispatch +maxsim_pairs()on CUDA.maxsim(Q, D)
now dispatches onD.dim(): 3-D stays on the in-batch cross-product
path, 4-D[Nq, K, Ld, d]runs as a single fused KD launch. New
maxsim_pairs(Q, D)covers the[B, Lq, d] × [B, Ld, d] → [B]
case. PyLate'scolbert_kd_scoresPythonfor-loop (one
maxsim()call per query) collapses to one kernel launch —
~10× faster than the loop atB=64, K=32and the pairwise
B=5000shape went from 355 ms (oldmaxsim_varlenpacking
path) → 0.18 ms (beatsflash-maxsim).- Every benchmark now reports peak VRAM (
max_memory_allocated()) per
variant in stdout and JSON;bench_fp8.pyandbench_fused_head_train.py
also gained--outdirJSON + Markdown sidecars.benchmarks/README.md
documents the unified CLI and the SkyPilot driver.
Changed
- CUDA autotune & dispatch overhaul — broadly faster across kernels.
Five independent wins on the Triton path:Lqbucketed to next pow-of-2 in themaxsim()wrapper
(#62). Variable-Lqtraining (ColBERT / ColPali, where the
tokenizer's per-batchmax(Lq)floats step-to-step) used to
re-trigger the full autotune sweep on every novelLq. Now
collapses to ≤ 9 cache entries. Measured 4.7× faster
end-to-end on 30 H100 steps withLq ∈ [8, 32](median
step 588 ms → 0.19 ms).- KD + pairwise folded into the fast forward + backward
kernels (#66). A singlekd_layout: tl.constexprswitches
d_global = pid % Nd(in-batch) vspid(KD / pairs) so
PyLate'scolbert_kd_scores(4-DD) and
colbert_scores_pairwiseshapes use the same dense fast path
as in-batch instead of routing throughscore_pairs_packed's
packing layer. KDB=64, K=32now beatsflash-maxsim
(lik/flash 0.94×), down from 1.4–7.7× slower pre-PR. - Persistent on-disk autotune cache via Triton ≥ 3.4's
cache_results=Trueon all eight autotuned kernels (#64).
First run on a machine still pays the ~4 s sweep; every
subsequent process / CI job / training restart loads the JSON
winner and skips bench. 10.6× faster cold start on the
second process. Feature-detected — older Triton (3.0–3.3)
silently keeps the in-memory-only behaviour, no dependency
floor bump. - Small-input forward bypass for
Nq*Nd ≤ 500 && d ≤ 256
withsave_argmax=False(#64). Fixed-config launch
(BLOCK_Q=32, BLOCK_D=64, num_warps=4, num_stages=2); cold
call drops from ~0.5–1 ms (autotune sweep) to sub-millisecond.
Closes the gap toflash-maxsim's_maxsim_fwd_kernel_small
on REPL / unit-test shapes. - New Hopper autotune config
BLOCK_Q=128, BLOCK_D=128, num_warps=8, num_stages=3(#57). Closes a 0.85× regression
vsflash-maxsimon the compute-bound colpali rerank shape
(Nq=1, Nd=500, Lq=Ld=1024, d=128— now 1.23×). Only picked
on the shape it was designed for; no regression elsewhere. normalizeout of the forward autotune key (#64). The two
constexpr branches still produce distinct binaries; they now
share one autotune entry instead of two. Cache cardinality
halves.
- Benchmark CLI unified. Experiment subsets are now
--only NAME ...
on every script (replacing the legacy--shape/--shapesflags and
the older--onlyfor variant selection, which moved to--variants).
scripts/sky_run_all_benchmarks.yamland the per-domain Sky yamls
accept aRUN_ONLYenv to pick a subset of tags. - SkyPilot bench yamls consolidated. Four operator-facing files now
cover every bench run:sky_benchmark_smoke_test.yaml(was
sky_run_benchmarks.yaml),sky_run_all_benchmarks.yaml,
sky_pylate_benchmark.yaml(new, folds the three previous
sky_lateon_edge.yaml/sky_pylate_realdata{,_long}.yaml), and
sky_colpali_benchmark.yaml(wassky_colpali_training.yaml). - MPS range refreshed (M4 2025). Inference
metal vs eageris now
1.9–3.5× andmetal vs compile2.2–14.3× (vs 1.9–3.2× and
1.1–2.0× in 0.2.0); the gap vstorch.compilewidened because MPS
Inductor regressed sharply on long-Ldinputs. New training (fwd +
bwd) table lands alongside, with the Metal backward 1.2–1.7× over
eager and 3–4× overtorch.compileonce shapes amortise launch
overhead. Full tables indocs/benchmarks.mdApple Silicon section. - JSON peak-VRAM keys standardized to
<variant>_peak_mbacross
bench_pylate_lateon,bench_pylate_realdata,bench_colpali_training,
bench_colpali_realdata,bench_cached_maxsim,bench_fastplaid, and
bench_lateon. Breaking for anyone parsingbenchmarks/results/*.json
directly (previously a mix ofpeak_gb/_peak/mem_*_MB). - Removed unreferenced exports
pylate_compat._bool_mask(shadowed by
_mask_as_bool) andmps.is_mps_tensor. - GPU CI moved from the GitHub-hosted runner to AWS CodeBuild (A10G) and
no longer auto-runs on push tomain; opt-in via therun-gpu-tests
PR label orworkflow_dispatch.
Removed
-
[breaking] Experimental kernels.
late_interaction_kernels.experimental
and its three research variants (soft_maxsim,smooth_maxsim,
maxsim_matryoshka) are gone, along withreference.maxsim_reference_soft,
tests/test_{soft,smooth,matryoshka}.py, and the two soft-maxsim cases
intests/test_robustness.py. None of them shipped to PyLate,
colpali_engine, FastPlaid, or NextPlaid; folding research kernels into
prod was the same mistake asmaxsim_xtrin 0.2.0. Users on a research
path can vendor the kernel source from the pre-0.3.0 git history. -
[breaking] Deprecated
*_inferenceshims andmaxsim_from_hidden_train.
The fourDeprecationWarningshims from 0.2.0 are removed:late_interaction_kernels.maxsim_inference→maxsim(...)late_interaction_kernels.fused_head.maxsim_from_hidden_train→
maxsim_from_hidden(...)late_interaction_kernels.varlen.maxsim_varlen_inference→
maxsim_varlen(...)late_interaction_kernels.plaid.maxsim_residual_inference→
maxsim_residual(...)
Each surviving function already auto-skips the saved argmax buffer
when no input hasrequires_grad=True, so behaviour is unchanged. -
[breaking]
set_backward_method/get_backward_methodremoved
(deprecated in 0.2.0). Migration: replaceset_backward_method("csr")
withmaxsim(..., backward="csr")(orMaxSimScorer(backward="csr")).
maxsim()'sbackward=Nonenow resolves directly to"auto"instead
of reading a module-level global. -
[breaking]
reference.xtr_referenceand its three CPU-only tests in
tests/test_reference_cpu.py— the XTR Triton kernel was already
deleted in 0.2.0.
Fixed
- Masked-row gradient poisoning (correctness). When every doc token
was masked for a(query, doc)pair, the Triton forward saved an
argmax of0instead of a sentinel, and the unified / atomic
backwards atomic-added a spuriousgrad_scores[i, j] * Q[i, s, :]
intograd_D[d_global, 0, :]. Forward now initialises the running
argmax to-1; the unified / atomic backwards skip the scatter on
t < 0. CSR backward is unaffected (sorting naturally drops the
sentinel). Matches the equivalent fix on the MPS Metal backward.
Known regressions (queued for 0.3.1)
Four shape-specific perf regressions vs 0.2.0 on the H100 sweep
(NGC 25.06 / torch 2.8 / triton 3.x). All correctness assertions hold;
the likely root cause for each is a winning autotune config rejected
by the tighter prune_forward SRAM model (#73).
bench_forwardtext-long(Nq=1, Nd=1k, Lq=32, Ld=1024):
0.093 → 0.121 ms (+30 %).bench_inference_edgeLateOn-Code-edge Nd=1k, Ld=1024, d=48:
0.072 → 0.114 ms (+58 %).bench_pylate_realdataContrastive bs=16, Lq=32, Ld=256: e2e step
52.3 → 61.6 ms (+18 %); vanilla baseline unchanged.bench_pylate_realdataCachedContrastive bs=64, mini=16, Lq=32, Ld=300, grad-ckpt: e2e step 305.5 → 318.5 ms — former 1.15× win
vs vanilla collapsed to 1.00×. Highest-priority of the four.