Skip to content

compute: Project B Stage 4 scaffolding (dispatch + env gate)#55

Merged
kekzl merged 2 commits into
mainfrom
feat/project-b-stage4-scaffolding
Apr 24, 2026
Merged

compute: Project B Stage 4 scaffolding (dispatch + env gate)#55
kekzl merged 2 commits into
mainfrom
feat/project-b-stage4-scaffolding

Conversation

@kekzl
Copy link
Copy Markdown
Owner

@kekzl kekzl commented Apr 24, 2026

Summary

Infrastructure layer for the MXFP4 FMHA hardware-blockscale MMA upgrade.
All prerequisites shipped in PR #54; this lands the dispatcher hook so
the actual kernel body can be developed in follow-up sessions without
churning the dispatch wiring.

What this PR does

  • Adds src/compute/attention_fmha_mxf4nvf4_sm120.{h,cu} as the landing
    pad for the new kernel.
  • Exposes mxf4nvf4_blockscale_enabled() gated on IMP_FMHA_BLOCKSCALE=1
    — cached one-shot env lookup, logs activation.
  • Adds fmha_sm120_mxf4nvf4_prefill() as the dispatch entry point.
    Currently delegates to the legacy fmha_sm120_mxfp4_prefill().
  • Updates attention_dispatch.cu MXFP4 branch to route through the new
    entry when the flag is set, preserving the same true/false fallback
    contract.

What this PR does NOT do

The kernel body still points at legacy. Setting IMP_FMHA_BLOCKSCALE=1
currently gives legacy behavior with a one-shot informational log.

Why deferred

The kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64 MMA expects
operands in a specific CUTLASS (T32,V32)→(M16,K64) layout that differs
non-trivially from imp's current row-major SMEM packing. Per the CuTe
stride analysis (Stride<Stride<_128,_1>,Stride<_16,_8,_512>>), each
thread's 32 FP4 values span 4 rows × 8 k-offsets. Without an
end-to-end Q·K^T FP32-reference harness, operand-layout bugs cascade
into softmax/P·V numerical garbage that's hard to isolate.

Prerequisites (merged in PR #54)

Stage Commit Result
1 Feasibility f50221e MMA compiles + launches on sm_120f + CUDA 13.2.78
2 Numerical e7615f3 A=0 → D=0 invariant verified
2.5 Quant math (linear) 4baf0d7 FP16→NVFP4→FP16 round-trip, 9.5% RMSE Gaussian
3 Throughput b9fbb66 2.60× raw MMA speedup vs legacy
3.1 HW scale layout 30e827f Quant + dequant with HW-interleaved scale, 4/4 tests pass

Next session roadmap

Per docs/PROJECT_B_MXFP4_FMHA_UPGRADE.md (merged in PR #54):

  1. Build E2E Q·K^T FP32-reference harness against the probe kernel
  2. Implement thread-layout mapping in a minimal Q·K^T test
  3. Swap SMEM scale format to FP8 UE4M3 per-16-elem
  4. Swap MMA in fmha_sm120_mxf4nvf4_kernel
  5. Bench + real-prompt correctness

Estimated effort: ~20 hours focused.

Test plan

🤖 Generated with Claude Code

kekzl and others added 2 commits April 24, 2026 22:49
Infrastructure for the MXFP4 FMHA hardware-blockscale upgrade. The
kernel body is deferred — this commit wires up everything around it so
future iteration can focus on the kernel itself.

What's here:
- src/compute/attention_fmha_mxf4nvf4_sm120.{h,cu} — landing pad file
  with mxf4nvf4_blockscale_enabled() env gate + fmha_sm120_mxf4nvf4_prefill()
  entry that currently delegates to the legacy f8f6f4 kernel.
- attention_dispatch.cu: MXFP4 branch now checks IMP_FMHA_BLOCKSCALE
  and routes through the new entry when set. Same true/false fallback
  contract preserved.
- Documentation of what remains: per-thread CUTLASS ALayout / SFALayout
  translation for the (T32,V32)→(M16,K64) mapping, per-16-elem FP8
  UE4M3 scale handling, MMA instruction swap to kind::mxf4nvf4.block_scale.

Why deferred: the CUTLASS CuTe ALayout expects non-row-major thread
value distribution — each thread holds 32 FP4 values spanning 4 rows
at specific k-offsets per the Stride<_16,_8,_512> formula. Without an
end-to-end Q·K^T correctness harness against FP32 reference, operand-
layout bugs cascade into hard-to-debug numerical garbage. A future
session should start by building that harness against the probe kernel,
then thread the correct layout through the FMHA SMEM structure.

All prerequisites validated in PR #54 commits:
  f50221e — Probe compiles + launches
  e7615f3 — A=0 → D=0
  4baf0d7, 30e827f — Quant round-trips (linear + HW layout), 9.5% RMSE
  b9fbb66 — 2.60× raw MMA speedup headroom
  98ffabd — Stage 3-5 plan with file:line refs

Tests: 583 pass. IMP_FMHA_BLOCKSCALE=1 now routes through the landing
pad (currently delegates to legacy, log clearly indicates WIP status).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Prerequisite for the FMHA integration: validate that a full Q·K^T
through quant → mma.sync.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64
produces numerically correct output before plumbing into the real
attention kernel.

Adds src/compute/mxf4nvf4_qkt_validate.{h,cu} and three GTest cases:

1. UniformInputs_AllOnes — Q=K=1.0 → D=64 everywhere.
   Status: PASS (128/128 exact, max_err=0.000). Validates:
     - MMA instruction launches + produces output
     - Uniform scale = 1.0 (FP8 UE4M3 byte 0x38) is accepted
     - A=0 invariant from earlier probe extends to real data
     - Sum-preserving operand layout is consistent with output layout

2. RowIndicator — Q row m = m-th E2M1 magnitude, K = all ones.
   Status: PASS as diagnostic. Shows D output mixes rows 0,1,8,9
   into a single output row, indicating per-thread A operand layout
   needs more work (currently treating each thread as holding 4
   different row segments, but output assumes single-row contribution).

3. ColIndicator — Q = 1, K col n = n-th magnitude.
   Status: PASS as diagnostic. D columns come out pairwise-averaged
   (112.0 repeated), showing B operand n-indexing is incorrect for
   single m16n8k64 issue (tidB=0).

Thread decomposition now uses CuTe column-major convention:
   t_outer = tid % 4 (inner shape 4, stride 128)
   t_inner = tid / 4 (outer shape 8, stride 1)
   offset = t_outer * 128 + t_inner + v_layout_stride

Remaining layout work for full correctness (tracked in test diagnostic
output): decode per-issue tidB=0 subset of the CuTe BLayout, and
properly split the A operand's 32 per-thread values across the 4 row
pairs (m, m+1, m+8, m+9) such that the hardware MMA reconstructs the
expected m16n8 output.

Tests: 586 pass (was 583, +3 new QKT tests as diagnostics).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@kekzl kekzl enabled auto-merge (squash) April 24, 2026 21:15
@kekzl kekzl merged commit e210601 into main Apr 24, 2026
2 checks passed
@kekzl kekzl deleted the feat/project-b-stage4-scaffolding branch April 24, 2026 21:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant