Skip to content

Comments

Add NAX Split-K GEMM for large-K matmuls to improve performance#3018

Merged
awni merged 7 commits intoml-explore:mainfrom
hxu296:splitk-nax-pr
Jan 26, 2026
Merged

Add NAX Split-K GEMM for large-K matmuls to improve performance#3018
awni merged 7 commits intoml-explore:mainfrom
hxu296:splitk-nax-pr

Conversation

@hxu296
Copy link
Contributor

@hxu296 hxu296 commented Jan 19, 2026

Summary

Adds NAX Split-K GEMM support to address high runtime variance and slow tail cases in large-K matmals on M5 GPU. For GEMMs with very large K dimensions, partitioning work along K substantially reduces variance and eliminates slow tail cases, achieving up to ~1.6× speedup over the fused NAX GEMM.

Changes

  • New kernel steel_gemm_splitk_nax partitions K-dimension work across threadgroups, then accumulates partial sums
  • Dispatch heuristic: batch_size==1 AND M×N >= 2048^2 AND K >= 10240 AND K >= 3×max(M,N) AND NAX available
  • Added a benchmark script splitk_gemm_bench.py to compare fused NAX vs Split-K NAX GEMM performance
  • Refactored NAX gemm_loop parameters to reuse it across fused and split-k paths

Performance on M5

Performance (bfloat16):
    M     N      K     Regular     Split-K     Speedup
----------------------------------------------------------------------
 2048  2048  10240    156.11ms    142.52ms      1.10x
 2048  3072  10240    236.23ms    211.46ms      1.12x
 3072  3072  10240    376.75ms    315.02ms      1.20x
 3072  3072  12288    500.96ms    372.27ms      1.35x
 3072  4096  12288    621.92ms    494.85ms      1.26x
 4096  4096  12288    844.15ms    659.48ms      1.28x
 4096  4096  18432   1472.46ms    984.36ms      1.50x
 4096  4096  21504   1781.44ms   1148.95ms      1.55x
 4096  6144  21504   2807.22ms   1734.34ms      1.62x
 6144  6144  21504   4078.63ms   2635.42ms      1.55x

Performance (float16):
    M     N      K     Regular     Split-K     Speedup
----------------------------------------------------------------------
 2048  2048  10240    155.20ms    147.03ms      1.06x
 2048  3072  10240    226.05ms    217.93ms      1.04x
 3072  3072  10240    378.34ms    325.91ms      1.16x
 3072  3072  12288    469.55ms    386.22ms      1.22x
 3072  4096  12288    618.64ms    521.50ms      1.19x
 4096  4096  12288    904.95ms    702.39ms      1.29x
 4096  4096  18432   1474.75ms   1079.02ms      1.37x
 4096  4096  21504   1865.51ms   1281.52ms      1.46x
 4096  6144  21504   2755.68ms   1965.28ms      1.40x
 6144  6144  21504   4086.48ms   3017.21ms      1.35x

Performance (float32):
    M     N      K     Regular     Split-K     Speedup
----------------------------------------------------------------------
 2048  2048  10240    193.15ms    196.28ms      0.98x
 2048  3072  10240    293.95ms    302.57ms      0.97x
 3072  3072  10240    449.79ms    473.58ms      0.95x
 3072  3072  12288    547.38ms    566.81ms      0.97x
 3072  4096  12288    727.40ms    758.03ms      0.96x
 4096  4096  12288    989.01ms   1014.35ms      0.98x
 4096  4096  18432   1808.03ms   1629.34ms      1.11x
 4096  4096  21504   2692.04ms   1901.86ms      1.42x
 4096  6144  21504   4260.02ms   2867.28ms      1.49x
 6144  6144  21504   6762.67ms   4345.39ms      1.56x

Design Notes

  1. Some dispatch thresholds (i.e. M×N >= 2048^2 AND K >= 10240) and split_k_partition_size (3072) are empirically determined on M5 GPU. Future work needed to tune for additional NAX devices as they become available.
  2. Added MLX_DISABLE_SPLITK_NAX env var for benchmarking purposes to compare against regular GEMM. Can be removed before merging.

Fixes #3017

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@hxu296 hxu296 changed the title Splitk nax pr Add NAX Split-K GEMM for large-K matrix multiplications Jan 19, 2026
@hxu296 hxu296 changed the title Add NAX Split-K GEMM for large-K matrix multiplications Add NAX Split-K GEMM for large-K matmuls to improve performance on M5 Jan 19, 2026
@hxu296 hxu296 changed the title Add NAX Split-K GEMM for large-K matmuls to improve performance on M5 Add NAX Split-K GEMM for large-K matmuls to improve performance Jan 19, 2026
Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very nicely done, thanks for the addition! We should remove the env var prior to merging. I will leave it open for a little while in case @jagrit06 or @angeloskath want to take a look.

@awni awni requested review from angeloskath and jagrit06 January 22, 2026 14:53
Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless @jagrit06 has comments we should merge this it looks great!

@hxu296 feel free to remove the env-var.

@hxu296
Copy link
Contributor Author

hxu296 commented Jan 26, 2026

Thanks for the review! I removed the test env var. I also updated the benchmark comparison to MLX Split-K vs PyTorch MPS GEMM (instead of MLX fused GEMM), to compare with external baseline. On M5, Split-K brings MLX into the same perf ballpark as Torch for large-K GEMMs. For bf16 / fp16, MLX is ~0.90–0.95× Torch across the tested shapes, and for fp32, MLX is ~2.5–2.8× faster than Torch (Torch appears not to be using TF32 downcast on MPS in this setup). There is likely additional headroom here, and further tuning could help close the remaining gap for bf16/fp16.

Perf numbers (as produced by python3 benchmarks/python/large_gemm_bench.py):

Performance (bfloat16):
    M     N      K         MLX (ms)       Torch (ms)     Speedup
--------------------------------------------------------------------------------
 2048  2048  10240     7.07± 0.02     6.55± 0.05      0.93x
 2048  3072  10240    10.61± 0.01     9.71± 0.05      0.92x
 3072  3072  10240    15.88± 0.02    14.62± 0.21      0.92x
 3072  3072  12288    18.80± 0.02    17.52± 0.06      0.93x
 3072  4096  12288    25.04± 0.03    23.71± 0.04      0.95x
 4096  4096  12288    33.16± 0.04    31.42± 0.04      0.95x
 4096  4096  18432    50.62± 0.36    46.94± 0.07      0.93x
 4096  4096  21504    59.03± 0.22    54.21± 0.08      0.92x
 4096  6144  21504    88.85± 0.13    80.80± 0.40      0.91x
 6144  6144  21504   134.78± 1.19   123.98± 0.91      0.92x

Performance (float16):
    M     N      K         MLX (ms)       Torch (ms)     Speedup
--------------------------------------------------------------------------------
 2048  2048  10240     7.37± 0.10     6.74± 0.10      0.92x
 2048  3072  10240    11.22± 0.16    10.09± 0.12      0.90x
 3072  3072  10240    17.11± 0.26    15.28± 0.30      0.89x
 3072  3072  12288    20.35± 0.25    18.16± 0.12      0.89x
 3072  4096  12288    26.96± 0.20    24.54± 0.10      0.91x
 4096  4096  12288    36.14± 0.13    32.81± 0.47      0.91x
 4096  4096  18432    54.66± 0.27    49.71± 0.85      0.91x
 4096  4096  21504    64.04± 1.49    57.99± 1.21      0.91x
 4096  6144  21504    97.17± 1.58    87.69± 1.23      0.90x
 6144  6144  21504   147.75± 0.66   132.77± 1.32      0.90x

Performance (float32):
    M     N      K         MLX (ms)       Torch (ms)     Speedup
--------------------------------------------------------------------------------
 2048  2048  10240    10.13± 0.07    27.34± 0.53      2.70x
 2048  3072  10240    16.21± 0.39    44.42± 3.27      2.74x
 3072  3072  10240    24.51± 0.76    64.79± 3.74      2.64x
 3072  3072  12288    30.08± 1.49    80.33± 5.05      2.67x
 3072  4096  12288    41.64± 2.44   116.17± 6.22      2.79x
 4096  4096  12288    62.35± 7.56   155.60± 7.61      2.50x
 4096  4096  18432   104.85± 9.08   270.32± 9.93      2.58x
 4096  4096  21504   129.25±12.24   355.18±10.21      2.75x
 4096  6144  21504   214.74±14.29   590.13±24.97      2.75x
 6144  6144  21504   335.93±22.29   988.37±19.87      2.94x

Copy link
Member

@jagrit06 jagrit06 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runs well when tested on my M5 machine!
we can merge it for now once all the other comments are addressed!

@awni awni merged commit 7ed2b6b into ml-explore:main Jan 26, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Enhancement] Investigate NAX Split-K for large-K GEMM stability on M5

4 participants