Skip to content

v0.3.0: Autotune & dispatch overhaul across CUDA kernels + faster training on Apple Silicon

Choose a tag to compare

@tonywu71 tonywu71 released this 28 May 14:19
· 22 commits to main since this release
95ed05b

Added

  • Training on Apple Silicon. New maxsim_train_metal (forward +
    saved argmax) and maxsim_backward_metal Metal kernels mirror the
    Triton maxsim_backward_unified API; a _MaxSimFnMetal(autograd. Function) wires them into maxsim_mps so the full training path
    (forward, backward, L2-normalize Jacobian) runs on Metal instead of
    falling back to torch.compile. Inference and unsupported-dtype
    paths are unchanged.
  • KD / pairs layout on Metal (4-D D). The MPS Metal kernel now
    accepts D as [Nq, K, Ld, d] directly — PyLate's
    colbert_kd_scores and colbert_scores_pairwise shapes hit the
    Metal kernel instead of falling back to torch.compile.
  • maxsim() 4-D dispatch + maxsim_pairs() on CUDA. maxsim(Q, D)
    now dispatches on D.dim(): 3-D stays on the in-batch cross-product
    path, 4-D [Nq, K, Ld, d] runs as a single fused KD launch. New
    maxsim_pairs(Q, D) covers the [B, Lq, d] × [B, Ld, d] → [B]
    case. PyLate's colbert_kd_scores Python for-loop (one
    maxsim() call per query) collapses to one kernel launch —
    ~10× faster than the loop at B=64, K=32 and the pairwise
    B=5000 shape went from 355 ms (old maxsim_varlen packing
    path) → 0.18 ms (beats flash-maxsim).
  • Every benchmark now reports peak VRAM (max_memory_allocated()) per
    variant in stdout and JSON; bench_fp8.py and bench_fused_head_train.py
    also gained --outdir JSON + Markdown sidecars. benchmarks/README.md
    documents the unified CLI and the SkyPilot driver.

