Skip to content

v0.2.0: API cleanup, colpali_engine drop-in, GPU CI

Choose a tag to compare

@h-aurelien-lac h-aurelien-lac released this 22 May 10:14
· 57 commits to main since this release
d5aa60a

Added

  • Self-hosted GPU CI workflow (.github/workflows/gpu-ci.yml) that runs the
    CUDA-marked tests on push to main, on PRs touching kernel-related files,
    on workflow_dispatch, or on PRs labelled run-gpu-ci. CPU-only CI was
    split into .github/workflows/cpu-ci.yml; both workflows now trigger only
    when their path filters match.
  • Interactive kernel picker (docs/choose-a-kernel.html) and HTML playbook
    (docs/how-it-works.html) to help pick the right kernel for a workload.
  • patch_colpali_engine() / unpatch_colpali_engine() — colpali_engine
    drop-in mirroring patch_pylate. Monkey-patches
    BaseVisualRetrieverProcessor.score_multi_vector and the three in-batch
    loss heads (ColbertLoss, ColbertPairwiseCELoss, ColbertSigmoidLoss)
    to route their einsum("bnd,csd->bcns") + amax(-1) + sum(-2) through the
    fused kernel. Negative-mining siblings (ColbertNegativeCELoss,
    ColbertPairwiseNegativeCELoss) inherit the in-batch term through their
    self.inner_loss reference. Falls back to the original implementation
    for use_smooth_max=True, LIK_DISABLE=1, sub-Ampere CUDA, CPU
    tensors, and d < 8.
  • maxsim_padded — padded-input reranking helper (inspired by
    https://github.com/ErikKaum/maxsim). Takes [B, Lq, d] / [B, C, Ld, d]
    tensors with per-row lengths, returns [B, C] fp32. Autograd-aware on
    every device: CUDA dispatches to the fused pair-list scatter kernel
    (forward + backward), CPU / MPS fall back to the pure-PyTorch reference.
    The underlying pack_padded(...) building block (which converts to the
    packed cu_seqlens layout with a single combined max_seqlen_q /
    max_seqlen_d device→host sync) is available from
    late_interaction_kernels.padded.
  • Fused backward for score_pairs_packed. The pair-list scatter kernel
    now saves a [num_pairs, max_lq] argmax buffer when either input has
    requires_grad=True and produces grad_Q / grad_D directly on the
    packed layout via two atomic-add scatter kernels. Pair-list training is
    now O(num_pairs · max_lq · d) on both passes; no [Nq, Nd]
    materialisation, no varlen-style off-diagonal compute. Pure inference
    pays no overhead (save_argmax=False).

Changed

  • maxsim now auto-skips the saved argmax buffer when neither input
    has requires_grad=True, matching the dispatch already shipped by
    maxsim_varlen and maxsim_residual. maxsim_inference is now a
    thin deprecation shim that forwards to maxsim.
  • maxsim_from_hidden is now autograd-aware (gradients flow into
    whichever of Q / H_d / W / b carry requires_grad=True); the
    forward-only path is auto-dispatched when none of them do.
    maxsim_from_hidden_train is now a thin deprecation shim.
  • [breaking] Bumped minimum PyTorch from 2.1 to 2.5. Older releases
    are no longer tested and the torch._assert_async bounds check in
    pack_padded now assumes the symbol is present unconditionally.
  • Replaced the unconditional CPU-only torch pin with explicit
    torch-cpu / torch-cuda optional extras so CUDA installs no longer
    pull a CPU wheel by default.
  • Cleaned the H100 autotune pool (_autotune.py::_small_d_hopper). Dropped
    the two warp_spec=True configs that have been silent no-ops since
    Triton 3.5 removed the num_consumer_groups / num_buffers_warp_spec
    kwargs (the API moved to compiler-driven warp specialization — without
    the kwargs, those entries duplicated other configs in the pool and
    occasionally won the autotune sample on noise alone). Also resized
    BLOCK_Q=32, BLOCK_D=128 from num_warps=8 to num_warps=4 so it
    matches the WGMMA warp-group size we actually want, and added the
    matching BLOCK_Q=64, BLOCK_D=128, num_warps=4, num_stages=3 row.

Deprecated

  • set_backward_method / get_backward_method now emit
    DeprecationWarning. The process-wide global has no functional
    advantage over the per-call backward= kwarg on maxsim /
    MaxSimScorer and complicates reasoning in multi-thread / multi-rank
    setups. Migration: replace set_backward_method("csr") with
    maxsim(..., backward="csr") (or MaxSimScorer(backward="csr")).
    The globals will be removed in the next breaking release.
  • late_interaction_kernels.maxsim_inference — use maxsim(...) directly.
  • late_interaction_kernels.fused_head.maxsim_from_hidden_train — use
    maxsim_from_hidden(...) directly.
  • [breaking] maxsim_inference_scatterscore_pairs_packed; module
    scatter.pyscore_pairs.py. Shorter name, matches prior art in
    https://github.com/ErikKaum/maxsim. Kernel, signature, and semantics
    are identical.
  • [breaking] Trimmed the top-level public surface to everyday API only:
    MaxSimScorer, retrieve, patch_pylate / unpatch_pylate, maxsim,
    maxsim_inference, maxsim_varlen, maxsim_padded, and the reference
    module. Lower-level / niche kernels must now be imported from their
    submodule:
    • score_pairs_packedlate_interaction_kernels.score_pairs
    • pack_padded / PackedBatchlate_interaction_kernels.padded
    • maxsim_from_hidden / maxsim_from_hidden_trainlate_interaction_kernels.fused_head
    • plaid_approx_score / maxsim_residual / maxsim_residual_varlen
      late_interaction_kernels.plaid
    • maxsim_inference_fp8late_interaction_kernels.fp8
    • set_backward_method / get_backward_methodlate_interaction_kernels.autograd
  • [breaking] Module relocations following the submodule reorganisation.
    Direct imports of the old paths now raise ImportError:
    • late_interaction_kernels._mpslate_interaction_kernels.mps.compile_dispatch
    • late_interaction_kernels.metallate_interaction_kernels.mps.metal
    • late_interaction_kernels.backward_csrlate_interaction_kernels.backward.csr
    • late_interaction_kernels.backward_unifiedlate_interaction_kernels.backward.unified
    • late_interaction_kernels.{soft,smooth,matryoshka,xtr}
      late_interaction_kernels.experimental.{soft,smooth,matryoshka,xtr}

Removed

  • [breaking] Top-level deprecation shims for maxsim_forward, maxsim_topk,
    maxsim_residual_inference, maxsim_varlen_inference,
    maxsim_matryoshka, maxsim_xtr, soft_maxsim, smooth_maxsim,
    quantize_fp8_per_tensor, quantize_fp8_per_token,
    dequantize_fp8_per_tensor, dequantize_fp8_per_token. Import from
    their submodules directly: late_interaction_kernels.{forward, topk, plaid, varlen, experimental, fp8}.
  • [breaking] maxsim_xtr (XTR top-K aggregation, the experimental kernel
    exposed at late_interaction_kernels.experimental.xtr). The kernel only
    ever shipped as a research curiosity; nothing in MaxSimScorer,
    retrieve, patch_pylate, or patch_colpali_engine used it. Users
    who still need XTR aggregation can take the kernel source from a
    pre-0.2.0 release or compose maxsim with a topk + sum on the
    output. The companion test (tests/test_xtr.py) is gone too.

Fixed

  • Interactive kernel picker (docs/choose-a-kernel.html) now surfaces
    maxsim, maxsim_varlen, score_pairs_packed, maxsim_residual, and
    maxsim_residual_varlen under the "My own training / inference code"
    branch (in addition to "Raw kernel functions"). Previously the combo
    custom code + training + packed cu_seqlens returned "No exact match".
  • Kernel picker shows a composition recipe when the combo
    varlen + top-k retrieval is selected (no single fused kernel covers
    that today — the answer is maxsim_varlen followed by torch.topk).
    The picker still falls back to the generic "No exact match" message
    for combinations no recipe covers.
  • late_interaction_kernels.backward.atomic referenced
    late_interaction_kernels.backward.backward_csr — a stale path from the
    submodule rename that would have raised ModuleNotFoundError the first
    time the auto backward path picked CSR on a real GPU. Pointed at the
    correct module (...backward.csr).
  • score_pairs_packed no longer recompiles or re-autotunes per distinct
    (max_lq, max_ld). Both were tl.constexpr and part of the autotune
    key, so each distinct max-seqlen bucket triggered a fresh compile +
    autotune sweep — the same trap Ld fell into on the dense forward in
    0.1.0. The kernel now keys only on d_pad. Pinned by
    tests/test_compile_cache.py (single autotune entry across 5 distinct
    max_ld / max_lq values).

Documentation

  • Spell out what the H100 forward table compares against (eager fp32
    reference) and why a torch.compile baseline isn't included on the CUDA
    side: Inductor still has to materialize the [Nq · Nd · Lq · Ld]
    similarity tensor before max(-1), which is exactly the HBM round-trip
    the fused kernel exists to avoid.
  • README banner, usage-context clarifications, and a restructured
    how-it-works walkthrough.