Add MLX_NUMERICAL_STRICT_MODE for shape-independent quantized_matmul#3473
Add MLX_NUMERICAL_STRICT_MODE for shape-independent quantized_matmul#3473rakshith48 wants to merge 1 commit intoml-explore:mainfrom
Conversation
…d_matmul quantized_matmul currently dispatches to one of three GPU kernels based on input M (qmv, qmm_splitk, qmm), each using a different K-reduction tree. fp32 sum is non-associative so the three paths produce slightly different bit patterns (~1.5-2.7e-5 per element) for the same dot product. For inference and training this is invisible. But it silently breaks any workload that compares two equivalent execution paths: - prefix-cache reuse (vLLM-style engines) - batched-vs-streaming eval comparison (lm-evaluation-harness) - distillation/RLHF teacher-student forward-pass equality Adds an opt-in env-var-gated flag matching MLX's existing convention (MLX_ENABLE_TF32, etc.). When set, QuantizedMatmul and GatherQMM force the no-split qmm/gather_qmm reference paths so output is bit-identical regardless of M. Cost: ~1.5-2.3x slower decode at M=1 (qmv is heavily optimized for the single-token case). Off by default; users opting in are explicitly trading throughput for correctness. Files changed: - mlx/utils.h (+19 LoC): env::numerical_strict_mode() helper - mlx/backend/metal/quantized.cpp (+53 LoC): gates in QuantizedMatmul::eval_gpu and GatherQMM::eval_gpu
|
I hit a severe manifestation of this bug in Setup: a 2-token verification pass (M=2). Queries shape Reproducer (no model required, tested on MLX 0.31.1 and 0.31.2)import mlx.core as mx
from mlx.utils import tree_map
B, n_kv_heads, n_repeats, D = 1, 4, 4, 256
n_q_heads = n_kv_heads * n_repeats
key_bits, value_bits, group_size = 8, 4, 64
mx.random.seed(42)
def run(N, M):
keys_f = mx.random.normal((B, n_kv_heads, N, D)).astype(mx.float16)
values_f = mx.random.normal((B, n_kv_heads, N, D)).astype(mx.float16)
q_keys = mx.quantize(keys_f, group_size=group_size, bits=key_bits)
q_values = mx.quantize(values_f, group_size=group_size, bits=value_bits)
queries = (mx.random.normal((B, n_q_heads, M, D)) * D**-0.5).astype(mx.float16)
mx.eval(q_keys, q_values, queries)
# reference: dequantize then float matmul
keys_dq = mx.dequantize(*q_keys, group_size=group_size, bits=key_bits)
values_dq = mx.dequantize(*q_values, group_size=group_size, bits=value_bits)
qr = queries.reshape(B, n_kv_heads, n_repeats, M, D)
s_ref = mx.softmax(qr @ keys_dq[:,:,None,:,:].transpose(0,1,2,4,3), axis=-1)
out_ref = (s_ref @ values_dq[:,:,None,:,:]).reshape(B, n_q_heads, M, D)
# quantized_matmul with expand_dims broadcast (GQA)
qk_e = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
qv_e = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
s_qmm = mx.softmax(mx.quantized_matmul(qr, *qk_e, transpose=True, group_size=group_size, bits=key_bits), axis=-1)
out_qmm = mx.quantized_matmul(s_qmm, *qv_e, transpose=False, group_size=group_size, bits=value_bits).reshape(B, n_q_heads, M, D)
mx.eval(out_ref, out_qmm)
diff = mx.max(mx.abs(out_ref.astype(mx.float32) - out_qmm.astype(mx.float32))).item()
print(f"N={N:5d} M={M}: max_diff={diff:.6f}")
for N in [512, 2048, 4096, 7358]:
for M in [1, 2]:
run(N, M)Output on M4 ProIn actual inference logs (real model, N=7358), I measured On the fix
For the GQA case specifically, a more targeted fix (detecting zero strides on the batch dimensions of |
Upstream issue + PR for ml-explore/mlx
Issue title
quantized_matmuloutput depends on input shape (M dimension); addMLX_NUMERICAL_STRICT_MODEopt-in for path-independent outputSummary
mx.quantized_matmuldispatches to one of three GPU kernels based on the M (batch/sequence) dimension of the input:qmvforM < vector_limit(~10–32 depending on K, N, arch)qmm_splitkforvector_limit ≤ M < ~65and transposed weights (B=1)qmmfor everything elseAll three kernels accumulate K in fp32 internally (BlockMMA's
AccumType=float, qmv uses fp32 accumulators). However, they use different reduction trees across K:qmmqmm_splitkqmvsimd_sum(hardware butterfly)fp32 floating-point sum is not associative:
(a + b) + c ≠ a + (b + c)in general. So three kernels computing the same dot product over the same K=4096 produce three different bit patterns, differing by ~1.5–2.7 × 10⁻⁵ per element.For straight inference and training this is invisible. But it silently breaks any workload that compares two equivalent execution paths:
q_proj(x_full)vsq_proj(x_full[:, -L:])should match for the overlapping tokens. They don't.Reproducer
Tested on MLX v0.31.2, Apple M1 16GB, Qwen3-8B-4bit (
mlx-community/Qwen3-8B-4bit).Output (without strict mode):
Same pattern in
k_projwith the boundary at L=256→257 (because N=1024 instead of 4096 changes the splitk threshold).Important: this is NOT just an fp16 ULP issue
An earlier analysis suggested the bug was fp16 partial-sum rounding in splitk's intermediate buffer (with the fix being: store partials in fp32 instead of fp16). I implemented that fix and verified it produces bit-identical output to pristine on the actual reproducer above. The reason: with bf16 scales the model output is auto-promoted to fp32 by MLX, so there's no fp16 cast anywhere in the splitk path that would benefit from precision promotion. The ~2 × 10⁻⁵ diff is purely fp32 non-associativity from differing reduction trees.
The implication: promoting partial-sums to fp32 does not fix the bit-equivalence problem for fp32-output paths. Only matching the reduction tree (or skipping the fast paths entirely) gives bit-equivalence.
Proposed fix:
MLX_NUMERICAL_STRICT_MODEopt-inAdd an opt-in env-var-controlled flag matching MLX's existing convention (
MLX_ENABLE_TF32,MLX_METAL_FAST_SYNCH, etc.):mlx/utils.h— add helper:mlx/backend/metal/quantized.cpp::QuantizedMatmul::eval_gpu— gate at top of dispatch:That single gate at
eval_gputop covers all three shape-dependent paths (qmv, qmm_splitk, qvm_split_k) because qmm is the canonical reference.Validation
With
MLX_NUMERICAL_STRICT_MODE=1:Bit-identical at every L for every projection.
Performance cost (honest numbers)
Measured on M1 16GB with Qwen3-8B-4bit:
The decode-loop slowdown is significant because
qmvis heavily optimized for M=1 generation; bypassing it forcesqmmto use a 32×32 tile for a single output row.This is why the flag is opt-in, not default-on. For users running:
Files changed
mlx/utils.h— +20 LoC (env helper + comment block)mlx/backend/metal/quantized.cpp— +18 LoC (gate at QuantizedMatmul::eval_gpu)Total: ~40 LoC. No new kernels, no new tests of existing behavior, no breaking changes. Off-by-default → zero impact on any existing user.
Test
Reproducer in
mac-llm-bench/eval/test_numerical_strict_mode.py— script that runs the boundary check in both modes:diff == 0.0at every L; exits 1 if any L failsReference
Hardware: Apple M1 (16 GB unified memory)
MLX version: 0.31.2
Model: mlx-community/Qwen3-8B-4bit
Discovered while building an Apple Silicon eval harness for Qwen3-8B (mac-llm-bench). The original symptom was a ~0.5pp MMLU accuracy regression when prefix-cache reuse was enabled; root-caused to this path-dependence by sweeping L and observing the boundary at L=64 (q/o_proj) and L=256 (k/v_proj) match the dispatcher's split_k threshold.