Changed

  • CUDA autotune & dispatch overhaul — broadly faster across kernels.
    Five independent wins on the Triton path:
    • Lq bucketed to next pow-of-2 in the maxsim() wrapper
      (#62). Variable-Lq training (ColBERT / ColPali, where the
      tokenizer's per-batch max(Lq) floats step-to-step) used to
      re-trigger the full autotune sweep on every novel Lq. Now
      collapses to ≤ 9 cache entries. Measured 4.7× faster
      end-to-end
      on 30 H100 steps with Lq ∈ [8, 32] (median
      step 588 ms → 0.19 ms).
    • KD + pairwise folded into the fast forward + backward
      kernels
      (#66). A single kd_layout: tl.constexpr switches
      d_global = pid % Nd (in-batch) vs pid (KD / pairs) so
      PyLate's colbert_kd_scores (4-D D) and
      colbert_scores_pairwise shapes use the same dense fast path
      as in-batch instead of routing through score_pairs_packed's
      packing layer. KD B=64, K=32 now beats flash-maxsim
      (lik/flash 0.94×), down from 1.4–7.7× slower pre-PR.
    • Persistent on-disk autotune cache via Triton ≥ 3.4's
      cache_results=True on all eight autotuned kernels (#64).
      First run on a machine still pays the ~4 s sweep; every
      subsequent process / CI job / training restart loads the JSON
      winner and skips bench. 10.6× faster cold start on the
      second process. Feature-detected — older Triton (3.0–3.3)
      silently keeps the in-memory-only behaviour, no dependency
      floor bump.
    • Small-input forward bypass for Nq*Nd ≤ 500 && d ≤ 256
      with save_argmax=False (#64). Fixed-config launch
      (BLOCK_Q=32, BLOCK_D=64, num_warps=4, num_stages=2); cold
      call drops from ~0.5–1 ms (autotune sweep) to sub-millisecond.
      Closes the gap to flash-maxsim's _maxsim_fwd_kernel_small
      on REPL / unit-test shapes.
    • New Hopper autotune config BLOCK_Q=128, BLOCK_D=128, num_warps=8, num_stages=3 (#57). Closes a 0.85× regression
      vs flash-maxsim on the compute-bound colpali rerank shape
      (Nq=1, Nd=500, Lq=Ld=1024, d=128 — now 1.23×). Only picked
      on the shape it was designed for; no regression elsewhere.
    • normalize out of the forward autotune key (#64). The two
      constexpr branches still produce distinct binaries; they now
      share one autotune entry instead of two. Cache cardinality
      halves.
  • Benchmark CLI unified. Experiment subsets are now --only NAME ...
    on every script (replacing the legacy --shape / --shapes flags and
    the older --only for variant selection, which moved to --variants).
    scripts/sky_run_all_benchmarks.yaml and the per-domain Sky yamls
    accept a RUN_ONLY env to pick a subset of tags.
  • SkyPilot bench yamls consolidated. Four operator-facing files now
    cover every bench run: sky_benchmark_smoke_test.yaml (was
    sky_run_benchmarks.yaml), sky_run_all_benchmarks.yaml,
    sky_pylate_benchmark.yaml (new, folds the three previous
    sky_lateon_edge.yaml / sky_pylate_realdata{,_long}.yaml), and
    sky_colpali_benchmark.yaml (was sky_colpali_training.yaml).
  • MPS range refreshed (M4 2025). Inference metal vs eager is now
    1.9–3.5× and metal vs compile 2.2–14.3× (vs 1.9–3.2× and
    1.1–2.0× in 0.2.0); the gap vs torch.compile widened because MPS
    Inductor regressed sharply on long-Ld inputs. New training (fwd +
    bwd) table lands alongside, with the Metal backward 1.2–1.7× over
    eager and 3–4× over torch.compile once shapes amortise launch
    overhead. Full tables in docs/benchmarks.md Apple Silicon section.
  • JSON peak-VRAM keys standardized to <variant>_peak_mb across
    bench_pylate_lateon, bench_pylate_realdata, bench_colpali_training,
    bench_colpali_realdata, bench_cached_maxsim, bench_fastplaid, and
    bench_lateon. Breaking for anyone parsing benchmarks/results/*.json
    directly (previously a mix of peak_gb / _peak / mem_*_MB).
  • Removed unreferenced exports pylate_compat._bool_mask (shadowed by
    _mask_as_bool) and mps.is_mps_tensor.
  • GPU CI moved from the GitHub-hosted runner to AWS CodeBuild (A10G) and
    no longer auto-runs on push to main; opt-in via the run-gpu-tests
    PR label or workflow_dispatch.

Removed

  • [breaking] Experimental kernels. late_interaction_kernels.experimental
    and its three research variants (soft_maxsim, smooth_maxsim,
    maxsim_matryoshka) are gone, along with reference.maxsim_reference_soft,
    tests/test_{soft,smooth,matryoshka}.py, and the two soft-maxsim cases
    in tests/test_robustness.py. None of them shipped to PyLate,
    colpali_engine, FastPlaid, or NextPlaid; folding research kernels into
    prod was the same mistake as maxsim_xtr in 0.2.0. Users on a research
    path can vendor the kernel source from the pre-0.3.0 git history.

  • [breaking] Deprecated *_inference shims and maxsim_from_hidden_train.
    The four DeprecationWarning shims from 0.2.0 are removed:

    • late_interaction_kernels.maxsim_inferencemaxsim(...)
    • late_interaction_kernels.fused_head.maxsim_from_hidden_train
      maxsim_from_hidden(...)
    • late_interaction_kernels.varlen.maxsim_varlen_inference
      maxsim_varlen(...)
    • late_interaction_kernels.plaid.maxsim_residual_inference
      maxsim_residual(...)

    Each surviving function already auto-skips the saved argmax buffer
    when no input has requires_grad=True, so behaviour is unchanged.

  • [breaking] set_backward_method / get_backward_method removed
    (deprecated in 0.2.0). Migration: replace set_backward_method("csr")
    with maxsim(..., backward="csr") (or MaxSimScorer(backward="csr")).
    maxsim()'s backward=None now resolves directly to "auto" instead
    of reading a module-level global.

  • [breaking] reference.xtr_reference and its three CPU-only tests in
    tests/test_reference_cpu.py — the XTR Triton kernel was already
    deleted in 0.2.0.

Fixed

  • Masked-row gradient poisoning (correctness). When every doc token
    was masked for a (query, doc) pair, the Triton forward saved an
    argmax of 0 instead of a sentinel, and the unified / atomic
    backwards atomic-added a spurious grad_scores[i, j] * Q[i, s, :]
    into grad_D[d_global, 0, :]. Forward now initialises the running
    argmax to -1; the unified / atomic backwards skip the scatter on
    t < 0. CSR backward is unaffected (sorting naturally drops the
    sentinel). Matches the equivalent fix on the MPS Metal backward.

Known regressions (queued for 0.3.1)

Four shape-specific perf regressions vs 0.2.0 on the H100 sweep
(NGC 25.06 / torch 2.8 / triton 3.x). All correctness assertions hold;
the likely root cause for each is a winning autotune config rejected
by the tighter prune_forward SRAM model (#73).

  • bench_forward text-long (Nq=1, Nd=1k, Lq=32, Ld=1024):
    0.093 → 0.121 ms (+30 %).
  • bench_inference_edge LateOn-Code-edge Nd=1k, Ld=1024, d=48:
    0.072 → 0.114 ms (+58 %).
  • bench_pylate_realdata Contrastive bs=16, Lq=32, Ld=256: e2e step
    52.3 → 61.6 ms (+18 %); vanilla baseline unchanged.
  • bench_pylate_realdata CachedContrastive bs=64, mini=16, Lq=32, Ld=300, grad-ckpt: e2e step 305.5 → 318.5 ms — former 1.15× win
    vs vanilla collapsed to 1.00×. Highest-priority of the four.