compute: Project B Stage 4 scaffolding (dispatch + env gate)#55
Merged
Conversation
Infrastructure for the MXFP4 FMHA hardware-blockscale upgrade. The
kernel body is deferred — this commit wires up everything around it so
future iteration can focus on the kernel itself.
What's here:
- src/compute/attention_fmha_mxf4nvf4_sm120.{h,cu} — landing pad file
with mxf4nvf4_blockscale_enabled() env gate + fmha_sm120_mxf4nvf4_prefill()
entry that currently delegates to the legacy f8f6f4 kernel.
- attention_dispatch.cu: MXFP4 branch now checks IMP_FMHA_BLOCKSCALE
and routes through the new entry when set. Same true/false fallback
contract preserved.
- Documentation of what remains: per-thread CUTLASS ALayout / SFALayout
translation for the (T32,V32)→(M16,K64) mapping, per-16-elem FP8
UE4M3 scale handling, MMA instruction swap to kind::mxf4nvf4.block_scale.
Why deferred: the CUTLASS CuTe ALayout expects non-row-major thread
value distribution — each thread holds 32 FP4 values spanning 4 rows
at specific k-offsets per the Stride<_16,_8,_512> formula. Without an
end-to-end Q·K^T correctness harness against FP32 reference, operand-
layout bugs cascade into hard-to-debug numerical garbage. A future
session should start by building that harness against the probe kernel,
then thread the correct layout through the FMHA SMEM structure.
All prerequisites validated in PR #54 commits:
f50221e — Probe compiles + launches
e7615f3 — A=0 → D=0
4baf0d7, 30e827f — Quant round-trips (linear + HW layout), 9.5% RMSE
b9fbb66 — 2.60× raw MMA speedup headroom
98ffabd — Stage 3-5 plan with file:line refs
Tests: 583 pass. IMP_FMHA_BLOCKSCALE=1 now routes through the landing
pad (currently delegates to legacy, log clearly indicates WIP status).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Prerequisite for the FMHA integration: validate that a full Q·K^T
through quant → mma.sync.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64
produces numerically correct output before plumbing into the real
attention kernel.
Adds src/compute/mxf4nvf4_qkt_validate.{h,cu} and three GTest cases:
1. UniformInputs_AllOnes — Q=K=1.0 → D=64 everywhere.
Status: PASS (128/128 exact, max_err=0.000). Validates:
- MMA instruction launches + produces output
- Uniform scale = 1.0 (FP8 UE4M3 byte 0x38) is accepted
- A=0 invariant from earlier probe extends to real data
- Sum-preserving operand layout is consistent with output layout
2. RowIndicator — Q row m = m-th E2M1 magnitude, K = all ones.
Status: PASS as diagnostic. Shows D output mixes rows 0,1,8,9
into a single output row, indicating per-thread A operand layout
needs more work (currently treating each thread as holding 4
different row segments, but output assumes single-row contribution).
3. ColIndicator — Q = 1, K col n = n-th magnitude.
Status: PASS as diagnostic. D columns come out pairwise-averaged
(112.0 repeated), showing B operand n-indexing is incorrect for
single m16n8k64 issue (tidB=0).
Thread decomposition now uses CuTe column-major convention:
t_outer = tid % 4 (inner shape 4, stride 128)
t_inner = tid / 4 (outer shape 8, stride 1)
offset = t_outer * 128 + t_inner + v_layout_stride
Remaining layout work for full correctness (tracked in test diagnostic
output): decode per-issue tidB=0 subset of the CuTe BLayout, and
properly split the A operand's 32 per-thread values across the 4 row
pairs (m, m+1, m+8, m+9) such that the hardware MMA reconstructs the
expected m16n8 output.
Tests: 586 pass (was 583, +3 new QKT tests as diagnostics).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
kekzl
added a commit
that referenced
this pull request
Apr 30, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Infrastructure layer for the MXFP4 FMHA hardware-blockscale MMA upgrade.
All prerequisites shipped in PR #54; this lands the dispatcher hook so
the actual kernel body can be developed in follow-up sessions without
churning the dispatch wiring.
What this PR does
src/compute/attention_fmha_mxf4nvf4_sm120.{h,cu}as the landingpad for the new kernel.
mxf4nvf4_blockscale_enabled()gated onIMP_FMHA_BLOCKSCALE=1— cached one-shot env lookup, logs activation.
fmha_sm120_mxf4nvf4_prefill()as the dispatch entry point.Currently delegates to the legacy
fmha_sm120_mxfp4_prefill().attention_dispatch.cuMXFP4 branch to route through the newentry when the flag is set, preserving the same true/false fallback
contract.
What this PR does NOT do
The kernel body still points at legacy. Setting
IMP_FMHA_BLOCKSCALE=1currently gives legacy behavior with a one-shot informational log.
Why deferred
The
kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64MMA expectsoperands in a specific CUTLASS (T32,V32)→(M16,K64) layout that differs
non-trivially from imp's current row-major SMEM packing. Per the CuTe
stride analysis (
Stride<Stride<_128,_1>,Stride<_16,_8,_512>>), eachthread's 32 FP4 values span 4 rows × 8 k-offsets. Without an
end-to-end Q·K^T FP32-reference harness, operand-layout bugs cascade
into softmax/P·V numerical garbage that's hard to isolate.
Prerequisites (merged in PR #54)
f50221ee7615f34baf0d7b9fbb6630e827fNext session roadmap
Per
docs/PROJECT_B_MXFP4_FMHA_UPGRADE.md(merged in PR #54):fmha_sm120_mxf4nvf4_kernelEstimated effort: ~20 hours focused.
Test plan
IMP_FMHA_BLOCKSCALE=1routes through new path (currently delegates)🤖 Generated with Claude Code