fast: add fused_swiglu_gather_qmv (Metal kernel fusing SwiGLU + gather_qmm)#2
Merged
Merged
Conversation
fafdb3b to
8756748
Compare
…r_qmm)
Adds mx.fast.fused_swiglu_gather_qmv, a single Metal kernel that
applies the SwiGLU activation (silu(gate) * up) inline with a
quantized gather_qmm matvec, skipping the intermediate activated
tensor and one full Metal command dispatch.
Replaces the three-launch sequence
y = mx.gather_qmm(mx.silu(gate) * up, w, scales, biases,
rhs_indices=...)
with one kernel. Microbench shows 1.27x speedup over that baseline on
M-series GPUs at MoE expert-matvec shapes.
Currently only the fast path is implemented (requires N % 8 == 0 and
K % 512 == 0); other shapes should fall back to the explicit sequence.
API:
out = mx.fast.fused_swiglu_gather_qmv(
gate, up, w, scales, biases, rhs_indices,
group_size=64, bits=4, mode="affine")
Files:
- mlx/fast.{h,cpp}, mlx/fast_primitives.h: new primitive +
user-facing entrypoint with input validation
- mlx/backend/metal/kernels/{quantized.h,quantized.metal}: new
affine_gather_qmv_swiglu kernel template
- mlx/backend/metal/quantized.cpp: dispatch + eval_gpu
- mlx/backend/no_gpu/primitives.cpp: NO_GPU stub for non-Metal builds
- python/src/fast.cpp: nanobind binding + docstring
8756748 to
c930416
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
In-fork PR — not for upstream submission yet. Iterating on the fork before opening upstream.
Summary
Adds
mx.fast.fused_swiglu_gather_qmv, a single Metal kernel that applies the SwiGLU activation (silu(gate) * up) inline with a quantizedgather_qmmmatvec, skipping the intermediate activated tensor and one full Metal command dispatch.Motivation
The Mixture-of-Experts decode path on Metal currently goes through:
That's three Metal dispatches (
silu, elementwise multiply,gather_qmm) plus one full materialized activation tensor of shape(..., M, K). For MoE expert matvecs (M=1at decode), wheregather_qmmis the heavy primitive, the surrounding activation work is a meaningful share of the cycle. Fusing SwiGLU directly into the quantized matvec kernel removes both the intermediate tensor and the two extra dispatches.Performance
Microbench matrix (M=1, bf16, group_size=64, bits=4, M4 Max)
Full decode-shape sweep against the explicit three-launch baseline
gather_qmm(silu(gate)*up, w, scales, biases, rhs_indices=...). Each cell: 20 warmup + 500 timed iterations, fresh inputs per (K, N, E) tuple. All 27 cells satisfy the fast-path constraintsN % 8 == 0andK % 512 == 0. Times in microseconds per call.K = 512:
K = 1024:
K = 2048:
Summary across the 27-cell matrix: geomean 1.053x, median 1.050x, range 0.96x — 1.17x, 26/27 cells at parity-or-better, 14/27 cells ≥ 1.05x. All 27 cells pass the correctness gate (cos ≥ 0.9999; observed range 0.99996 — 0.99998).
The savings come from the eliminated activation tensor + two extra dispatches; this is a roughly constant ~6 μs/call across the sweep. As
KandNgrow, the absolute saving stays flat but the relative speedup compresses because thegather_qmmbody itself grows. The largest wins are at small-K / wide-E shapes where the surrounding work is a larger share of the cycle; the one regression (K=2048, N=256, E=16) is a corner where the fused launch is slightly slower at full-K worst-case quantization workload — within run-to-run noise on M4 Max but worth flagging.A single isolated-process repro of the headline shape (
K=512, N=128, E=8, 500-iter, fresh subprocess) showed 1.10x — 1.17x across three fresh launches. Original PR description's 1.27x figure was measured from an earlier microbench at slightly different framing (smaller iter count, cold cache); the 1.05x geomean / 1.07x at headline above is the more conservative number to plan around.Off-target shapes (fast-path constraints violated)
The kernel currently has no input validation for the fast-path constraints. Off-target shapes silently produce incorrect output instead of erroring or falling back:
This needs to be addressed before upstreaming — either (a) add
eval_gpushape checks that raiseruntime_errorwhen the fast-path constraints aren't met, or (b) emit a generic slow-path branch inside the kernel. Option (a) is the minimal fix and matches MLX's existing convention for fast-only primitives.Bit-equivalent under cosine ≥ 0.9999 across the entire fast-path matrix (included in the unit test).
API
Scope
Currently only the fast path is implemented in Metal (requires
N % 8 == 0andK % 512 == 0, which holds for Qwen3.6 and other common MoE expert shapes). Other shapes should fall back to the explicitgather_qmm(silu(gate)*up, ...)sequence at the caller — there is no in-kernel slow-path branch.Caveat: as of this measurement run, the kernel does NOT guard against off-target shapes — it silently runs and produces wrong output (see table above). This needs to be fixed (either by raising or by adding a slow-path branch) before upstreaming.
The primitive supports
affinequant mode only (matches theaffine_gather_qmvit shadows). CPU is unimplemented (runtime_error); ano_gpustub is provided so Linux-only builds still link.Files
mlx/backend/metal/kernels/quantized.hmlx/backend/metal/kernels/quantized.metalmlx/backend/metal/quantized.cppmlx/backend/no_gpu/primitives.cppmlx/fast.cppmlx/fast.hmlx/fast_primitives.hpython/src/fast.cpppython/tests/test_fast.pyTest plan
Three new tests added in
python/tests/test_fast.py(all gated onmx.metal.is_available()):test_fused_swiglu_gather_qmv_correctness— numerical agreement vs the explicitgather_qmm(silu(gate) * up, w, scales, biases, rhs_indices=...)reference path, in bothfloat16andbfloat16, cosine similarity ≥ 0.9999.test_fused_swiglu_gather_qmv_shapes— exercises a few small fast-path shapes(B, M, K, N, E):(1,1,512,128,4),(2,1,512,128,8),(3,1,1024,256,8),(4,1,512,64,8). Checks both output shape/dtype and numerical match.test_fused_swiglu_gather_qmv_input_validation— exercises the C++ guards: mismatchedgate/upshape, mismatchedgate/updtype, non-uint32rhs_indices, andgate.ndim < 2all raise.Open questions for maintainers
N % 8 != 0orK % 512 != 0. Shouldeval_gpuraise on these shapes, or should the kernel grow a generic slow-path branch? Leaning toward the raise — the MoE-decode shapes that benefit most are already on the fast path, and a raise gives the caller a deterministic fallback signal.gather_qmmitself.mx.fast.*alongsidescaled_dot_product_attentionandfused_qsdpa(also activation-fused fast paths). The alternative would be a newmx.fused.*namespace if more SwiGLU-style fusions land —swiglu + qmvnon-gathered,swiglu + gemm, etc. Happy to rename before upstreaming.affineonly. Extending tomxfp4/mxfp8would follow the same template-instantiation pattern asgather_qmv. Worth doing in this PR, or in a follow-up?CI status:
build_and_test.ymlis gated toml-explore/mlxso it won't run on this fork PR. Linux validation is via manually-dispatched Nightly Build (build_linux_releasex86 + arm64 confirmed green).