Add NAX Split-K GEMM for large-K matmuls to improve performance#3018
Add NAX Split-K GEMM for large-K matmuls to improve performance#3018awni merged 7 commits intoml-explore:mainfrom
Conversation
awni
left a comment
There was a problem hiding this comment.
This looks very nicely done, thanks for the addition! We should remove the env var prior to merging. I will leave it open for a little while in case @jagrit06 or @angeloskath want to take a look.
|
Thanks for the review! I removed the test env var. I also updated the benchmark comparison to MLX Split-K vs PyTorch MPS GEMM (instead of MLX fused GEMM), to compare with external baseline. On M5, Split-K brings MLX into the same perf ballpark as Torch for large-K GEMMs. For bf16 / fp16, MLX is ~0.90–0.95× Torch across the tested shapes, and for fp32, MLX is ~2.5–2.8× faster than Torch (Torch appears not to be using TF32 downcast on MPS in this setup). There is likely additional headroom here, and further tuning could help close the remaining gap for bf16/fp16. Perf numbers (as produced by |
jagrit06
left a comment
There was a problem hiding this comment.
Runs well when tested on my M5 machine!
we can merge it for now once all the other comments are addressed!
Summary
Adds NAX Split-K GEMM support to address high runtime variance and slow tail cases in large-K matmals on M5 GPU. For GEMMs with very large K dimensions, partitioning work along K substantially reduces variance and eliminates slow tail cases, achieving up to ~1.6× speedup over the fused NAX GEMM.
Changes
steel_gemm_splitk_naxpartitions K-dimension work across threadgroups, then accumulates partial sumssplitk_gemm_bench.pyto compare fused NAX vs Split-K NAX GEMM performancePerformance on M5
Design Notes
MLX_DISABLE_SPLITK_NAXenv var for benchmarking purposes to compare against regular GEMM. Can be removed before merging.Fixes #3017
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes