You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This commit was created on GitHub.com and signed with GitHub’s verified signature.
Fixed
Triton kernels no longer recompile or re-autotune per distinct Ld. Ld was declared tl.constexpr and (for the autotuned forwards) keyed
the autotuner — but inside the kernels it only drives a runtime range(0, Ld, BLOCK_D) loop, so variable-length training was paying
one Triton recompile + one autotune sweep per distinct doc length. Ld is now a runtime arg and is out of the autotune key across forward, soft, smooth, fp8, fused_head, matryoshka, and the
three backward kernels. Measured 9.3× faster cold start on H100 (4
distinct Ld values, fp16); steady-state per-call performance
unchanged. Pinned by tests/test_compile_cache.py.
Added
Apple Silicon (MPS) — fused simdgroup_matrix Metal kernel for forward
MaxSim (late_interaction_kernels.metal.maxsim_inference_metal),
JIT-compiled via torch.mps.compile_shader. Persistent threadgroups serve
8 consecutive j values per launch, Q is register-resident across every (j, d-chunk) pair, and the cooperative D load stages each row through
per-thread registers so the optional L2-normalize fold pays one threadgroup
write per element instead of three. Forward-only; never materialises the [Nq · Nd · Lq · Ld] similarity tensor.
MPS dispatch (late_interaction_kernels._mps) — torch.compile-fused
reference (autograd-aware) for training calls, with the Metal kernel
selected for inference when its envelope holds (fp16 / bf16, d ≤ 128, d % 8 == 0, Nq · Nd ≥ 64 ∧ Ld ≥ 192). Compile-time MSL errors and
device-side faults fall back transparently to the compile path.
patch_pylate() MPS routing — pylate.scores.colbert_scores / colbert_kd_scores now route MPS tensors through maxsim_mps. The maxsim Triton import is now lazy, so pylate_compat is importable on
machines without Triton (e.g. macOS).
benchmarks/bench_mps.py benches Metal / torch.compile / eager
side-by-side and reports metal vs eager, metal vs compile, and compile vs eager ratios. Apple M4 fp16: 1.9–3.2× over eager
(1.1–2.0× over torch.compile) on realistic inference shapes.
benchmarks/bench_flash_maxsim.py is back in the runner script and the
documentation; pinned to flash-maxsim==0.2.0 so the published numbers
are reproducible.
87 new MPS tests across tests/test_mps.py, tests/test_mps_metal.py,
and tests/test_pylate_compat_mps.py (parity, masks, autograd, dispatch
fallbacks, env overrides, KD layout, PyLate routing).
Changed
MaxSimScorer / _score() now raises an explicit ValueError when Q.device != D.device instead of silently dropping through to the eager
reference and surfacing an opaque RuntimeError from torch.matmul —
same contract retrieve() already enforced.
docs/benchmarks.md Apple Silicon section rewritten with the Metal
numbers, the dispatch heuristic, and the headline metal vs eager ratio.
Minimum Python is now 3.10 (was 3.9). pyproject.toml bumps requires-python = ">=3.10", the Python classifiers, and tool.ruff.target-version = "py310". uv.lock regenerated; the CI matrix
was already 3.10 / 3.11 / 3.12.
Removed
All from __future__ import annotations lines across the package, tests,
benchmarks, and examples (67 files). Annotations now use the native PEP 604
/ PEP 585 syntax (X | Y, list[X], dict[K, V]) that Python 3.10
supports at runtime — no compatibility shim needed.
Fixed
patch_pylate() previously gated on q.is_cuda and silently fell through
to PyLate's reference implementation for MPS tensors. Replaced with a
per-call _device_path(Q, D) → {"cuda", "mps", None} switch so Mac users
get the same one-liner upgrade as CUDA users.
benchmarks/results/ is now .gitignored; benchmark outputs are no
longer tracked in version control.
MPS torch.compile path no longer trips the inductor symbolic-shape
bug. On torch 2.8 / nightly, MPS inductor fails to lower S.max(dim=-1) when the reduction axis is symbolic
(cannot determine truth value of Relational: s12 <= 1024 from codegen_iteration_ranges_entry). Switched the compile call from dynamic=True to dynamic=False so PyTorch's dynamo cache transparently
recompiles per (Nq, Nd, Lq, Ld) tuple instead — fine for typical
inference where shapes are stable, and shape-varying workloads can fall
back to the Metal kernel. Unblocks the 28 MPS tests that were skipped on
this bug; 167 / 167 pass on macOS with no skips on the dispatch /
metal / pylate_compat_mps suites.
README.md header restored. The Python-3.10 bump (#24) accidentally
stripped <div align="center">, the badge ![]() image syntax, and
reformatted a handful of markdown tables — leaving the landing page
un-centered with plain-text "badges" instead of the shields.io images.
Restored from the pre-#24 revision and re-applied the python-3.10–3.12 shield change. No other content changes.