Skip to content

[Performance] qmv kernel: non-linear cost step at M=3 for large MLP shapes #3553

@AirRunner

Description

@AirRunner

mx.quantized_matmul(transpose=True) shows a discontinuous cost increase at M=3 for large asymmetric MLP shapes. M=2 costs nearly the same as M=1, but M=3 costs +28-37% more. The step does not appear for square shapes (5120→5120).

Per-op microbench (group_size=64, bits=4)

import mlx.core as mx
import time, statistics

def bench_qmm(K, N, M, n_iters=200, warmup=30):
    w_f = mx.random.normal((N, K), dtype=mx.float32).astype(mx.float16)
    w_q, scales, biases = mx.quantize(w_f, group_size=64, bits=4)
    x = mx.random.normal((M, K), dtype=mx.float32).astype(mx.float16)
    mx.eval(w_q, scales, biases, x)
    times = []
    for i in range(warmup + n_iters):
        t0 = time.perf_counter()
        out = mx.quantized_matmul(x, w_q, scales, biases, transpose=True, group_size=64, bits=4)
        mx.eval(out)
        if i >= warmup:
            times.append((time.perf_counter() - t0) * 1000)
    return statistics.mean(times)

shapes = [
    ('5120 → 5120  (attn q/o)', 5120, 5120),
    ('5120 → 1024  (attn k/v)', 5120, 1024),
    ('5120 → 13824 (MLP gate/up)', 5120, 13824),
    ('13824 → 5120 (MLP down)', 13824, 5120),
]

for label, K, N in shapes:
    ms = [bench_qmm(K, N, M) for M in [1, 2, 3, 4]]
    print(f'{label}: M=1={ms[0]:.3f}ms M=2={ms[1]:.3f}ms M=3={ms[2]:.3f}ms M=4={ms[3]:.3f}ms  (M3/M1={ms[2]/ms[0]:.2f}x)')

Output on M4 Pro:

5120 → 5120  (attn q/o):    M=1=0.218ms  M=2=0.200ms  M=3=0.195ms  M=4=0.233ms  (M3/M1=0.89x)
5120 → 1024  (attn k/v):    M=1=0.112ms  M=2=0.118ms  M=3=0.127ms  M=4=0.136ms  (M3/M1=1.13x)
5120 → 13824 (MLP gate/up): M=1=0.270ms  M=2=0.272ms  M=3=0.346ms  M=4=0.425ms  (M3/M1=1.28x)
13824 → 5120 (MLP down):    M=1=0.277ms  M=2=0.292ms  M=3=0.379ms  M=4=0.477ms  (M3/M1=1.37x)

The step is shape-dependent: it only appears where output dimension >> input dimension (or vice versa), not on square shapes.

Full model forward (Qwen3.6-27B 4-bit, post-prefill 512 tokens)

M Time Ratio vs M=1
1 68.0ms 1.00x
2 69.3ms 1.02x
3 93.8ms 1.38x
4 121.9ms 1.79x
5 150.2ms 2.21x
6 179.3ms 2.64x

From M=3 onward, cost grows linearly at about +28ms/step. Non-quantized ops (RMSNorm, softmax) are flat across M=1-4, so the step is entirely in the linear projections.

Context

All shapes fall below vector_limit (=10 for M4 Pro at K,N > 4096, from get_qmv_batch_limit using applegpu_g16s), so M=3 dispatches to qmv_fast_impl, not qmm_splitk.

qmv_fast_impl uses grid = (M, ceil(N/8), B): one threadgroup per M row per N-tile. The cause of the discontinuity at M=3 is not clear to me. I tried two things:

  • Fused kernel (process M rows within a single threadgroup to share weight loads): correct output but slower than stock. The M threadgroups run in parallel in the current dispatch, and serializing them loses more than is gained.
  • Lowering vector_limit to route M=3+ to qmm_splitk: significantly worse (M3/M1 goes from 1.3x to 2.7-2.9x), which makes sense since qmm_splitk is designed for larger M.

The step appears systematically across all asymmetric projection shapes.

→ Does this look like a GPU scheduling effect (wave quantization, occupancy change)? Is there a profiling approach that could help narrow it down?

Impact

This affects any workload batching M=3-8 tokens together: small-batch server inference (3 concurrent requests), speculative decoding verify passes (draft_length=2 produces M=3), beam search. The +28% jump at M=3 is abrupt and hard to work around at the application level.

PRs #1861 (faster small-batch qmv) and #3120 (split-K qmm) address adjacent regimes, but M=3-9 on large asymmetric shapes seems to remain unaddressed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions