Skip to content

v0.4.1 — Mask-invariant autotune key

Choose a tag to compare

@h-aurelien-lac h-aurelien-lac released this 03 Jun 15:58
40b7c1d

Fixed

Variable-length training no longer pays repeated autotune sweeps.

On ColQwen2 / ColPali training with variable query lengths, a fresh 5–10 s Triton autotune sweep fired every time a query batch first toggled its mask presence — as late as step 14, costing up to 1.6× end-to-end on vidore/docvqa_test_subsampled.

Two causes, both fixed:

  1. has_q_mask / has_d_mask were in the forward and backward autotune keys. They are constexpr toggles that change codegen but not the winning (BLOCK_Q, BLOCK_D, num_warps, num_stages) tile, so they only fragmented the cache.
  2. Triton's autotuner also keys on the dtype of every tensor argument, and the absent-mask placeholder was Q (bf16) rather than the real mask dtype (int8) — so present-vs-absent re-split the cache regardless of the named key.

Absent optional args now use a dtype-matched placeholder (autotune_placeholder), and the mask flags are out of the keys. Autotune reuses the cached config across mask combinations (Triton still JIT-compiles a correct, separately specialized kernel per constexpr value); steady-state numerics and selected configs are unchanged.

Full changelog: https://github.com/hcompai/late-interaction-kernels/blob/v0.4.1/CHANGELOG.md