Skip to content

fast: add fused_swiglu_gather_qmv (Metal kernel fusing SwiGLU + gather_qmm)#2

Merged
benjamin-levin merged 5 commits into
mainfrom
fused-swiglu-gather-qmv-pr
May 19, 2026
Merged

fast: add fused_swiglu_gather_qmv (Metal kernel fusing SwiGLU + gather_qmm)#2
benjamin-levin merged 5 commits into
mainfrom
fused-swiglu-gather-qmv-pr

Conversation

@benjamin-levin
Copy link
Copy Markdown
Owner

@benjamin-levin benjamin-levin commented May 18, 2026

In-fork PR — not for upstream submission yet. Iterating on the fork before opening upstream.

Summary

Adds mx.fast.fused_swiglu_gather_qmv, a single Metal kernel that applies the SwiGLU activation (silu(gate) * up) inline with a quantized gather_qmm matvec, skipping the intermediate activated tensor and one full Metal command dispatch.

Motivation

The Mixture-of-Experts decode path on Metal currently goes through:

y = mx.gather_qmm(mx.silu(gate) * up, w, scales, biases, rhs_indices=...)

That's three Metal dispatches (silu, elementwise multiply, gather_qmm) plus one full materialized activation tensor of shape (..., M, K). For MoE expert matvecs (M=1 at decode), where gather_qmm is the heavy primitive, the surrounding activation work is a meaningful share of the cycle. Fusing SwiGLU directly into the quantized matvec kernel removes both the intermediate tensor and the two extra dispatches.

Performance

Microbench matrix (M=1, bf16, group_size=64, bits=4, M4 Max)

Full decode-shape sweep against the explicit three-launch baseline gather_qmm(silu(gate)*up, w, scales, biases, rhs_indices=...). Each cell: 20 warmup + 500 timed iterations, fresh inputs per (K, N, E) tuple. All 27 cells satisfy the fast-path constraints N % 8 == 0 and K % 512 == 0. Times in microseconds per call.

K = 512:

N \ E E=4 E=8 E=16
N=64 127.0 → 116.1 (1.09x) 124.4 → 117.5 (1.06x) 124.0 → 105.7 (1.17x)
N=128 124.5 → 116.2 (1.07x) 123.3 → 115.7 (1.07x) 121.3 → 117.5 (1.03x)
N=256 119.4 → 114.5 (1.04x) 121.3 → 113.8 (1.07x) 123.1 → 116.5 (1.06x)

K = 1024:

N \ E E=4 E=8 E=16
N=64 115.1 → 110.0 (1.05x) 117.5 → 114.3 (1.03x) 122.3 → 116.6 (1.05x)
N=128 120.3 → 113.9 (1.06x) 120.1 → 111.5 (1.08x) 120.1 → 117.8 (1.02x)
N=256 122.7 → 115.8 (1.06x) 124.4 → 117.6 (1.06x) 125.7 → 119.5 (1.05x)

K = 2048:

N \ E E=4 E=8 E=16
N=64 122.4 → 116.8 (1.05x) 122.4 → 117.5 (1.04x) 123.9 → 118.4 (1.05x)
N=128 121.5 → 117.3 (1.04x) 124.6 → 119.2 (1.05x) 126.4 → 121.3 (1.04x)
N=256 124.6 → 118.7 (1.05x) 126.5 → 119.0 (1.06x) 121.7 → 126.7 (0.96x)

Summary across the 27-cell matrix: geomean 1.053x, median 1.050x, range 0.96x — 1.17x, 26/27 cells at parity-or-better, 14/27 cells ≥ 1.05x. All 27 cells pass the correctness gate (cos ≥ 0.9999; observed range 0.99996 — 0.99998).

The savings come from the eliminated activation tensor + two extra dispatches; this is a roughly constant ~6 μs/call across the sweep. As K and N grow, the absolute saving stays flat but the relative speedup compresses because the gather_qmm body itself grows. The largest wins are at small-K / wide-E shapes where the surrounding work is a larger share of the cycle; the one regression (K=2048, N=256, E=16) is a corner where the fused launch is slightly slower at full-K worst-case quantization workload — within run-to-run noise on M4 Max but worth flagging.

A single isolated-process repro of the headline shape (K=512, N=128, E=8, 500-iter, fresh subprocess) showed 1.10x — 1.17x across three fresh launches. Original PR description's 1.27x figure was measured from an earlier microbench at slightly different framing (smaller iter count, cold cache); the 1.05x geomean / 1.07x at headline above is the more conservative number to plan around.

Off-target shapes (fast-path constraints violated)

The kernel currently has no input validation for the fast-path constraints. Off-target shapes silently produce incorrect output instead of erroring or falling back:

K N E stock (μs) fused (μs) "speedup" cos sim safe?
128 128 8 135.1 115.6 1.17x 0.553 NO — wrong output
256 128 8 126.3 116.5 1.08x 0.241 NO — wrong output
768 128 8 126.0 118.5 1.06x 0.850 NO — wrong output
512 4 8 129.1 111.9 1.15x 0.593 NO — wrong output
512 12 8 125.3 112.0 1.12x 0.906 NO — wrong output
512 20 8 123.9 107.9 1.15x 0.952 NO — wrong output
512 100 8 124.8 114.6 1.09x 0.980 NO — wrong output
512 132 8 113.6 107.9 1.05x 0.981 NO — wrong output

This needs to be addressed before upstreaming — either (a) add eval_gpu shape checks that raise runtime_error when the fast-path constraints aren't met, or (b) emit a generic slow-path branch inside the kernel. Option (a) is the minimal fix and matches MLX's existing convention for fast-only primitives.

Bit-equivalent under cosine ≥ 0.9999 across the entire fast-path matrix (included in the unit test).

API

