Skip to content

v0.4.3: Bug-fix roundup — chunked top-k, empty batches, varlen autotune

Choose a tag to compare

@tonywu71 tonywu71 released this 10 Jun 08:58
· 1 commit to main since this release
11cf671

Fixed

  • Chunked top-k no longer crashes when top_k exceeds the merge width.
    retrieve(Q, D, top_k=10, chunk=4) (and the CUDA maxsim_topk it wraps)
    raised "selected index k out of range": each chunk contributes at most
    chunk columns, but the running merge always asked torch.topk for k.
    The merge now clamps k to the merged width; the final output still
    honors the documented min(top_k, Nd) contract on both paths.
  • MaxSimScorer.forward() respects a caller's torch.no_grad(). The
    non-inference branch of _score forced torch.enable_grad(), so
    MaxSimScorer()(Q, D) inside torch.no_grad() returned a tensor with
    requires_grad=True. It now uses a null context and leaves the caller's
    grad mode untouched.
  • maxsim_residual_varlen handles an empty corpus (Nd == 0). The
    launcher used grid Nq * max(Nd, 1) while the kernel computed
    pid // Nd with a constexpr Nd=0. The empty [Nq, 0] result is now
    returned before any launch.
  • pack_padded handles an empty batch (B == 0). It crashed on
    qlen.max() of a zero-element tensor; it now returns the trivially-empty
    PackedBatch up front, so maxsim_padded agrees with the CPU reference
    ([0, C]).
  • PyLate legacy mask= is forwarded on the fallback path. The patched
    colbert_scores / colbert_kd_scores / colbert_scores_pairwise mapped
    a legacy mask= into the fused path but called the original function
    with a still-None documents_mask when deferring to PyLate (CPU,
    sub-Ampere, LIK_DISABLE=1), silently dropping the doc mask.
  • maxsim_backward_lowmem raises a clean RuntimeError without
    Triton/CUDA
    instead of failing inside the launch — same guard as
    maxsim_backward_unified, so backward/__init__'s import-anywhere
    promise holds.
  • prune_forward falls back to the smallest-footprint configs. When
    every autotune config overflowed the shared-memory budget, the fallback
    returned the first two list entries — configs just pruned for exceeding
    the budget. It now returns the two with the smallest estimated SMEM
    footprint.
  • The KD path (maxsim with 4-D D) validates device agreement for
    Q / D / masks like the 3-D path, instead of surfacing an internal
    kernel error on mixed devices.
  • FP8 fallback really works without Triton. maxsim_inference_fp8's
    documented "transparent fallback" imported the Triton-backed maxsim
    unconditionally; it now dispatches like retrieve (Triton on CUDA,
    compiled reference on MPS, eager reference elsewhere), so the dequantized
    bf16 fallback runs on CPU-only installs.

Changed

  • maxsim_varlen autotune sweeps are amortized across batch shapes.
    max_lq / max_ld were constexpr autotune keys with no bucketing, so —
    contrary to the previous docs — every distinct (max_lq, max_ld) pair
    re-triggered the full autotune sweep and recompiled the kernel. They
    are now exact runtime loop bounds (like the dense kernel's Ld), and
    the autotune cache is keyed on power-of-two-bucketed copies (floor 16)
    passed as key-only arguments: one sweep per bucket, zero masked
    iterations added, argmax buffer still sized on the exact max_lq.
    Results are unchanged. The PLAID kernels (maxsim_residual /
    maxsim_residual_varlen) deliberately keep their exact max_Ld key: it
    is a property of the compressed index, stable across calls, so the sweep
    amortizes to one — bucketing the loop bound there measured 30-50%
    steady-state throughput loss on Ld=300 corpora for no benefit, which
    is also why the varlen bounds stay exact.
  • max_seqlen_* arguments are documented as hard loop bounds, not
    hints
    , in maxsim_varlen, score_pairs_packed, and
    maxsim_residual_varlen: a too-small value silently truncated tokens and
    returned wrong scores. Caller-supplied values are now checked against the
    cu_seqlens maxima with an on-device torch._assert_async (no D2H sync,
    same trade-off as pack_padded's length checks).
  • maxsim_residual squeezes a 2-D Q back to [Nd] instead of
    returning [1, Nd], matching maxsim_residual_varlen and the maxsim
    wrapper's convention. Behavior change for callers that relied on the
    un-squeezed shape.
  • Removed the dead B constexpr from _plaid_approx_score_kernel (it
    forced a recompile per batch size), the unused Lq / Ld placeholder
    params from the varlen forward kernel, and the dead Nq constexpr from
    _varlen_bwd_dQ_kernel. The varlen backward kernels also take max_lq
    as a runtime arg instead of a constexpr (mirroring the score_pairs
    backward kernels), so a new query-length maximum no longer recompiles
    them.
  • Docstring corrections: maxsim_inference_fp8 no longer mentions argmax
    ties (the inference kernel never computes an argmax), and the lowmem
    backward docs say grads are written in the input dtype (fp16 / bf16 /
    fp32), not "bf16 grads".

Removed

  • The interactive kernel picker page (docs/choose-a-kernel.html). It predates
    the API consolidation: with maxsim dispatching on layout and the patchers /
    native PyLate & colpali-engine backends covering the framework paths, a
    decision tree over many entry points no longer reflects the library. The
    README "Choose a kernel" section and links are gone with it.