Skip to content

v0.4.0: Low-Memory Backward & Fused ColPali Training

Choose a tag to compare

@tonywu71 tonywu71 released this 03 Jun 13:08
· 11 commits to main since this release
1d1e42a

Added

  • lowmem backward — ~half the training peak memory, deterministic.
    A new destination-owned backward that accumulates grad_Q / grad_D in
    fp32 registers and writes them straight in the input dtype (bf16/fp16) — no
    full-size fp32 gradient buffer, no fp32→bf16 transient, and no atomics
    (so it is bitwise-reproducible). auto now routes the gradient-heavy
    shapes to it: knowledge-distillation / hard-negative layouts (where
    grad_D is n_neg-inflated) and large high-contention in-batch squares;
    unified still handles the common and long-query cross-products, where it
    is fastest. Measured on H100 (bf16, fwd+bwd): a B256 × 16-neg ColPali
    step drops from 4.3 GB / 2.36 ms to 2.2 GB / 1.37 ms (≈½ memory,
    ~1.7× faster); pylate-text B256 from 96 MB to 52 MB. Gradients match
    unified to bf16 rounding. Select per-call with backward="lowmem".
  • colpali_engine explicit-negative loss heads now fused.
    patch_colpali_engine() previously accelerated only the in-batch CE term
    of ColbertNegativeCELoss / ColbertPairwiseNegativeCELoss; their
    explicit positive and per-query negative scoring stayed on the unfused
    einsum. The positive term now routes through maxsim_pairs (diagonal
    q[i]·d[i]) and the [B, n_neg, Ld, d] negative slab through 4-D
    maxsim (KD layout), so neither materialises the similarity tensor; the
    in-batch term keeps reusing the already-patched inner head (no double
    work). Pos/neg fusion is CUDA-only — MPS / CPU fall back to the original
    einsum for those terms while the in-batch term still accelerates. The 4-D
    negative backward auto-routes to lowmem, making the fused heads a
    training memory win too (peak ~13–28% below vanilla, widening with
    B × n_neg) on top of the speedup (up to 4.31× at B256 × 16-neg
    in the MaxSim-isolation bench).
  • colpali install extrapip install "late-interaction-kernels[colpali]" pulls colpali-engine>=0.3.10,<1
    for patch_colpali_engine(). CPU-only in CI (colpali_engine's
    torchvision tree conflicts with the CUDA torch wheel, so it is never
    co-activated with torch-cuda; the GPU parity tests install it
    out-of-band), mirroring the pylate extra.

Changed

  • Long-query forward chunking — broadly faster at ColPali scale.
    maxsim() now splits queries with Lq > 512 into fixed 128-token
    chunks, scores each chunk as an independent query through the shared
    _maxsim_cross core, and sums the per-chunk MaxSim back per original
    query. Summing a per-token max over query tokens is exact, so forward
    and backward are numerically identical to the un-chunked path
    (autograd flows through the reshape + sum). Long queries launch more,
    shorter programs that fill the GPU instead of serialising one long
    static_range loop, and the kernel always sees Lq == 128, so the
    autotune cache collapses onto a small constant (one entry, plus one
    more for tail-padded has_q_mask=True) instead of one per length
    bucket. Measured on H100 (bf16) with bench_chunking.py, vs the
    un-chunked path: +49–77% at Lq=768, and at Lq=1024 from +24%
    in-batch to roughly break-even for rerank
    .
    Shorter queries (ColBERT Lq≤32, long-doc Lq≤512) fall through to
    the existing core unchanged — no regression. Chunking is
    cross-product-only; the KD / pairs path (4-D D) is unaffected and
    long-Lq KD should use maxsim_varlen.
  • Autotuned backward launch params — faster training step. The
    backward kernels previously launched with Triton's stock
    num_warps=4. Each is one program per output row streaming a single
    d_pad vector through a doc loop, so 4 warps over-subscribe the
    narrow program — the H100 optimum is 1–2 warps. Every backward kernel
    is now @triton.autotuned over a small (num_warps, num_stages)
    grid via a shared backward/_autotune.py config module. The key
    mirrors the forward autotuner (Lq, d_pad, layout flags; Nd /
    Ld stay out), so the cache holds one entry per regime rather than
    one per batch size, and atomic-accumulating kernels use
    reset_to_zero so autotune trials don't pile onto each other.
    Measured on H100 (bf16), tuning lifts auto by ~1.2–1.45× across
    the training shapes
    (see the backward table in benchmarks.md),
    the largest gain on the high-contention train-256 reduction, all at
    lower peak memory.

Removed

  • [breaking] Backward methods atomic and csr. The dense grad_D
    strategies collapse to two: unified (fastest, fp32 atomics) and
    lowmem (memory-optimal, deterministic). The legacy two-pass atomic
    path was strictly dominated by unified, and csr's determinism niche
    is now covered by lowmem, so both were deleted along with the CSR
    build/sort machinery. backward= now accepts
    "auto" | "unified" | "lowmem"; passing "atomic" or "csr" now
    raises ValueError. The auto default is unaffected.

Fixed

  • maxsim_from_hidden backward leaked a spurious gradient for fully
    d-masked documents.
    A document with every token masked out scores 0
    in the forward, but the backward gathered a stale index-0 winner and
    added a non-zero contribution to grad_Q / grad_H_d / grad_W /
    grad_b. The fused-head kernel now writes a -1 argmax sentinel for
    query rows with no valid winner and the backward gates on it, matching
    the main maxsim path and the unfused reference (zero gradient).
  • Forward-kernel autotune config pruning now sizes its shared-memory
    estimate with the padded embedding dim (next_pow2(d)) instead of the
    raw d. For non-power-of-2 d the old estimate undercounted SMEM by
    up to ~2x and could admit configs that overflow at launch.
  • maxsim_residual now raises on zero-length documents when Q requires
    grad. An empty doc has no MaxSim winner, so the backward had no correct
    gradient and would gather a stale index-0 winner; it now fails fast.
    Inference (no grad) is unchanged and still scores an empty doc 0.