Skip to content

v0.4.4: Perf roundup — one-shot norm check, shape-stable compiles, autotuned backwards

Latest

Choose a tag to compare

@h-aurelien-lac h-aurelien-lac released this 10 Jun 21:45
cb359cd

Fixed

  • maxsim(normalize=False) ran a .item() norm-check device sync on every
    call; it now runs once per process, restoring CUDA-graph capture on the
    PyLate / colpali-engine hot path.

Changed

  • PLAID centroid codes are handled as int32 end to end (any integer dtype is
    still accepted; out-of-range codes in plaid_approx_score clamp to
    centroid 0).
  • Batch sizes (Nq, Nd) are runtime kernel arguments, not constexpr — no
    more recompiles per batch shape under dynamic batching.
  • GPU family detection is keyed on compute capability, not the device name;
    Blackwell now gets first-class autotune configs.
  • The varlen, packed-pairs and residual backward launches are autotuned like
    the dense backwards (up to 1.42× on H100 at training shapes).
  • Backward gradient buffers that the kernels overwrite in full use
    torch.empty instead of torch.zeros (atomic-scatter buffers stay zeroed).
  • The fp8 autotune key no longer splits on mask presence.

Full rationale and H100 measurements in #111.