mx.quantized_matmul(transpose=True) shows a discontinuous cost increase at M=3 for large asymmetric MLP shapes. M=2 costs nearly the same as M=1, but M=3 costs +28-37% more. The step does not appear for square shapes (5120→5120).
Per-op microbench (group_size=64, bits=4)
import mlx.core as mx
import time, statistics
def bench_qmm(K, N, M, n_iters=200, warmup=30):
w_f = mx.random.normal((N, K), dtype=mx.float32).astype(mx.float16)
w_q, scales, biases = mx.quantize(w_f, group_size=64, bits=4)
x = mx.random.normal((M, K), dtype=mx.float32).astype(mx.float16)
mx.eval(w_q, scales, biases, x)
times = []
for i in range(warmup + n_iters):
t0 = time.perf_counter()
out = mx.quantized_matmul(x, w_q, scales, biases, transpose=True, group_size=64, bits=4)
mx.eval(out)
if i >= warmup:
times.append((time.perf_counter() - t0) * 1000)
return statistics.mean(times)
shapes = [
('5120 → 5120 (attn q/o)', 5120, 5120),
('5120 → 1024 (attn k/v)', 5120, 1024),
('5120 → 13824 (MLP gate/up)', 5120, 13824),
('13824 → 5120 (MLP down)', 13824, 5120),
]
for label, K, N in shapes:
ms = [bench_qmm(K, N, M) for M in [1, 2, 3, 4]]
print(f'{label}: M=1={ms[0]:.3f}ms M=2={ms[1]:.3f}ms M=3={ms[2]:.3f}ms M=4={ms[3]:.3f}ms (M3/M1={ms[2]/ms[0]:.2f}x)')
Output on M4 Pro:
5120 → 5120 (attn q/o): M=1=0.218ms M=2=0.200ms M=3=0.195ms M=4=0.233ms (M3/M1=0.89x)
5120 → 1024 (attn k/v): M=1=0.112ms M=2=0.118ms M=3=0.127ms M=4=0.136ms (M3/M1=1.13x)
5120 → 13824 (MLP gate/up): M=1=0.270ms M=2=0.272ms M=3=0.346ms M=4=0.425ms (M3/M1=1.28x)
13824 → 5120 (MLP down): M=1=0.277ms M=2=0.292ms M=3=0.379ms M=4=0.477ms (M3/M1=1.37x)
The step is shape-dependent: it only appears where output dimension >> input dimension (or vice versa), not on square shapes.
Full model forward (Qwen3.6-27B 4-bit, post-prefill 512 tokens)
| M |
Time |
Ratio vs M=1 |
| 1 |
68.0ms |
1.00x |
| 2 |
69.3ms |
1.02x |
| 3 |
93.8ms |
1.38x |
| 4 |
121.9ms |
1.79x |
| 5 |
150.2ms |
2.21x |
| 6 |
179.3ms |
2.64x |
From M=3 onward, cost grows linearly at about +28ms/step. Non-quantized ops (RMSNorm, softmax) are flat across M=1-4, so the step is entirely in the linear projections.
Context
All shapes fall below vector_limit (=10 for M4 Pro at K,N > 4096, from get_qmv_batch_limit using applegpu_g16s), so M=3 dispatches to qmv_fast_impl, not qmm_splitk.
qmv_fast_impl uses grid = (M, ceil(N/8), B): one threadgroup per M row per N-tile. The cause of the discontinuity at M=3 is not clear to me. I tried two things:
- Fused kernel (process M rows within a single threadgroup to share weight loads): correct output but slower than stock. The M threadgroups run in parallel in the current dispatch, and serializing them loses more than is gained.
- Lowering
vector_limit to route M=3+ to qmm_splitk: significantly worse (M3/M1 goes from 1.3x to 2.7-2.9x), which makes sense since qmm_splitk is designed for larger M.
The step appears systematically across all asymmetric projection shapes.
→ Does this look like a GPU scheduling effect (wave quantization, occupancy change)? Is there a profiling approach that could help narrow it down?
Impact
This affects any workload batching M=3-8 tokens together: small-batch server inference (3 concurrent requests), speculative decoding verify passes (draft_length=2 produces M=3), beam search. The +28% jump at M=3 is abrupt and hard to work around at the application level.
PRs #1861 (faster small-batch qmv) and #3120 (split-K qmm) address adjacent regimes, but M=3-9 on large asymmetric shapes seems to remain unaddressed.
mx.quantized_matmul(transpose=True)shows a discontinuous cost increase at M=3 for large asymmetric MLP shapes. M=2 costs nearly the same as M=1, but M=3 costs +28-37% more. The step does not appear for square shapes (5120→5120).Per-op microbench (group_size=64, bits=4)
Output on M4 Pro:
The step is shape-dependent: it only appears where output dimension >> input dimension (or vice versa), not on square shapes.
Full model forward (Qwen3.6-27B 4-bit, post-prefill 512 tokens)
From M=3 onward, cost grows linearly at about +28ms/step. Non-quantized ops (RMSNorm, softmax) are flat across M=1-4, so the step is entirely in the linear projections.
Context
All shapes fall below
vector_limit(=10 for M4 Pro at K,N > 4096, fromget_qmv_batch_limitusingapplegpu_g16s), so M=3 dispatches toqmv_fast_impl, notqmm_splitk.qmv_fast_implusesgrid = (M, ceil(N/8), B): one threadgroup per M row per N-tile. The cause of the discontinuity at M=3 is not clear to me. I tried two things:vector_limitto route M=3+ toqmm_splitk: significantly worse (M3/M1 goes from 1.3x to 2.7-2.9x), which makes sense sinceqmm_splitkis designed for larger M.The step appears systematically across all asymmetric projection shapes.
→ Does this look like a GPU scheduling effect (wave quantization, occupancy change)? Is there a profiling approach that could help narrow it down?
Impact
This affects any workload batching M=3-8 tokens together: small-batch server inference (3 concurrent requests), speculative decoding verify passes (draft_length=2 produces M=3), beam search. The +28% jump at M=3 is abrupt and hard to work around at the application level.
PRs #1861 (faster small-batch
qmv) and #3120 (split-Kqmm) address adjacent regimes, but M=3-9 on large asymmetric shapes seems to remain unaddressed.