4-bit gather_qmm weight-reuse GEMM tops out ~80 GB/s at small MoE M on M5 Pro - is tile tuning of gather_qmm_rhs_nax feasible?
#3691
programVeins
started this conversation in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Measured-findings discussion, not a bug.
On an M5 Pro (48 GB, macOS 27 beta), single-stream 4-bit MoE is bottlenecked by streaming expert weights, and the weight-reuse GEMM at the M the MoE actually uses (~16 rows/expert) streams at only ~80 GB/s - vs ~220 GB/s for
qmv(M≤2) and ~263 GB/s elementwise on the same machine. Motivating workload:diffusiongemma-26B-A4B-it-4bit(128 experts/top-8, 256-token canvas →M ≈ 256·8/128 ≈ 16rows/expert; ~12.8 GB of 4-bit weights read per denoising step). MLX@a6ec712with the #3632 kernel-name fix applied so the NAX gather path loads.A 4-bit
quantized_matmulweight-reuse sweep (fixed weights, vary M) gives a clean two-regime curve. The cliff lands exactly where reuse starts forcing threadgroup staging, and the plateau is precisely the M the MoE needs:qmvqmvqmmqmm(incl. NAXmatmul2d)qmmImportant: the NAX tensor-op path is on this curve, not above it - the
mpp::tensor_ops::matmul2d((16,32,16)descriptor insteel/gemm/nax.h) is what's selected at M=16–64 and measures the same ~80 GB/s. So this isn't "the generic kernel is slow, NAX would fix it." I also tried 5 custom kernels (register-resident multi-row qmv, L2-weight-sharing, capacity-padded NAX gather,simdgroup_half8x8, a JITmatmul2d) to beat it; best was ~84 GB/s. The L2-sharing idea was disproven (simdgroups don't stay in lockstep, so each re-reads weights from DRAM - which is why reuse needs explicit staging).gather_qmm_rhs_naxhardcodesbm=bn=bk=64, wm=wn=2with a literal// TODO: Tune the block sizes.Question: is small-M tile tuning of
gather_qmm_rhs_nax/gather_qmm_t_naxtractable for these shapes (M≈16, E=128, K=2816, N∈{1408,704}, gs=64) - e.g. smallerbm+ more N-tiling, or a fused dequant→matmul2dstaging that keeps a staged weight row resident across more activation rows - or is ~80 GB/s the understoodmatmul2dstaged ceiling on this generation? If it's the known ceiling, that's a useful answer too: it confirms single-stream 4-bit MoE on M5 is bandwidth-bound here and the real levers are lower-bit experts or multi-canvas batching, not kernel tuning. Happy to run any tile/shape sweep on M5 Pro and post numbers. (NAX path requires the #3632bk32→bk64fix to load - this is a +1 confirmation of that PR, not a competing fix.)Beta Was this translation helpful? Give feedback.
All reactions