Releases: hcompai/late-interaction-kernels
v0.4.4: Perf roundup — one-shot norm check, shape-stable compiles, autotuned backwards
Fixed
maxsim(normalize=False)ran a.item()norm-check device sync on every
call; it now runs once per process, restoring CUDA-graph capture on the
PyLate / colpali-engine hot path.
Changed
- PLAID centroid codes are handled as int32 end to end (any integer dtype is
still accepted; out-of-range codes inplaid_approx_scoreclamp to
centroid 0). - Batch sizes (
Nq,Nd) are runtime kernel arguments, not constexpr — no
more recompiles per batch shape under dynamic batching. - GPU family detection is keyed on compute capability, not the device name;
Blackwell now gets first-class autotune configs. - The varlen, packed-pairs and residual backward launches are autotuned like
the dense backwards (up to 1.42× on H100 at training shapes). - Backward gradient buffers that the kernels overwrite in full use
torch.emptyinstead oftorch.zeros(atomic-scatter buffers stay zeroed). - The fp8 autotune key no longer splits on mask presence.
Full rationale and H100 measurements in #111.
v0.4.3: Bug-fix roundup — chunked top-k, empty batches, varlen autotune
Fixed
- Chunked top-k no longer crashes when
top_kexceeds the merge width.
retrieve(Q, D, top_k=10, chunk=4)(and the CUDAmaxsim_topkit wraps)
raised "selected index k out of range": each chunk contributes at most
chunkcolumns, but the running merge always askedtorch.topkfork.
The merge now clampskto the merged width; the final output still
honors the documentedmin(top_k, Nd)contract on both paths. MaxSimScorer.forward()respects a caller'storch.no_grad(). The
non-inference branch of_scoreforcedtorch.enable_grad(), so
MaxSimScorer()(Q, D)insidetorch.no_grad()returned a tensor with
requires_grad=True. It now uses a null context and leaves the caller's
grad mode untouched.maxsim_residual_varlenhandles an empty corpus (Nd == 0). The
launcher used gridNq * max(Nd, 1)while the kernel computed
pid // Ndwith a constexprNd=0. The empty[Nq, 0]result is now
returned before any launch.pack_paddedhandles an empty batch (B == 0). It crashed on
qlen.max()of a zero-element tensor; it now returns the trivially-empty
PackedBatchup front, somaxsim_paddedagrees with the CPU reference
([0, C]).- PyLate legacy
mask=is forwarded on the fallback path. The patched
colbert_scores/colbert_kd_scores/colbert_scores_pairwisemapped
a legacymask=into the fused path but called the original function
with a still-Nonedocuments_maskwhen deferring to PyLate (CPU,
sub-Ampere,LIK_DISABLE=1), silently dropping the doc mask. maxsim_backward_lowmemraises a cleanRuntimeErrorwithout
Triton/CUDA instead of failing inside the launch — same guard as
maxsim_backward_unified, sobackward/__init__'s import-anywhere
promise holds.prune_forwardfalls back to the smallest-footprint configs. When
every autotune config overflowed the shared-memory budget, the fallback
returned the first two list entries — configs just pruned for exceeding
the budget. It now returns the two with the smallest estimated SMEM
footprint.- The KD path (
maxsimwith 4-DD) validates device agreement for
Q/D/ masks like the 3-D path, instead of surfacing an internal
kernel error on mixed devices. - FP8 fallback really works without Triton.
maxsim_inference_fp8's
documented "transparent fallback" imported the Triton-backedmaxsim
unconditionally; it now dispatches likeretrieve(Triton on CUDA,
compiled reference on MPS, eager reference elsewhere), so the dequantized
bf16 fallback runs on CPU-only installs.
Changed
maxsim_varlenautotune sweeps are amortized across batch shapes.
max_lq/max_ldwere constexpr autotune keys with no bucketing, so —
contrary to the previous docs — every distinct(max_lq, max_ld)pair
re-triggered the full autotune sweep and recompiled the kernel. They
are now exact runtime loop bounds (like the dense kernel'sLd), and
the autotune cache is keyed on power-of-two-bucketed copies (floor 16)
passed as key-only arguments: one sweep per bucket, zero masked
iterations added, argmax buffer still sized on the exactmax_lq.
Results are unchanged. The PLAID kernels (maxsim_residual/
maxsim_residual_varlen) deliberately keep their exactmax_Ldkey: it
is a property of the compressed index, stable across calls, so the sweep
amortizes to one — bucketing the loop bound there measured 30-50%
steady-state throughput loss onLd=300corpora for no benefit, which
is also why the varlen bounds stay exact.max_seqlen_*arguments are documented as hard loop bounds, not
hints, inmaxsim_varlen,score_pairs_packed, and
maxsim_residual_varlen: a too-small value silently truncated tokens and
returned wrong scores. Caller-supplied values are now checked against the
cu_seqlensmaxima with an on-devicetorch._assert_async(no D2H sync,
same trade-off aspack_padded's length checks).maxsim_residualsqueezes a 2-DQback to[Nd]instead of
returning[1, Nd], matchingmaxsim_residual_varlenand themaxsim
wrapper's convention. Behavior change for callers that relied on the
un-squeezed shape.- Removed the dead
Bconstexpr from_plaid_approx_score_kernel(it
forced a recompile per batch size), the unusedLq/Ldplaceholder
params from the varlen forward kernel, and the deadNqconstexpr from
_varlen_bwd_dQ_kernel. The varlen backward kernels also takemax_lq
as a runtime arg instead of a constexpr (mirroring thescore_pairs
backward kernels), so a new query-length maximum no longer recompiles
them. - Docstring corrections:
maxsim_inference_fp8no longer mentions argmax
ties (the inference kernel never computes an argmax), and thelowmem
backward docs say grads are written in the input dtype (fp16 / bf16 /
fp32), not "bf16 grads".
Removed
- The interactive kernel picker page (
docs/choose-a-kernel.html). It predates
the API consolidation: withmaxsimdispatching on layout and the patchers /
native PyLate & colpali-engine backends covering the framework paths, a
decision tree over many entry points no longer reflects the library. The
README "Choose a kernel" section and links are gone with it.
v0.4.2: Native LIK in PyLate and colpali-engine
Added
- Real-recipe e2e training benchmarks for ColQwen2 and PyLate
(bench_colpali_e2e.py,bench_pylate_e2e.py). Both instrument the loss
head to record per-MaxSim-call VRAM in-train, replay each recorded shape on
an isolated graph (exact forward/saved/backward brackets), and treat OOM as
a recorded sweep outcome rather than a crash;--variant vanilla|lik
toggles the patch, withsummarize_*_e2e.py+scripts/sky_*_e2e.yaml
driving the sweep (fresh process per cell). Measured on 1×H100 80 GB:
ColQwen2's MaxSim op costs 7.81 GiB vanilla vs 61 MiB with LIK at B=128
(~130×), step time at parity, and vanilla OOMs at B=128 (a 1.81 GiB request
with 25 GiB reserved-but-unallocated) where LIK trains it — 2× batch
headroom; PyLate (grad-ckpt regime) drops step peak 54.1 → 29.7 GiB at
B=512, runs 1.07–1.12× faster per step, and trains B=1024 where vanilla
OOMs. The ColQwen2 bench targets released colpali-engine 0.3.16 and shims
its twoContrastiveTrainerbugs under transformers 5.x (fixed upstream in
colpali#412, unreleased).
Tables indocs/benchmarks.md.
Changed
patch_pylate()/patch_colpali_engine()defer to the native LIK
backends. PyLate ≥ 1.5.1 (pylate#222)
and colpali-engine ≥ 0.3.17 (colpali#412)
now ship their own LIK dispatch (pip install "pylate[lik]"/
"colpali-engine[lik]", viaauto/PYLATE_SCORES_BACKEND/
COLPALI_SCORES_BACKEND). On those versions the patches are deprecated
no-ops that detect native support by package version and step aside (patching
PyLate would also breakColBERTScores, which forwardsbackend=); older
versions are unaffected. The native backends callmaxsim/maxsim_pairs
/maxsim_mpsby keyword, so those signatures are now pinned by a test.benchmarks/is grouped per comparison stack —kernels/(incl. the
platform-specificbench_mps.py),plaid/,colpali/, andpylate/, each
e2e bench next to its summarizer. Pure moves:--onlytags and JSON output
names are unchanged, so existing results stay comparable.bench_lateon.py
→kernels/bench_longdoc.py(the value is the long-document regime, Ld up
to 16 384), and thesky_run_all_benchmarks.yamlRUN_ONLYtaglateon→
longdoc.
Fixed
patch_pylate()works on PyLate 1.5 again. 1.5 renamed the scoring
module (pylate.scores.scores→pylate.scores.colbert) and rerouted the
contrastive losses throughColBERTScores; the patch now detects the
layout, patches the defining module (covering the loss path), and rewrites
onlyDistillation's import-time capture on 1.5. The pylate extra's
>=1.3.3,<2range is accurate again — no more 1.3.3 pin.
Removed
- The previous e2e training benches (
bench_colpali_training.py,
bench_colpali_realdata.py,bench_pylate_training.py,
bench_pylate_realdata.py,bench_pylate_lateon.py), their shared
_bench_common.py, and thesky_colpali_benchmark.yaml/
sky_pylate_benchmark.yamljobs — superseded by the e2e harnesses above
(bench_colpali_loss.pyis kept; historical numbers stay in
docs/benchmarks.md). Plus four stale one-offs:bench_backward_0_5.py,
bench_fastplaid.py,bench_training.py, and the autotune-persistence
reproducer (scripts/_bench_autotune_persistence.py+
scripts/sky_bench_autotune_persistence.yaml).
v0.4.1 — Mask-invariant autotune key
Fixed
Variable-length training no longer pays repeated autotune sweeps.
On ColQwen2 / ColPali training with variable query lengths, a fresh 5–10 s Triton autotune sweep fired every time a query batch first toggled its mask presence — as late as step 14, costing up to 1.6× end-to-end on vidore/docvqa_test_subsampled.
Two causes, both fixed:
has_q_mask/has_d_maskwere in the forward and backward autotune keys. They areconstexprtoggles that change codegen but not the winning(BLOCK_Q, BLOCK_D, num_warps, num_stages)tile, so they only fragmented the cache.- Triton's autotuner also keys on the dtype of every tensor argument, and the absent-mask placeholder was
Q(bf16) rather than the real mask dtype (int8) — so present-vs-absent re-split the cache regardless of the named key.
Absent optional args now use a dtype-matched placeholder (autotune_placeholder), and the mask flags are out of the keys. Autotune reuses the cached config across mask combinations (Triton still JIT-compiles a correct, separately specialized kernel per constexpr value); steady-state numerics and selected configs are unchanged.
Full changelog: https://github.com/hcompai/late-interaction-kernels/blob/v0.4.1/CHANGELOG.md
v0.4.0: Low-Memory Backward & Fused ColPali Training
Added
lowmembackward — ~half the training peak memory, deterministic.
A new destination-owned backward that accumulatesgrad_Q/grad_Din
fp32 registers and writes them straight in the input dtype (bf16/fp16) — no
full-size fp32 gradient buffer, no fp32→bf16 transient, and no atomics
(so it is bitwise-reproducible).autonow routes the gradient-heavy
shapes to it: knowledge-distillation / hard-negative layouts (where
grad_Disn_neg-inflated) and large high-contention in-batch squares;
unifiedstill handles the common and long-query cross-products, where it
is fastest. Measured on H100 (bf16, fwd+bwd): aB256 × 16-negColPali
step drops from 4.3 GB / 2.36 ms to 2.2 GB / 1.37 ms (≈½ memory,
~1.7× faster);pylate-text B256from 96 MB to 52 MB. Gradients match
unifiedto bf16 rounding. Select per-call withbackward="lowmem".- colpali_engine explicit-negative loss heads now fused.
patch_colpali_engine()previously accelerated only the in-batch CE term
ofColbertNegativeCELoss/ColbertPairwiseNegativeCELoss; their
explicit positive and per-query negative scoring stayed on the unfused
einsum. The positive term now routes throughmaxsim_pairs(diagonal
q[i]·d[i]) and the[B, n_neg, Ld, d]negative slab through 4-D
maxsim(KD layout), so neither materialises the similarity tensor; the
in-batch term keeps reusing the already-patched inner head (no double
work). Pos/neg fusion is CUDA-only — MPS / CPU fall back to the original
einsum for those terms while the in-batch term still accelerates. The 4-D
negative backwardauto-routes tolowmem, making the fused heads a
training memory win too (peak ~13–28% below vanilla, widening with
B × n_neg) on top of the speedup (up to 4.31× atB256 × 16-neg
in the MaxSim-isolation bench). colpaliinstall extra —pip install "late-interaction-kernels[colpali]"pullscolpali-engine>=0.3.10,<1
forpatch_colpali_engine(). CPU-only in CI (colpali_engine's
torchvision tree conflicts with the CUDA torch wheel, so it is never
co-activated withtorch-cuda; the GPU parity tests install it
out-of-band), mirroring thepylateextra.
Changed
- Long-query forward chunking — broadly faster at ColPali scale.
maxsim()now splits queries withLq > 512into fixed 128-token
chunks, scores each chunk as an independent query through the shared
_maxsim_crosscore, and sums the per-chunk MaxSim back per original
query. Summing a per-token max over query tokens is exact, so forward
and backward are numerically identical to the un-chunked path
(autograd flows through the reshape + sum). Long queries launch more,
shorter programs that fill the GPU instead of serialising one long
static_rangeloop, and the kernel always seesLq == 128, so the
autotune cache collapses onto a small constant (one entry, plus one
more for tail-paddedhas_q_mask=True) instead of one per length
bucket. Measured on H100 (bf16) withbench_chunking.py, vs the
un-chunked path: +49–77% atLq=768, and atLq=1024from +24%
in-batch to roughly break-even for rerank.
Shorter queries (ColBERTLq≤32, long-docLq≤512) fall through to
the existing core unchanged — no regression. Chunking is
cross-product-only; the KD / pairs path (4-DD) is unaffected and
long-LqKD should usemaxsim_varlen. - Autotuned backward launch params — faster training step. The
backward kernels previously launched with Triton's stock
num_warps=4. Each is one program per output row streaming a single
d_padvector through a doc loop, so 4 warps over-subscribe the
narrow program — the H100 optimum is 1–2 warps. Every backward kernel
is now@triton.autotuned over a small(num_warps, num_stages)
grid via a sharedbackward/_autotune.pyconfig module. The key
mirrors the forward autotuner (Lq,d_pad, layout flags;Nd/
Ldstay out), so the cache holds one entry per regime rather than
one per batch size, and atomic-accumulating kernels use
reset_to_zeroso autotune trials don't pile onto each other.
Measured on H100 (bf16), tuning liftsautoby ~1.2–1.45× across
the training shapes (see the backward table inbenchmarks.md),
the largest gain on the high-contentiontrain-256reduction, all at
lower peak memory.
Removed
- [breaking] Backward methods
atomicandcsr. The densegrad_D
strategies collapse to two:unified(fastest, fp32 atomics) and
lowmem(memory-optimal, deterministic). The legacy two-passatomic
path was strictly dominated byunified, andcsr's determinism niche
is now covered bylowmem, so both were deleted along with the CSR
build/sort machinery.backward=now accepts
"auto" | "unified" | "lowmem"; passing"atomic"or"csr"now
raisesValueError. Theautodefault is unaffected.
Fixed
maxsim_from_hiddenbackward leaked a spurious gradient for fully
d-masked documents. A document with every token masked out scores 0
in the forward, but the backward gathered a stale index-0 winner and
added a non-zero contribution tograd_Q/grad_H_d/grad_W/
grad_b. The fused-head kernel now writes a-1argmax sentinel for
query rows with no valid winner and the backward gates on it, matching
the mainmaxsimpath and the unfused reference (zero gradient).- Forward-kernel autotune config pruning now sizes its shared-memory
estimate with the padded embedding dim (next_pow2(d)) instead of the
rawd. For non-power-of-2dthe old estimate undercounted SMEM by
up to ~2x and could admit configs that overflow at launch. maxsim_residualnow raises on zero-length documents whenQrequires
grad. An empty doc has no MaxSim winner, so the backward had no correct
gradient and would gather a stale index-0 winner; it now fails fast.
Inference (no grad) is unchanged and still scores an empty doc 0.
v0.3.0: Autotune & dispatch overhaul across CUDA kernels + faster training on Apple Silicon
Added
- Training on Apple Silicon. New
maxsim_train_metal(forward +
saved argmax) andmaxsim_backward_metalMetal kernels mirror the
Tritonmaxsim_backward_unifiedAPI; a_MaxSimFnMetal(autograd. Function)wires them intomaxsim_mpsso the full training path
(forward, backward, L2-normalize Jacobian) runs on Metal instead of
falling back totorch.compile. Inference and unsupported-dtype
paths are unchanged. - KD / pairs layout on Metal (4-D
D). The MPS Metal kernel now
acceptsDas[Nq, K, Ld, d]directly — PyLate's
colbert_kd_scoresandcolbert_scores_pairwiseshapes hit the
Metal kernel instead of falling back totorch.compile. maxsim()4-D dispatch +maxsim_pairs()on CUDA.maxsim(Q, D)
now dispatches onD.dim(): 3-D stays on the in-batch cross-product
path, 4-D[Nq, K, Ld, d]runs as a single fused KD launch. New
maxsim_pairs(Q, D)covers the[B, Lq, d] × [B, Ld, d] → [B]
case. PyLate'scolbert_kd_scoresPythonfor-loop (one
maxsim()call per query) collapses to one kernel launch —
~10× faster than the loop atB=64, K=32and the pairwise
B=5000shape went from 355 ms (oldmaxsim_varlenpacking
path) → 0.18 ms (beatsflash-maxsim).- Every benchmark now reports peak VRAM (
max_memory_allocated()) per
variant in stdout and JSON;bench_fp8.pyandbench_fused_head_train.py
also gained--outdirJSON + Markdown sidecars.benchmarks/README.md
documents the unified CLI and the SkyPilot driver.
Changed
- CUDA autotune & dispatch overhaul — broadly faster across kernels.
Five independent wins on the Triton path:Lqbucketed to next pow-of-2 in themaxsim()wrapper
(#62). Variable-Lqtraining (ColBERT / ColPali, where the
tokenizer's per-batchmax(Lq)floats step-to-step) used to
re-trigger the full autotune sweep on every novelLq. Now
collapses to ≤ 9 cache entries. Measured 4.7× faster
end-to-end on 30 H100 steps withLq ∈ [8, 32](median
step 588 ms → 0.19 ms).- KD + pairwise folded into the fast forward + backward
kernels (#66). A singlekd_layout: tl.constexprswitches
d_global = pid % Nd(in-batch) vspid(KD / pairs) so
PyLate'scolbert_kd_scores(4-DD) and
colbert_scores_pairwiseshapes use the same dense fast path
as in-batch instead of routing throughscore_pairs_packed's
packing layer. KDB=64, K=32now beatsflash-maxsim
(lik/flash 0.94×), down from 1.4–7.7× slower pre-PR. - Persistent on-disk autotune cache via Triton ≥ 3.4's
cache_results=Trueon all eight autotuned kernels (#64).
First run on a machine still pays the ~4 s sweep; every
subsequent process / CI job / training restart loads the JSON
winner and skips bench. 10.6× faster cold start on the
second process. Feature-detected — older Triton (3.0–3.3)
silently keeps the in-memory-only behaviour, no dependency
floor bump. - Small-input forward bypass for
Nq*Nd ≤ 500 && d ≤ 256
withsave_argmax=False(#64). Fixed-config launch
(BLOCK_Q=32, BLOCK_D=64, num_warps=4, num_stages=2); cold
call drops from ~0.5–1 ms (autotune sweep) to sub-millisecond.
Closes the gap toflash-maxsim's_maxsim_fwd_kernel_small
on REPL / unit-test shapes. - New Hopper autotune config
BLOCK_Q=128, BLOCK_D=128, num_warps=8, num_stages=3(#57). Closes a 0.85× regression
vsflash-maxsimon the compute-bound colpali rerank shape
(Nq=1, Nd=500, Lq=Ld=1024, d=128— now 1.23×). Only picked
on the shape it was designed for; no regression elsewhere. normalizeout of the forward autotune key (#64). The two
constexpr branches still produce distinct binaries; they now
share one autotune entry instead of two. Cache cardinality
halves.
- Benchmark CLI unified. Experiment subsets are now
--only NAME ...
on every script (replacing the legacy--shape/--shapesflags and
the older--onlyfor variant selection, which moved to--variants).
scripts/sky_run_all_benchmarks.yamland the per-domain Sky yamls
accept aRUN_ONLYenv to pick a subset of tags. - SkyPilot bench yamls consolidated. Four operator-facing files now
cover every bench run:sky_benchmark_smoke_test.yaml(was
sky_run_benchmarks.yaml),sky_run_all_benchmarks.yaml,
sky_pylate_benchmark.yaml(new, folds the three previous
sky_lateon_edge.yaml/sky_pylate_realdata{,_long}.yaml), and
sky_colpali_benchmark.yaml(wassky_colpali_training.yaml). - MPS range refreshed (M4 2025). Inference
metal vs eageris now
1.9–3.5× andmetal vs compile2.2–14.3× (vs 1.9–3.2× and
1.1–2.0× in 0.2.0); the gap vstorch.compilewidened because MPS
Inductor regressed sharply on long-Ldinputs. New training (fwd +
bwd) table lands alongside, with the Metal backward 1.2–1.7× over
eager and 3–4× overtorch.compileonce shapes amortise launch
overhead. Full tables indocs/benchmarks.mdApple Silicon section. - JSON peak-VRAM keys standardized to
<variant>_peak_mbacross
bench_pylate_lateon,bench_pylate_realdata,bench_colpali_training,
bench_colpali_realdata,bench_cached_maxsim,bench_fastplaid, and
bench_lateon. Breaking for anyone parsingbenchmarks/results/*.json
directly (previously a mix ofpeak_gb/_peak/mem_*_MB). - Removed unreferenced exports
pylate_compat._bool_mask(shadowed by
_mask_as_bool) andmps.is_mps_tensor. - GPU CI moved from the GitHub-hosted runner to AWS CodeBuild (A10G) and
no longer auto-runs on push tomain; opt-in via therun-gpu-tests
PR label orworkflow_dispatch.
Removed
-
[breaking] Experimental kernels.
late_interaction_kernels.experimental
and its three research variants (soft_maxsim,smooth_maxsim,
maxsim_matryoshka) are gone, along withreference.maxsim_reference_soft,
tests/test_{soft,smooth,matryoshka}.py, and the two soft-maxsim cases
intests/test_robustness.py. None of them shipped to PyLate,
colpali_engine, FastPlaid, or NextPlaid; folding research kernels into
prod was the same mistake asmaxsim_xtrin 0.2.0. Users on a research
path can vendor the kernel source from the pre-0.3.0 git history. -
[breaking] Deprecated
*_inferenceshims andmaxsim_from_hidden_train.
The fourDeprecationWarningshims from 0.2.0 are removed:late_interaction_kernels.maxsim_inference→maxsim(...)late_interaction_kernels.fused_head.maxsim_from_hidden_train→
maxsim_from_hidden(...)late_interaction_kernels.varlen.maxsim_varlen_inference→
maxsim_varlen(...)late_interaction_kernels.plaid.maxsim_residual_inference→
maxsim_residual(...)
Each surviving function already auto-skips the saved argmax buffer
when no input hasrequires_grad=True, so behaviour is unchanged. -
[breaking]
set_backward_method/get_backward_methodremoved
(deprecated in 0.2.0). Migration: replaceset_backward_method("csr")
withmaxsim(..., backward="csr")(orMaxSimScorer(backward="csr")).
maxsim()'sbackward=Nonenow resolves directly to"auto"instead
of reading a module-level global. -
[breaking]
reference.xtr_referenceand its three CPU-only tests in
tests/test_reference_cpu.py— the XTR Triton kernel was already
deleted in 0.2.0.
Fixed
- Masked-row gradient poisoning (correctness). When every doc token
was masked for a(query, doc)pair, the Triton forward saved an
argmax of0instead of a sentinel, and the unified / atomic
backwards atomic-added a spuriousgrad_scores[i, j] * Q[i, s, :]
intograd_D[d_global, 0, :]. Forward now initialises the running
argmax to-1; the unified / atomic backwards skip the scatter on
t < 0. CSR backward is unaffected (sorting naturally drops the
sentinel). Matches the equivalent fix on the MPS Metal backward.
Known regressions (queued for 0.3.1)
Four shape-specific perf regressions vs 0.2.0 on the H100 sweep
(NGC 25.06 / torch 2.8 / triton 3.x). All correctness assertions hold;
the likely root cause for each is a winning autotune config rejected
by the tighter prune_forward SRAM model (#73).
bench_forwardtext-long(Nq=1, Nd=1k, Lq=32, Ld=1024):
0.093 → 0.121 ms (+30 %).bench_inference_edgeLateOn-Code-edge Nd=1k, Ld=1024, d=48:
0.072 → 0.114 ms (+58 %).bench_pylate_realdataContrastive bs=16, Lq=32, Ld=256: e2e step
52.3 → 61.6 ms (+18 %); vanilla baseline unchanged.bench_pylate_realdataCachedContrastive bs=64, mini=16, Lq=32, Ld=300, grad-ckpt: e2e step 305.5 → 318.5 ms — former 1.15× win
vs vanilla collapsed to 1.00×. Highest-priority of the four.
v0.2.0: API cleanup, colpali_engine drop-in, GPU CI
Added
- Self-hosted GPU CI workflow (
.github/workflows/gpu-ci.yml) that runs the
CUDA-marked tests on push tomain, on PRs touching kernel-related files,
onworkflow_dispatch, or on PRs labelledrun-gpu-ci. CPU-only CI was
split into.github/workflows/cpu-ci.yml; both workflows now trigger only
when their path filters match. - Interactive kernel picker (
docs/choose-a-kernel.html) and HTML playbook
(docs/how-it-works.html) to help pick the right kernel for a workload. patch_colpali_engine()/unpatch_colpali_engine()— colpali_engine
drop-in mirroringpatch_pylate. Monkey-patches
BaseVisualRetrieverProcessor.score_multi_vectorand the three in-batch
loss heads (ColbertLoss,ColbertPairwiseCELoss,ColbertSigmoidLoss)
to route theireinsum("bnd,csd->bcns") + amax(-1) + sum(-2)through the
fused kernel. Negative-mining siblings (ColbertNegativeCELoss,
ColbertPairwiseNegativeCELoss) inherit the in-batch term through their
self.inner_lossreference. Falls back to the original implementation
foruse_smooth_max=True,LIK_DISABLE=1, sub-Ampere CUDA, CPU
tensors, andd < 8.maxsim_padded— padded-input reranking helper (inspired by
https://github.com/ErikKaum/maxsim). Takes[B, Lq, d]/[B, C, Ld, d]
tensors with per-row lengths, returns[B, C]fp32. Autograd-aware on
every device: CUDA dispatches to the fused pair-list scatter kernel
(forward + backward), CPU / MPS fall back to the pure-PyTorch reference.
The underlyingpack_padded(...)building block (which converts to the
packedcu_seqlenslayout with a single combinedmax_seqlen_q/
max_seqlen_ddevice→host sync) is available from
late_interaction_kernels.padded.- Fused backward for
score_pairs_packed. The pair-list scatter kernel
now saves a[num_pairs, max_lq]argmax buffer when either input has
requires_grad=Trueand producesgrad_Q/grad_Ddirectly on the
packed layout via two atomic-add scatter kernels. Pair-list training is
nowO(num_pairs · max_lq · d)on both passes; no[Nq, Nd]
materialisation, no varlen-style off-diagonal compute. Pure inference
pays no overhead (save_argmax=False).
Changed
maxsimnow auto-skips the saved argmax buffer when neither input
hasrequires_grad=True, matching the dispatch already shipped by
maxsim_varlenandmaxsim_residual.maxsim_inferenceis now a
thin deprecation shim that forwards tomaxsim.maxsim_from_hiddenis now autograd-aware (gradients flow into
whichever ofQ/H_d/W/bcarryrequires_grad=True); the
forward-only path is auto-dispatched when none of them do.
maxsim_from_hidden_trainis now a thin deprecation shim.- [breaking] Bumped minimum PyTorch from
2.1to2.5. Older releases
are no longer tested and thetorch._assert_asyncbounds check in
pack_paddednow assumes the symbol is present unconditionally. - Replaced the unconditional CPU-only
torchpin with explicit
torch-cpu/torch-cudaoptional extras so CUDA installs no longer
pull a CPU wheel by default. - Cleaned the H100 autotune pool (
_autotune.py::_small_d_hopper). Dropped
the twowarp_spec=Trueconfigs that have been silent no-ops since
Triton 3.5 removed thenum_consumer_groups/num_buffers_warp_spec
kwargs (the API moved to compiler-driven warp specialization — without
the kwargs, those entries duplicated other configs in the pool and
occasionally won the autotune sample on noise alone). Also resized
BLOCK_Q=32, BLOCK_D=128fromnum_warps=8tonum_warps=4so it
matches the WGMMA warp-group size we actually want, and added the
matchingBLOCK_Q=64, BLOCK_D=128, num_warps=4, num_stages=3row.
Deprecated
set_backward_method/get_backward_methodnow emit
DeprecationWarning. The process-wide global has no functional
advantage over the per-callbackward=kwarg onmaxsim/
MaxSimScorerand complicates reasoning in multi-thread / multi-rank
setups. Migration: replaceset_backward_method("csr")with
maxsim(..., backward="csr")(orMaxSimScorer(backward="csr")).
The globals will be removed in the next breaking release.late_interaction_kernels.maxsim_inference— usemaxsim(...)directly.late_interaction_kernels.fused_head.maxsim_from_hidden_train— use
maxsim_from_hidden(...)directly.- [breaking]
maxsim_inference_scatter→score_pairs_packed; module
scatter.py→score_pairs.py. Shorter name, matches prior art in
https://github.com/ErikKaum/maxsim. Kernel, signature, and semantics
are identical. - [breaking] Trimmed the top-level public surface to everyday API only:
MaxSimScorer,retrieve,patch_pylate/unpatch_pylate,maxsim,
maxsim_inference,maxsim_varlen,maxsim_padded, and thereference
module. Lower-level / niche kernels must now be imported from their
submodule:score_pairs_packed→late_interaction_kernels.score_pairspack_padded/PackedBatch→late_interaction_kernels.paddedmaxsim_from_hidden/maxsim_from_hidden_train→late_interaction_kernels.fused_headplaid_approx_score/maxsim_residual/maxsim_residual_varlen
→late_interaction_kernels.plaidmaxsim_inference_fp8→late_interaction_kernels.fp8set_backward_method/get_backward_method→late_interaction_kernels.autograd
- [breaking] Module relocations following the submodule reorganisation.
Direct imports of the old paths now raiseImportError:late_interaction_kernels._mps→late_interaction_kernels.mps.compile_dispatchlate_interaction_kernels.metal→late_interaction_kernels.mps.metallate_interaction_kernels.backward_csr→late_interaction_kernels.backward.csrlate_interaction_kernels.backward_unified→late_interaction_kernels.backward.unifiedlate_interaction_kernels.{soft,smooth,matryoshka,xtr}→
late_interaction_kernels.experimental.{soft,smooth,matryoshka,xtr}
Removed
- [breaking] Top-level deprecation shims for
maxsim_forward,maxsim_topk,
maxsim_residual_inference,maxsim_varlen_inference,
maxsim_matryoshka,maxsim_xtr,soft_maxsim,smooth_maxsim,
quantize_fp8_per_tensor,quantize_fp8_per_token,
dequantize_fp8_per_tensor,dequantize_fp8_per_token. Import from
their submodules directly:late_interaction_kernels.{forward, topk, plaid, varlen, experimental, fp8}. - [breaking]
maxsim_xtr(XTR top-K aggregation, the experimental kernel
exposed atlate_interaction_kernels.experimental.xtr). The kernel only
ever shipped as a research curiosity; nothing inMaxSimScorer,
retrieve,patch_pylate, orpatch_colpali_engineused it. Users
who still need XTR aggregation can take the kernel source from a
pre-0.2.0 release or composemaxsimwith atopk + sumon the
output. The companion test (tests/test_xtr.py) is gone too.
Fixed
- Interactive kernel picker (
docs/choose-a-kernel.html) now surfaces
maxsim,maxsim_varlen,score_pairs_packed,maxsim_residual, and
maxsim_residual_varlenunder the "My own training / inference code"
branch (in addition to "Raw kernel functions"). Previously the combo
custom code + training + packed cu_seqlens returned "No exact match". - Kernel picker shows a composition recipe when the combo
varlen + top-k retrieval is selected (no single fused kernel covers
that today — the answer ismaxsim_varlenfollowed bytorch.topk).
The picker still falls back to the generic "No exact match" message
for combinations no recipe covers. late_interaction_kernels.backward.atomicreferenced
late_interaction_kernels.backward.backward_csr— a stale path from the
submodule rename that would have raisedModuleNotFoundErrorthe first
time theautobackward path picked CSR on a real GPU. Pointed at the
correct module (...backward.csr).score_pairs_packedno longer recompiles or re-autotunes per distinct
(max_lq, max_ld). Both weretl.constexprand part of the autotune
key, so each distinct max-seqlen bucket triggered a fresh compile +
autotune sweep — the same trapLdfell into on the dense forward in
0.1.0. The kernel now keys only ond_pad. Pinned by
tests/test_compile_cache.py(single autotune entry across 5 distinct
max_ld/max_lqvalues).
Documentation
- Spell out what the H100 forward table compares against (eager fp32
reference) and why atorch.compilebaseline isn't included on the CUDA
side: Inductor still has to materialize the[Nq · Nd · Lq · Ld]
similarity tensor beforemax(-1), which is exactly the HBM round-trip
the fused kernel exists to avoid. - README banner, usage-context clarifications, and a restructured
how-it-works walkthrough.
v0.1.0: Apple Silicon support + variable-Ld training fix
Fixed
- Triton kernels no longer recompile or re-autotune per distinct
Ld.
Ldwas declaredtl.constexprand (for the autotuned forwards) keyed
the autotuner — but inside the kernels it only drives a runtime
range(0, Ld, BLOCK_D)loop, so variable-length training was paying
one Triton recompile + one autotune sweep per distinct doc length.
Ldis now a runtime arg and is out of the autotune key across
forward,soft,smooth,fp8,fused_head,matryoshka, and the
three backward kernels. Measured 9.3× faster cold start on H100 (4
distinctLdvalues, fp16); steady-state per-call performance
unchanged. Pinned bytests/test_compile_cache.py.
Added
- Apple Silicon (MPS) — fused
simdgroup_matrixMetal kernel for forward
MaxSim (late_interaction_kernels.metal.maxsim_inference_metal),
JIT-compiled viatorch.mps.compile_shader. Persistent threadgroups serve
8 consecutivejvalues per launch,Qis register-resident across every
(j, d-chunk)pair, and the cooperative D load stages each row through
per-thread registers so the optional L2-normalize fold pays one threadgroup
write per element instead of three. Forward-only; never materialises the
[Nq · Nd · Lq · Ld]similarity tensor. - MPS dispatch (
late_interaction_kernels._mps) —torch.compile-fused
reference (autograd-aware) for training calls, with the Metal kernel
selected for inference when its envelope holds (fp16 / bf16,d ≤ 128,
d % 8 == 0,Nq · Nd ≥ 64 ∧ Ld ≥ 192). Compile-time MSL errors and
device-side faults fall back transparently to the compile path. patch_pylate()MPS routing —pylate.scores.colbert_scores/
colbert_kd_scoresnow route MPS tensors throughmaxsim_mps. The
maxsimTriton import is now lazy, sopylate_compatis importable on
machines without Triton (e.g. macOS).- Env overrides:
LIK_FORCE_MPS_BACKEND={metal,compile,reference},
LIK_DISABLE_COMPILE=1,LIK_MPS_METAL_MIN_BATCH,LIK_MPS_METAL_MIN_LD. benchmarks/bench_mps.pybenches Metal /torch.compile/ eager
side-by-side and reportsmetal vs eager,metal vs compile, and
compile vs eagerratios. Apple M4 fp16: 1.9–3.2× over eager
(1.1–2.0× overtorch.compile) on realistic inference shapes.benchmarks/bench_flash_maxsim.pyis back in the runner script and the
documentation; pinned toflash-maxsim==0.2.0so the published numbers
are reproducible.- 87 new MPS tests across
tests/test_mps.py,tests/test_mps_metal.py,
andtests/test_pylate_compat_mps.py(parity, masks, autograd, dispatch
fallbacks, env overrides, KD layout, PyLate routing).
Changed
MaxSimScorer/_score()now raises an explicitValueErrorwhen
Q.device != D.deviceinstead of silently dropping through to the eager
reference and surfacing an opaqueRuntimeErrorfromtorch.matmul—
same contractretrieve()already enforced.docs/benchmarks.mdApple Silicon section rewritten with the Metal
numbers, the dispatch heuristic, and the headlinemetal vs eagerratio.- Minimum Python is now 3.10 (was 3.9).
pyproject.tomlbumps
requires-python = ">=3.10", the Python classifiers, and
tool.ruff.target-version = "py310".uv.lockregenerated; the CI matrix
was already 3.10 / 3.11 / 3.12.
Removed
- All
from __future__ import annotationslines across the package, tests,
benchmarks, and examples (67 files). Annotations now use the native PEP 604
/ PEP 585 syntax (X | Y,list[X],dict[K, V]) that Python 3.10
supports at runtime — no compatibility shim needed.
Fixed
patch_pylate()previously gated onq.is_cudaand silently fell through
to PyLate's reference implementation for MPS tensors. Replaced with a
per-call_device_path(Q, D) → {"cuda", "mps", None}switch so Mac users
get the same one-liner upgrade as CUDA users.benchmarks/results/is now.gitignored; benchmark outputs are no
longer tracked in version control.- MPS
torch.compilepath no longer trips the inductor symbolic-shape
bug. On torch 2.8 / nightly, MPS inductor fails to lower
S.max(dim=-1)when the reduction axis is symbolic
(cannot determine truth value of Relational: s12 <= 1024from
codegen_iteration_ranges_entry). Switched the compile call from
dynamic=Truetodynamic=Falseso PyTorch's dynamo cache transparently
recompiles per(Nq, Nd, Lq, Ld)tuple instead — fine for typical
inference where shapes are stable, and shape-varying workloads can fall
back to the Metal kernel. Unblocks the 28 MPS tests that were skipped on
this bug; 167 / 167 pass on macOS with no skips on the dispatch /
metal /pylate_compat_mpssuites. README.mdheader restored. The Python-3.10 bump (#24) accidentally
stripped<div align="center">, the badge![]()image syntax, and
reformatted a handful of markdown tables — leaving the landing page
un-centered with plain-text "badges" instead of the shields.io images.
Restored from the pre-#24 revision and re-applied the
python-3.10–3.12shield change. No other content changes.
v0.0.1: initial release
Initial release for late-interaction-kernels: fused Triton kernels for late-interaction (MaxSim) scoring, with a high-level PyTorch API and PyLate drop-in 🚀
Added
- Core MaxSim kernels —
maxsim(autograd-aware) andmaxsim_inference
with fused L2-normalize, mask handling, and aunified/csr/atomic
backward selector (set_backward_method, defaultauto). - Ragged / packed batches —
maxsim_varlenovercu_seqlens-indexed
flat buffers, autograd-aware on bothQandD. - Pair-list scoring —
maxsim_inference_scatterscores arbitrary
(query_index, doc_index)pairs from packed batches and returns
[num_pairs]directly (vLLM-style reranker scheduling). - Fused D-side head —
maxsim_from_hidden(inference) and
maxsim_from_hidden_train(closed-form backward) apply
projection + L2-normalize + MaxSim in a single pass over raw
[Nd, Ld, d_model]hidden states. - PLAID / ColBERTv2 —
plaid_approx_score(approximate scoring) and
maxsim_residual/maxsim_residual_varlen(exact rerank with on-the-fly
2/4/8-bit residual decompression + L2-normalize + MaxSim, forward-only on
varlen). - FP8 inference —
maxsim_inference_fp8with per-tensor / per-token
e4m3 inputs, fp32 accumulator, and a score-tie fallback harness. - High-level API —
MaxSimScorer(nn.Module)andretrieve(Q, D, top_k),
both with transparent pure-PyTorch CPU fallback so training and retrieval
code is unit-testable on macOS / Windows / CPU-only CI. - PyLate drop-in —
patch_pylate/unpatch_pylatepatch
colbert_scoresandcolbert_kd_scoresacrossContrastive,
CachedContrastive, andDistillation.LIK_DISABLE=1is the
process-wide kill switch. - Experimental kernels —
late_interaction_kernels.experimentalships
soft_maxsim,smooth_maxsim,maxsim_xtr, andmaxsim_matryoshka. - FP8 helpers —
late_interaction_kernels.fp8exposes per-tensor /
per-token quantize / dequantize utilities. - Per-GPU autotune (Hopper / Ampere / Ada / generic) with shared-memory
pruning; warp specialization on Triton ≥ 3.2 with transparent fallback. - Pure-PyTorch reference (
late_interaction_kernels.reference) used as
ground truth in tests and as the CPU fallback path. - Test suite covering forward / backward parity, varlen, soft/smooth,
edge cases, PyLate compatibility, CPU fallback, andgradcheckon the
high-level API. - Benchmarks for every kernel, plus end-to-end PyLate / LateOn training
and retrieval scripts underbenchmarks/andscripts/.