Add support for late-interaction-kernels (LIK)#412
Conversation
ManuelFay
left a comment
There was a problem hiding this comment.
this is super cool !
I think we need to run a real training with this kernel to see before merging what the gains are: how much can we increase batch size by, how much more speed do we gain?
this would also be a good opportunity to reverify some of the training scripts and the doc here to make sure it s starightforward to do (i am not so sure).
4a733b6 to
1223fbb
Compare
- Add maxsim dispatcher (maxsim_inbatch, maxsim_kd) with a pure-torch einsum reference and a lazily-imported LIK backend - Mirror PyLate's design: [lik] extra, COLPALI_SCORES_BACKEND env var (auto/torch/lik) read per call, LIKUnsupportedError sentinel - Route score_multi_vector and the five ColBERT losses through the dispatcher (negative-doc losses via LIK's kd_layout) - Add CPU dispatch tests plus CUDA parity and training-smoke tests - Fix transformers-5.x trainer breakage (_get_train_sampler signature, single-dataset compute_loss prefixes) - Document the extra and the backend toggle in README and CHANGELOG
- Add bench_train.py: runs training steps with the maxsim dispatcher
instrumented (per-call forward peak and bytes held for backward),
then replays each recorded shape on an isolated graph to bracket
the op's backward exactly
- Add SkyPilot sweep over B in {16..128} x {auto, torch}, fresh
process per cell so an OOM is isolated
- Add summarizer emitting the markdown table and the log-log plot
- Add train-subset loaders so runs skip the full 52 GB train set
8 cells from a 1x H100 run (LIK 0.4.1, ColQwen2 + LoRA): per-op forward/held/backward VRAM plus whole-step peaks; the vanilla B=128 cell records the fragmentation OOM message.
Keep the final tree lean: the harness and results stay reachable at the two prior commits, referenced from the PR description.
cacf178 to
1d3363c
Compare
- Cache ~/.cache/huggingface across runs, keyed on test file contents - Serve cached files without network calls in CI via a conftest shim - Surface .no_exist markers as EntryNotFoundError, matching online 404 handling
a758065 to
dd096be
Compare
There was a problem hiding this comment.
Pull request overview
Adds an opt-in MaxSim backend that can route late-interaction scoring through late-interaction-kernels (LIK) when eligible (CUDA Ampere+ / Apple Silicon), with a transparent pure-torch fallback otherwise. This integrates the dispatcher into scoring and ColBERT-style losses, adds an optional dependency extra, and introduces targeted tests and CI caching to keep the default CPU test suite stable.
Changes:
- Introduce
maxsim_inbatch/maxsim_kddispatchers with a lazy LIK backend and env-var selection (COLPALI_SCORES_BACKEND=auto|torch|lik). - Route
score_multi_vectorand ColBERT losses through the dispatcher (keeping smooth-max on torch). - Add CPU/CUDA parity + smoke tests and CI Hugging Face cache support; document the new optional
[lik]extra.
Reviewed changes
Copilot reviewed 12 out of 13 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
tests/utils/test_maxsim.py |
CPU-side dispatcher semantics + parity tests against the torch reference. |
tests/utils/test_maxsim_cuda.py |
CUDA-gated LIK forward/backward parity + training smoke tests for updated losses. |
tests/conftest.py |
CI-only Hugging Face cache-first monkeypatch to reduce flaky Hub/network dependency. |
README.md |
Document optional fused MaxSim kernels and COLPALI_SCORES_BACKEND. |
pyproject.toml |
Add [lik] optional extra and include it in [all]. |
colpali_engine/utils/processing_utils.py |
Switch score_multi_vector to use the MaxSim dispatcher. |
colpali_engine/utils/maxsim.py |
New MaxSim dispatcher + torch reference implementations. |
colpali_engine/utils/_lik_backend.py |
New lazily-imported LIK backend with eligibility checks and sentinel error. |
colpali_engine/trainer/contrastive_trainer.py |
Transformers 5.x compatibility fix + initialize dataset prefix fields for single-dataset path. |
colpali_engine/loss/late_interaction_losses.py |
Route hard-max paths through dispatcher for in-batch + negative-doc losses. |
CHANGELOG.md |
Changelog entry for the new optional [lik] backend and env var. |
.gitignore |
Ignore uv.lock. |
.github/workflows/test.yml |
Cache Hugging Face files in CI prior to running the test suite. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Hey @QuentinJGMace, tsm for the review! 🙌🏼 What experiment do you have in mind exactly? Imo running full end-to-end training on the entire train set is feasible or necessary here. I already have tests showing that, given the same inputs, LIK and Here’s what I’d suggest: check that, with and without LIK, we get the same training loss after a given number of steps. Would that work for you? |
I agree that training end to end is not completely necessary, I just thought you had done it. I think a convincing experiment would be to have 3 runs of 50-100 steps (same seed for each collator):
It would ensure that lik does not break anything (even if as I understand it, we're already pretty sure of that) and that the code changes done with this PR did not change the way training behaved with previous versions. wdyt ? |
Sounds good, I'm on it! I'll share the details on my experiments in a comment not to pollute the commit history. |
|
@QuentinJGMace Ran the 3 trainings you asked for. Same seed ( Settings
Results
max |LIK − torch| on the branch is 0.0039 (mean 0.0008). For scale: main vs branch-torch run the same math, and they differ by just as much. The LIK deltas sit at bf16 step-to-step noise level, and the three curves overlap:
One caveat to be transparent about Current main can't run this training as-is. With the declared dependency range (
This branch already fixes both, which is why runs 2 and 3 needed no changes. For run 1, I shimmed main in the harness only (adapted the sampler call convention and primed the prefix attributes from the collator), with zero changes to any scoring or loss math. The produced sampler is identical, since transformers passes One shim applied to all three runs equally: transformers 5.x leaves the tied The LIK run's wall time (54s vs 21s for torch) is the cold-start autotune sweep discussed earlier in this thread. It's a fixed one-off cost, irrelevant beyond a few steps. |
Co-authored-by: Aurélien Lac <56725662+uminaty@users.noreply.github.com>
|
Thanks ! merging now |

Summary
Adds an optional
[lik]extra that routes ColBERT MaxSim scoring throughlate-interaction-kernels(LIK), a fused Triton kernel, on CUDA Ampere+ / Apple Silicon, with a transparent pure-torch fallback everywhere else. It is opt-in and feature-flagged (COLPALI_SCORES_BACKENDselectsauto/torch/lik), with no change to the public API or training semantics: the kernel and the torch reference return the same scores and the same loss.On the MaxSim operation itself the kernel's speedup is unambiguous: isolating the loss head on ColPali-like shapes (no encoder), LIK runs the forward+backward up to 2.5–4.3× faster at large batch×negatives, with the win growing with
B × n_neg(LIK 0.4.1 benchmarks). In ColVision training, however, MaxSim is a few milliseconds inside a step dominated by the 2B-parameter model forward/backward, so the op speedup dilutes to per-step parity end-to-end. What survives the dilution is the memory win: measured at the op level, vanilla MaxSim costs 7.8 GiB of VRAM at B=128 where LIK costs 62 MiB (129×), and that B²-growing term is exactly what caps the trainable batch size. Removing it doubles the batch on an 80 GB H100 (vanilla OOMs at B=128, LIK trains it).What this PR adds
colpali_engine/utils/maxsim.py: the dispatcher (maxsim_inbatch,maxsim_kd) selecting between the LIK backend and the torcheinsum + amax + sumreference perCOLPALI_SCORES_BACKEND.colpali_engine/utils/_lik_backend.py: the lazily-imported LIK implementations, input validation, and theLIKUnsupportedErrorsentinel.score_multi_vector, the three in-batch ColBERT losses (ColbertLoss,ColbertPairwiseCELoss,ColbertSigmoidLoss), and the two negative-doc losses (ColbertNegativeCELoss,ColbertPairwiseNegativeCELoss, via LIK'skd_layout) through the dispatcher.pyproject.toml: optional extralik = ["late-interaction-kernels>=0.4.1,<0.5.0"]; README section documenting the extra and the env var._get_train_samplersignature; single-datasetcompute_lossprefixes).The benchmarking harnesses used for the numbers below were added and then removed within this PR's history, keeping the final tree lean: the op-level VRAM harness and its results live at
1717e37, the original batch-size sweep at2749bd5(a pre-rebase commit GitHub keeps accessible).Design
maxsim_inbatch(Q, D)handles the in-batch[B, Lq, d] x [B, Ld, d]grid (used byscore_multi_vectorand the in-batch losses);maxsim_kd(Q, D)handles the per-query candidate layout[B, N, Ld, d](negative-doc losses). The LIK implementations live in a lazily-imported_lik_backendmodule that validates each call (CUDA Ampere+ or MPS, embedding dim above the kernel's tile floor, matching devices) and raises aLIKUnsupportedErrorsentinel when the kernel cannot run; real kernel errors always propagate. Both paths treat padded tokens as exactly-zero embeddings rather than an explicit mask; ColQwen2 already zeroes padded positions through the attention mask, so the scores match.The design deliberately matches PyLate's integration (lightonai/pylate#222), so using both libraries means one mental model: the extra is
[lik], the backend module split is the same, andCOLPALI_SCORES_BACKEND(read per call) mirrorsPYLATE_SCORES_BACKENDwith the same three values:auto(default) uses LIK when eligible and silently falls back to torch,torchforces the reference, andlikis strict, raisingLIKUnsupportedErrorinstead of falling back.Results
The kernel itself is much faster than the einsum it replaces. LIK's 0.4.1 benchmarks isolate the loss head on ColPali-like shapes (
Lq=32, Ld=1030, no encoder, forward+backward at matched numerics): the speedup climbs 1.13× → 4.31× asB × n_neggrows (2.50× at B256×n8, 4.31× at B256×n16), with ~25–30% lower peak memory on the head. In LIK's own words, this is "the throughput the encoder hides" in end-to-end training.End-to-end, the model forward dilutes that speedup to parity. A ColQwen2 training step is dominated by the 2B-parameter doc/query towers; MaxSim is a few milliseconds of a >1 s step. We measure per-step parity (B=64: 7.19 vs 7.23 samples/s), and LIK's own end-to-end ColQwen2 table shows the same 0.97–1.02×. The dilution is mechanical, not a kernel property: on a 17M encoder, where MaxSim is a bigger slice of the step, the same kernel shows up as a 1.1–1.3× end-to-end speedup.
What survives at ColQwen2 scale is the memory win, measured at the op level. We instrumented the dispatcher during real ColQwen2 training steps, then replayed each recorded shape on an isolated graph where the op's backward can be measured exactly (the replayed forward numbers match the in-train ones to the MiB). The VRAM attributable to MaxSim:
The score grid is fp32 in practice (autocast computes the embedding L2-norm in fp32 and the division promotes, so the loss runs on fp32 embeddings). At B=128 the
[B,B,Lq,Ld]tensor is 2.4 GiB, held from the op's forward until its backward, where the op spikes another 2.25× that (the grid's gradient plus theamaxscatter): 5.4 GiB. LIK holds only the[B,B]output and its backward allocates only the input gradients (dominated bygrad_D), so its footprint grows linearly in B instead of quadratically: 62 MiB total at B=128, a 129× reduction.That op footprint is what sets the batch-size ceiling. Sweeping
per_device_train_batch_sizeuntil OOM (ColQwen2 + LoRA, grad-checkpointing, bf16, 80 GB H100): whole-step peak allocated VRAM is identical while both fit, then splits at B=128.Note
Vanilla and LIK look identical in the VRAM table up to B=64 because the score tensors are freed before the peak: the op's backward runs first, then the model backward where the peak lives. At B=128 vanilla dies not because it uses more peak memory, but because its score grid needs multi-GiB contiguous blocks that memory fragmentation makes impossible to satisfy (the observed OOM is a 1.81 GiB request failing while 25 GiB sit reserved but unallocated). LIK's 62 MiB fits in whatever scraps remain.
Vanilla maxes out at B=64, LIK at B=128 (2× headroom). At B=256 both paths OOM: pushing 256 pages of ~768 visual tokens each through the 2B doc tower is the limit, regardless of the score tensor. The first steps pay a one-time Triton autotune warmup that amortizes over a full run. The loss matches the torch reference within bf16 noise.
Full sweep table
1× H100,
vidore/colqwen2-base,ColbertPairwiseCELoss, grad-checkpointing on, LIK0.4.1. Fresh process per (B, backend) so an OOM is isolated.Throughput is from 4-step runs, so it is autotune-warmup-affected (LIK looks slower at B=16 only because warmup dominates 4 steps). The point that matters: LIK runs B=128 at 7.98 samples/s, which vanilla cannot reach.
Reproduce
The op-level VRAM harness and its result JSONs were added and removed within this PR's history; check out
1717e37to get both. The harness wrapsmaxsim_inbatchduring training to record the forward peak and the bytes held for backward, then replays each recorded shape on an isolated graph to bracket the op's backward exactly (a grad hook cannot bracket it in-train: it fires as a pre-hook of the producing node, after the whole doc-tower backward).The whole-step batch-size sweep (B up to 512) ran on an earlier harness iteration at
2749bd5(sky_batch_sweep.yaml+summarize_sweep.py, pre-rebase commit kept accessible by GitHub). The CUDA test runner (scripts/sky_test_lik.yaml) lives at7ae3402; the slow suite itself stays in-tree (pytest -m slow tests/utils/test_maxsim_cuda.pyon a CUDA Ampere+ host).Force a single run onto a backend with
COLPALI_SCORES_BACKEND=auto|torch|lik(likerrors instead of silently falling back).Next steps
late-interaction-kernelsrepository.colpali-enginerelease once merged (0.3.17if nothing else lands in between) so the[lik]extra is installable from PyPI.