out = mx.fast.fused_swiglu_gather_qmv(
    gate,           # (..., M, K) — same dtype as up
    up,             # (..., M, K) — same shape/dtype as gate
    w,              # (E, N, K/pack_factor) packed uint32
    scales,         # (E, N, K/group_size)
    biases,         # optional, same shape as scales
    rhs_indices,    # uint32, broadcastable to gate's batch dims
    *,
    group_size=64,
    bits=4,
    mode="affine",
    stream=None,
)
# Returns: array of shape (..., M, N), same dtype as gate.

Scope

Currently only the fast path is implemented in Metal (requires N % 8 == 0 and K % 512 == 0, which holds for Qwen3.6 and other common MoE expert shapes). Other shapes should fall back to the explicit gather_qmm(silu(gate)*up, ...) sequence at the caller — there is no in-kernel slow-path branch.

Caveat: as of this measurement run, the kernel does NOT guard against off-target shapes — it silently runs and produces wrong output (see table above). This needs to be fixed (either by raising or by adding a slow-path branch) before upstreaming.

The primitive supports affine quant mode only (matches the affine_gather_qmv it shadows). CPU is unimplemented (runtime_error); a no_gpu stub is provided so Linux-only builds still link.

Files

File +/-
mlx/backend/metal/kernels/quantized.h +294
mlx/backend/metal/kernels/quantized.metal +1
mlx/backend/metal/quantized.cpp +116
mlx/backend/no_gpu/primitives.cpp +1
mlx/fast.cpp +61
mlx/fast.h +26
mlx/fast_primitives.h +36
python/src/fast.cpp +62
python/tests/test_fast.py +182

Test plan

Three new tests added in python/tests/test_fast.py (all gated on mx.metal.is_available()):

  • test_fused_swiglu_gather_qmv_correctness — numerical agreement vs the explicit gather_qmm(silu(gate) * up, w, scales, biases, rhs_indices=...) reference path, in both float16 and bfloat16, cosine similarity ≥ 0.9999.
  • test_fused_swiglu_gather_qmv_shapes — exercises a few small fast-path shapes (B, M, K, N, E): (1,1,512,128,4), (2,1,512,128,8), (3,1,1024,256,8), (4,1,512,64,8). Checks both output shape/dtype and numerical match.
  • test_fused_swiglu_gather_qmv_input_validation — exercises the C++ guards: mismatched gate/up shape, mismatched gate/up dtype, non-uint32 rhs_indices, and gate.ndim < 2 all raise.

Open questions for maintainers

  1. Off-target guard. The kernel currently runs (and produces wrong output) for N % 8 != 0 or K % 512 != 0. Should eval_gpu raise on these shapes, or should the kernel grow a generic slow-path branch? Leaning toward the raise — the MoE-decode shapes that benefit most are already on the fast path, and a raise gives the caller a deterministic fallback signal.
  2. Slow-path expansion. If guarding rather than supporting, broadening the kernel mostly serves symmetry with gather_qmm itself.
  3. API placement. Filed under mx.fast.* alongside scaled_dot_product_attention and fused_qsdpa (also activation-fused fast paths). The alternative would be a new mx.fused.* namespace if more SwiGLU-style fusions land — swiglu + qmv non-gathered, swiglu + gemm, etc. Happy to rename before upstreaming.
  4. Quant modes. Currently affine only. Extending to mxfp4 / mxfp8 would follow the same template-instantiation pattern as gather_qmv. Worth doing in this PR, or in a follow-up?

CI status: build_and_test.yml is gated to ml-explore/mlx so it won't run on this fork PR. Linux validation is via manually-dispatched Nightly Build (build_linux_release x86 + arm64 confirmed green).

@benjamin-levin benjamin-levin force-pushed the fused-swiglu-gather-qmv-pr branch from fafdb3b to 8756748 Compare May 19, 2026 00:39
@benjamin-levin benjamin-levin changed the title WIP: fused SiLU(gate)*up + gather_qmv kernel (CI test) fast: add fused_swiglu_gather_qmv (Metal kernel fusing SwiGLU + gather_qmm) May 19, 2026
…r_qmm)

Adds mx.fast.fused_swiglu_gather_qmv, a single Metal kernel that
applies the SwiGLU activation (silu(gate) * up) inline with a
quantized gather_qmm matvec, skipping the intermediate activated
tensor and one full Metal command dispatch.

Replaces the three-launch sequence

    y = mx.gather_qmm(mx.silu(gate) * up, w, scales, biases,
                      rhs_indices=...)

with one kernel. Microbench shows 1.27x speedup over that baseline on
M-series GPUs at MoE expert-matvec shapes.

Currently only the fast path is implemented (requires N % 8 == 0 and
K % 512 == 0); other shapes should fall back to the explicit sequence.

API:
    out = mx.fast.fused_swiglu_gather_qmv(
        gate, up, w, scales, biases, rhs_indices,
        group_size=64, bits=4, mode="affine")

Files:
- mlx/fast.{h,cpp}, mlx/fast_primitives.h: new primitive +
  user-facing entrypoint with input validation
- mlx/backend/metal/kernels/{quantized.h,quantized.metal}: new
  affine_gather_qmv_swiglu kernel template
- mlx/backend/metal/quantized.cpp: dispatch + eval_gpu
- mlx/backend/no_gpu/primitives.cpp: NO_GPU stub for non-Metal builds
- python/src/fast.cpp: nanobind binding + docstring
@benjamin-levin benjamin-levin force-pushed the fused-swiglu-gather-qmv-pr branch from 8756748 to c930416 Compare May 19, 2026 00:41
@benjamin-levin benjamin-levin marked this pull request as ready for review May 19, 2026 18:23
@benjamin-levin benjamin-levin merged commit 4928b8e into main May 19, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants