Skip to content

Implement mxfp4 split-k gemm#958

Merged
willghatch merged 21 commits intomainfrom
users/willghatch/splitk-mxfp4
Apr 21, 2026
Merged

Implement mxfp4 split-k gemm#958
willghatch merged 21 commits intomainfrom
users/willghatch/splitk-mxfp4

Conversation

@willghatch
Copy link
Copy Markdown
Contributor

@willghatch willghatch commented Feb 24, 2026

The core things added are split-k gemm, and it is tested for (1) generation of the buffer_atomic_pk_add_bf16 instruction that we wanted to use, and (2) for gemm correctness.

Overview of changes unrelated to wave_asm:

  • remove_global_indexing in general_utils.py: Zeroes out tiling constraint starts (e.g. K_SPLIT_OFF) alongside workgroup IDs before dimension scaling, so that the subtraction of the start offset doesn't mix scaled and unscaled units (K vs K/32 for MXFP4 scales).

  • Fixing spurious bounds on split-K tiling that prevented scale vector merging: TilingConstraint.get_index_bound was conservatively generating bounds for the split-K case because sympy could not prove that ceiling(Min(K, f(wg)) / tile)

    • tile <= K. These bounds prevented merge_contiguous_reads from combining scalar scale reads into vector<4xi8> loads (it skips reads that already have bounds). Add _work_may_exceed_dim() to structurally detect the aligned split-k pattern and prove no overshoot, avoiding the spurious bound. (This was necessary to get scale_preshuffle to have 4x vector loads when combined with split-k.)

@willghatch
Copy link
Copy Markdown
Contributor Author

@harsh-nod This has splitk with preshuffle_scales functional with the 4x vector load. I've done some basic cleanup, but as mentioned there are still parts of it that I haven't fully reviewed or understood.

@willghatch willghatch force-pushed the users/willghatch/splitk-mxfp4 branch 2 times, most recently from 88b0c99 to 8de6506 Compare February 24, 2026 18:44
@willghatch
Copy link
Copy Markdown
Contributor Author

@harsh-nod this is now rebased on top of main, which now has the wave_asm backend commit that you carved out of this one. So it should be ready to go.

Comment thread tests/kernel/wave_gemm_test.py Outdated
Comment thread wave_lang/kernel/wave/utils/general_utils.py Outdated
Comment thread wave_lang/kernel/wave/constraints.py Outdated
@willghatch willghatch force-pushed the users/willghatch/splitk-mxfp4 branch 3 times, most recently from 65575f0 to ca5f8e8 Compare February 25, 2026 23:31
Comment thread examples/python/7.1_schedule.py Outdated
Comment thread examples/python/7.1_schedule.py Outdated
Comment thread examples/python/7.1_schedule.py Outdated
Comment thread tests/kernel/wave_gemm_test.py
Comment thread wave_lang/kernel/wave/constraints.py Outdated
@willghatch willghatch force-pushed the users/willghatch/splitk-mxfp4 branch 4 times, most recently from 2f97e30 to 03ff9aa Compare March 9, 2026 21:43
Comment thread wave_lang/kernel/wave/compile.py
Comment thread examples/python/7.1_schedule.py
@willghatch willghatch force-pushed the users/willghatch/splitk-mxfp4 branch from f6815a2 to a0a9afd Compare March 10, 2026 00:32
@willghatch willghatch force-pushed the users/willghatch/splitk-mxfp4 branch 3 times, most recently from 438c41e to 5c73bff Compare March 30, 2026 21:35
Comment thread wave_lang/kernel/wave/constraints.py Outdated
Comment thread wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py
@willghatch willghatch force-pushed the users/willghatch/splitk-mxfp4 branch 2 times, most recently from aef1874 to 64e14a9 Compare April 7, 2026 23:11
The core things added are split-k gemm, and it is tested for (1) generation of the `buffer_atomic_pk_add_bf16` instruction that we wanted to use, and (2) for gemm correctness.

Overview of some of the major changes:

- `remove_global_indexing` in `general_utils.py`: Zeroes out tiling constraint
  starts (e.g. `K_SPLIT_OFF`) alongside workgroup IDs before dimension scaling,
  so that the subtraction of the start offset doesn't mix scaled and unscaled
  units (K vs K/32 for MXFP4 scales).

