You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This commit was created on GitHub.com and signed with GitHub’s verified signature.
Added
lowmem backward — ~half the training peak memory, deterministic.
A new destination-owned backward that accumulates grad_Q / grad_D in
fp32 registers and writes them straight in the input dtype (bf16/fp16) — no
full-size fp32 gradient buffer, no fp32→bf16 transient, and no atomics
(so it is bitwise-reproducible). auto now routes the gradient-heavy
shapes to it: knowledge-distillation / hard-negative layouts (where grad_D is n_neg-inflated) and large high-contention in-batch squares; unified still handles the common and long-query cross-products, where it
is fastest. Measured on H100 (bf16, fwd+bwd): a B256 × 16-neg ColPali
step drops from 4.3 GB / 2.36 ms to 2.2 GB / 1.37 ms (≈½ memory,
~1.7× faster); pylate-text B256 from 96 MB to 52 MB. Gradients match unified to bf16 rounding. Select per-call with backward="lowmem".
colpali_engine explicit-negative loss heads now fused. patch_colpali_engine() previously accelerated only the in-batch CE term
of ColbertNegativeCELoss / ColbertPairwiseNegativeCELoss; their
explicit positive and per-query negative scoring stayed on the unfused
einsum. The positive term now routes through maxsim_pairs (diagonal q[i]·d[i]) and the [B, n_neg, Ld, d] negative slab through 4-D maxsim (KD layout), so neither materialises the similarity tensor; the
in-batch term keeps reusing the already-patched inner head (no double
work). Pos/neg fusion is CUDA-only — MPS / CPU fall back to the original
einsum for those terms while the in-batch term still accelerates. The 4-D
negative backward auto-routes to lowmem, making the fused heads a
training memory win too (peak ~13–28% below vanilla, widening with B × n_neg) on top of the speedup (up to 4.31× at B256 × 16-neg
in the MaxSim-isolation bench).
colpali install extra — pip install "late-interaction-kernels[colpali]" pulls colpali-engine>=0.3.10,<1
for patch_colpali_engine(). CPU-only in CI (colpali_engine's
torchvision tree conflicts with the CUDA torch wheel, so it is never
co-activated with torch-cuda; the GPU parity tests install it
out-of-band), mirroring the pylate extra.
Changed
Long-query forward chunking — broadly faster at ColPali scale. maxsim() now splits queries with Lq > 512 into fixed 128-token
chunks, scores each chunk as an independent query through the shared _maxsim_cross core, and sums the per-chunk MaxSim back per original
query. Summing a per-token max over query tokens is exact, so forward
and backward are numerically identical to the un-chunked path
(autograd flows through the reshape + sum). Long queries launch more,
shorter programs that fill the GPU instead of serialising one long static_range loop, and the kernel always sees Lq == 128, so the
autotune cache collapses onto a small constant (one entry, plus one
more for tail-padded has_q_mask=True) instead of one per length
bucket. Measured on H100 (bf16) with bench_chunking.py, vs the
un-chunked path: +49–77% at Lq=768, and at Lq=1024 from +24%
in-batch to roughly break-even for rerank.
Shorter queries (ColBERT Lq≤32, long-doc Lq≤512) fall through to
the existing core unchanged — no regression. Chunking is
cross-product-only; the KD / pairs path (4-D D) is unaffected and
long-Lq KD should use maxsim_varlen.
Autotuned backward launch params — faster training step. The
backward kernels previously launched with Triton's stock num_warps=4. Each is one program per output row streaming a single d_pad vector through a doc loop, so 4 warps over-subscribe the
narrow program — the H100 optimum is 1–2 warps. Every backward kernel
is now @triton.autotuned over a small (num_warps, num_stages)
grid via a shared backward/_autotune.py config module. The key
mirrors the forward autotuner (Lq, d_pad, layout flags; Nd / Ld stay out), so the cache holds one entry per regime rather than
one per batch size, and atomic-accumulating kernels use reset_to_zero so autotune trials don't pile onto each other.
Measured on H100 (bf16), tuning lifts auto by ~1.2–1.45× across
the training shapes (see the backward table in benchmarks.md),
the largest gain on the high-contention train-256 reduction, all at
lower peak memory.
Removed
[breaking] Backward methods atomic and csr. The dense grad_D
strategies collapse to two: unified (fastest, fp32 atomics) and lowmem (memory-optimal, deterministic). The legacy two-pass atomic
path was strictly dominated by unified, and csr's determinism niche
is now covered by lowmem, so both were deleted along with the CSR
build/sort machinery. backward= now accepts "auto" | "unified" | "lowmem"; passing "atomic" or "csr" now
raises ValueError. The auto default is unaffected.
Fixed
maxsim_from_hidden backward leaked a spurious gradient for fully
d-masked documents. A document with every token masked out scores 0
in the forward, but the backward gathered a stale index-0 winner and
added a non-zero contribution to grad_Q / grad_H_d / grad_W / grad_b. The fused-head kernel now writes a -1 argmax sentinel for
query rows with no valid winner and the backward gates on it, matching
the main maxsim path and the unfused reference (zero gradient).
Forward-kernel autotune config pruning now sizes its shared-memory
estimate with the padded embedding dim (next_pow2(d)) instead of the
raw d. For non-power-of-2 d the old estimate undercounted SMEM by
up to ~2x and could admit configs that overflow at launch.
maxsim_residual now raises on zero-length documents when Q requires
grad. An empty doc has no MaxSim winner, so the backward had no correct
gradient and would gather a stale index-0 winner; it now fails fast.
Inference (no grad) is unchanged and still scores an empty doc 0.