- Fixing spurious bounds on split-K tiling that prevented scale vector merging:
  TilingConstraint.get_index_bound was conservatively generating bounds for the
  split-K case because sympy could not prove that ceiling(Min(K, f(wg)) / tile)
  * tile <= K.  These bounds prevented merge_contiguous_reads from combining
  scalar scale reads into vector<4xi8> loads (it skips reads that already have
  bounds).  Add _work_may_exceed_dim() to structurally detect the aligned
  split-k pattern and prove no overshoot, avoiding the spurious bound.  (This
  was necessary to get scale_preshuffle to have 4x vector loads when combined
  with split-k.)

Signed-off-by: William G Hatch <william@hatch.uno>
But not for _cpp variants (waveasm backend), which has some issues.

Signed-off-by: William G Hatch <william@hatch.uno>
Remove undefined skip_if_no_gpu() and skip_if_no_wave_lang() calls,
remove undefined backend fixture parameter from test signatures, and
remove test_compare_backends_copy_kernel which referenced multiple
undefined functions (compare_with_python_backend, get_target_arch).

Made-with: Cursor
Signed-off-by: William G Hatch <william@hatch.uno>
The work_bound for TilingConstraint with a nonzero start is
start + tile * ceiling(...), an Add expression.  _extract_tile_and_ceiling
only matched pure Mul, so it always returned (None, None) for split-K,
forcing unnecessary bounds checks that prevent read merging.

Strip the additive start offset before matching the tile*ceiling core.
Also use as_numer_denom() instead of .simplify() as suggested in review.

Signed-off-by: William G Hatch <william@hatch.uno>
Each split must tile evenly by BLOCK_K for correctness.  Add explicit
validation alongside the existing k_per_split >= BLOCK_K check.

Signed-off-by: William G Hatch <william@hatch.uno>
V_LSHRREV_B32 only shifts right but does not clear the upper bits,
unlike V_BFE_U32 which extracts a specific bitfield.  Add a V_AND_B32
with a mask of (1 << elemBits) - 1 after the shift to match the
semantics of bitfield extraction.

Applies to both handleVectorExtract and handleVectorExtractStridedSlice.

Signed-off-by: William G Hatch <william@hatch.uno>
Delete the duplicate split-K MXFP4 kernel builder from gemm.py and
migrate all callers to the tagged variants in tagged_mxfp4_gemm.py.
The tagged kernels already use SHARED_ADDRESS_SPACE with
use_global_to_shared=True, which is the preferred configuration.

Signed-off-by: William G Hatch <william@hatch.uno>
The Add-stripping logic incorrectly counted all Mul terms instead of
only those containing a ceiling factor, causing it to miss the split-K
pattern where work_bound = start_mul + tile*ceiling(...).

Also replace the direct isinstance(numer, Min) check with a recursive
search, since sympy distributes the division to produce
Min(dim, ...) + other_terms rather than a bare Min as the numerator.

Signed-off-by: William G Hatch <william@hatch.uno>
Signed-off-by: William G Hatch <william@hatch.uno>
Signed-off-by: William G Hatch <william@hatch.uno>
- Add k_partitions to MXFP4 dbuf and pingpong schedules; k_partitions=1
  matches split-K kernels (single K expansion id) and uses a two-cluster
  layout so no empty clusters are passed to reorder_graph.
- E2E splitk bf16 tests: use K=512 with two splits so k_per_split meets
  default BLOCK_K; pass manual schedules into capture_wave_kernel_info.
- wave_gemm_test SplitKMxfp4Gemm: compile with get_mxfp4_dbuf_schedule
  k_partitions=1.

Made-with: Cursor

Signed-off-by: William G Hatch <william@hatch.uno>
@willghatch willghatch force-pushed the users/willghatch/splitk-mxfp4 branch from 5586133 to f8d5fd5 Compare April 13, 2026 15:52
@willghatch willghatch force-pushed the users/willghatch/splitk-mxfp4 branch from c62a3cd to 307bc75 Compare April 21, 2026 14:14
Motivation and details in short paragraphs.
After rebasing splitk-mxfp4 onto main, the split-K C++ example test uses
only tagged template helpers that return WaveCompileOptions from the
template layer; the explicit compile import is unused.

Change Details:

- Remove unused symbol from the import line to match actual usage and
  avoid linter noise.

- Single-line import cleanup in examples/python/7.1_schedule.py

Made-with: Cursor

Signed-off-by: William G Hatch <william@hatch.uno>
Signed-off-by: William G Hatch <william@hatch.uno>
After the SCC-modeling rework, waveasm.if requires an SCC or SGPR
condition.  Kernels whose scf.if condition comes from a vector-path
arith.cmpi (e.g. the split-K MXFP4 bounds check) end up with a
!waveasm.vreg condition, which trips the verifier.  Capture this case
as a red-phase lit test so the subsequent fix in buildIfFromSCFIf is
guarded against regression.

Change Details:

- Add scf_if_vgpr_cond_to_wave_if to region-based-translation.mlir.
  The reason is to exercise the VGPR -> SCC coercion path (expected
  v_readfirstlane_b32 + s_cmp_ne_u32) that the failing split-K MXFP4
  tests need.

Signed-off-by: William G Hatch <william@hatch.uno>
Match the shape of the existing scf_if_to_wave_if test so the CHECK
for waveasm.if can pin the result type (!waveasm.vreg) and prove the
full condition-lowering pipeline round-trips.  With no yield the op
printer omits the `-> result_ty` suffix, which leaves the CHECK too
loose to distinguish the VGPR-cond case from the SCC-cond case.

Change Details:

- Add a trivial yield (i32 add/sub) to scf_if_vgpr_cond_to_wave_if.
  The reason is to exercise the yielded-result path and keep CHECK
  anchored to `waveasm.if ... : !waveasm.scc -> !waveasm.vreg`.

Signed-off-by: William G Hatch <william@hatch.uno>
Post-rebase main enforces that waveasm.if's condition must be SCC or
SGPR.  Split-K MXFP4 kernels generate scf.if where the condition comes
from arith.cmpi on VGPR-typed operands (the bounds check uses
affine.apply on gpu.block_id, and affine.apply always produces VGPR),
so arith.cmpi falls through the vector path and yields a boolean
VGPR.  That VGPR tripped the verifier as
  'waveasm.if' op operand #0 must be SCC or SGPR, but got '!waveasm.vreg'
in 4 split-K MXFP4 waveasm e2e tests.

buildIfFromSCFIf previously only converted ImmType -> SGPR (via
s_mov_b32), which also silently violated the new verifier rule (SGPR
is accepted, but only because ImmType-source cmpi is rare in practice;
the verifier tolerates a plain SGPR condition from s_mov).  We now
generalise: any non-SCC condition is normalised to SGPR and then
tested against 0 with s_cmp_ne_u32 to produce SCC, mirroring the VGPR
upper-bound coercion in buildLoopFromSCFFor.

Change Details:

- Replace the ImmType-only path in buildIfFromSCFIf with a general
  coercion: VGPR -> v_readfirstlane_b32, Imm -> s_mov_b32, then
  s_cmp_ne_u32 against 0 to materialise SCC.  The reason is that the
  new verifier contract requires SCC/SGPR and the simplest uniform
  coercion is via the readfirstlane + s_cmp idiom already used for
  VGPR loop upper bounds.

Signed-off-by: William G Hatch <william@hatch.uno>
The 3 split-K MXFP4 bf16 waveasm e2e tests assert that the generated
assembly contains buffer_atomic_pk_add_bf16.  Before the migration to
the tagged template (commit 4d87ee9 "Remove untagged
get_splitk_mxfp4_gemm_kernel"), the underlying implementation
defaulted to bf16 output.  The tagged template defaults to f32, which
takes the buffer_atomic_add_f32 path in handleMemRefAtomicRMW and
never emits the bf16-packed atomic.

These tests explicitly set up bfloat16 output tensors and assert the
bf16 atomic instruction is emitted, so the bf16 dtype is the correct
selection for the output_type kwarg.

Change Details:

- Pass output_type=tkl.bf16 in test_splitk_mxfp4_bf16_atomic_cpp_backend,
  test_splitk_mxfp4_bf16_asm_emission, and
  test_splitk_mxfp4_preshuffle_scales_cpp_backend to restore the
  previous (pre-migration) bf16 output behavior the tests were
  written against.

Signed-off-by: William G Hatch <william@hatch.uno>
…tion

The assembly emitter's peak SGPR scan was counting precolored VCC
registers (s[106:107] on GFX9 Wave64) as general-purpose SGPRs.
This inflated peakSGPRs, causing the loop back-edge swap temporary
to be allocated at s108 -- beyond the hardware limit of s105.

VCC is an architectural register emitted as "vcc" in assembly, not
as s[106:107], so it should not contribute to the general SGPR peak.
Skip PSReg values at or above target.getMaxSGPRs() (the VCC boundary)
when computing peakSGPRs.

Fixes SGPR overflow in split-K MXFP4 bf16 atomic kernels where the
v_cndmask_b32 VCC dependency tracking created precolored s[106:107]
values that were incorrectly counted.

Made-with: Cursor

Signed-off-by: William G Hatch <william@hatch.uno>
Guards against the issue fixed by the preceding commit (VCC exclusion
from peakSGPRs): a loop with SGPR iter_args that swap, combined with
a live precolored VCC (s[106:107]), previously allocated the swap
scratch at s108 -- past the user SGPR range -- tripping the
assembler's "register index is out of range" error.

Change Details:

- New lit test sgpr-swap-emit-vcc.mlir:  asserts the emitter does not
  choose an s1XX swap temp (s106+ is VCC or TTMP, not user-addressable
  in the general SGPR pool).  Fails against the pre-fix emitter with
  "s_mov_b32 s108, s1"; passes with the fix.

Signed-off-by: William G Hatch <william@hatch.uno>
Addresses remaining numerical failures in split-K MXFP4 GEMM tests on
gfx950 that persisted after the bf16 atomic and VCC-exclusion fixes.
Bundles correctness fixes from `wip/streamk-mxfp4-explore` (dcbb090)
that target codegen paths exercised by the split-K kernel.

Change Details:

Cherry-picked commit `dcbb090a` onto the current branch.  Conflicts
resolved as follows:

- `tests/kernel/wave/asm/test_waveasm_e2e.py` — kept HEAD, which already
  contains the `output_type=tkl.bf16` fix from commit `d8b3d589` on this
  branch.
- `waveasm/lib/Transforms/RegionBuilder.cpp` — kept HEAD, which has a
  more complete `buildIfFromSCFIf` coercion path (handles SCC, VGPR,
  ImmType, and SGPR conditions, all uniformly coerced to SCC via
  `s_cmp_ne_u32`) than dcbb090's partial version.

Included changes:

- `AssemblyEmitter.cpp`: emit IfOp yield-to-result register copies per
  branch (VGPR/SGPR/AGPR) when the allocated yield register differs
  from the allocated result register; emit a do-while loop guard
  (pre-loop `s_cmp_ge_u32` + `s_cbranch_scc1`) that skips the body
  when the trip count is zero (scf.for can have zero iterations but
  waveasm.loop is do-while); AGPR copies via `v_accvgpr_read/write_b32`
  using `kScratchVGPR`.
- `LinearScanPass.cpp`: remove IfOp results from the linear-scan
  worklist and assign their physical register post-allocation from
  the then-yield operand via `getEffectivePhysReg`; handle `PARegType`
  in loop init-arg block-arg assignment.
- `Liveness.cpp`: Pass 3c removes IfOp result ranges from the worklist
  and extends the yield-operand range to cover the IfOp result
  lifetime.
- `ArithHandlers.cpp`: when all users of `arith.cmpi` feed `scf.if`,
  promote VGPR operands via `V_READFIRSTLANE_B32` and emit `S_CMP_*`
  directly (EQ/NE/LT/LE/GT/GE in I32 and U32 variants), so the
  `waveasm.if` condition is already in SCC form.
- `tests/kernel/wave_gemm_test.py`: bump the (512,512,1024) split-K
  test shape from 2 splits to 4 splits, matching upstream.
- `waveasm/test/Transforms/linear-scan-if-feeds-loop.mlir` and
  `linear-scan-ifop-bug.mlir`: CHECK-line updates for the new IfOp
  allocation behavior (yield carries the if-result register instead
  of a separate allocation).

Signed-off-by: William G Hatch <william@hatch.uno>
The prior CHECK lines expected the VGPR-coercion path via v_cndmask_b32
-> v_readfirstlane_b32 -> s_cmp_ne_u32.  After the cherry-pick of
dcbb090, VGPR-operand arith.cmpi with all-scf.if users promotes its
operands directly through v_readfirstlane_b32 and emits the s_cmp_*
variant for the cmpi predicate (here s_cmp_lt_i32), producing SCC
without an intermediate boolean VGPR.  The new path is what
post-rebase split-K kernels exercise on gfx950.

Change Details:

- Replace the v_cndmask_b32 / s_cmp_ne_u32 CHECK lines with the new
  readfirstlane + s_cmp_lt_i32 expectation.
- Update the surrounding block comment to describe the direct promotion
  mechanism rather than the boolean-VGPR coercion path.

Signed-off-by: William G Hatch <william@hatch.uno>
@willghatch willghatch force-pushed the users/willghatch/splitk-mxfp4 branch from 307bc75 to ef0c63b Compare April 21, 2026 14:16
@willghatch willghatch merged commit f08f0cc into main Apr 21, 2026
18 of 19 checks passed
@willghatch willghatch deleted the users/willghatch/splitk-mxfp4 branch April 21, 2026 15:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants