QMoE: prepack int4/int8 expert weights in PrePack hook (symmetric with MatMulNBits)#28749
QMoE: prepack int4/int8 expert weights in PrePack hook (symmetric with MatMulNBits)#28749justinchuby wants to merge 5 commits into
Conversation
Fixes microsoft#28748. `MatMulNBits::PrePack_B` calls `preprocess_weights_for_mixed_gemm_cuda` at session-load time so callers can hand it the raw `[N, K/(8/bits)]` packed int4/int8 weights produced by `quantize_matmul_{4,8}bits`. The CUTLASS fpA_intB layout transform (row permutation + sub-byte transpose + column interleave + bias) happens inside ORT. `QMoE::PrePack` for `quant_type == "int"` did the opposite: input slots 2 and 5 (fc1/fc2 expert weights) were explicitly skipped with `is_packed = false`, and the compute path passed `tensor->DataRaw()` straight into the CUTLASS runner. That assumes the caller has already prepacked the weights themselves, which: - requires a CUDA-built ORT just to export a QMoE model (the `pack_weights_for_cuda_mixed_gemm` pybind binding is only exposed when ORT is built with USE_CUDA), and - is silent-failure-prone: skipping the prepack just produces garbage output, not an error. This change mirrors the MatMulNBits PrePack path: - Add `packed_fc1_weights_` / `packed_fc2_weights_` buffers. - Add `PrePackIntExpertWeights` helper that walks the E experts of the `[E, N, K/(8/bits)]` initializer, runs the existing `unpack_uint4_transposed_to_int8_direct_cuda` / `transpose_uint8_matrix_and_convert_to_int8` adapter, then the shared `preprocess_weights_for_mixed_gemm_cuda` transform, and stacks results into `[E, K, N/(8/bits)]`. - Dispatch from `PrePack` for slots 2 and 5 when `quant_type_ == "int"`. - Update `ComputeInternal` to prefer `packed_fc{1,2}_weights_` over the raw tensor data when the PrePack hook has populated them, with a fall-through to the raw initializer for sessions that disable prepacking (in that case the caller still has to provide pre-prepacked bytes themselves — same as today). Builds cleanly (verified by re-compiling `contrib_ops/cuda/moe/moe_quantization.cc.o` against the current ORT main; remaining link-time errors in the surrounding `onnxruntime_providers_cuda` target are a pre-existing CUDA 13.2 + CCCL header incompatibility in `bias_softmax_impl.cu` and unrelated to this change). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Bit-parity smoke test that constructs two single-node QMoE graphs over identical per-expert quantized weights: - **Raw path**: writes the un-prepacked `[E, N, K/2]` bytes from `quantize_matmul_4bits` straight into the initializer. Exercises the new `QMoE::PrePackIntExpertWeights` hook. - **Pre-prepacked path**: applies `pack_weights_for_cuda_mixed_gemm` per-expert before writing the initializer (matches what the existing test_qmoe_cuda.py tests do). Both feed the same QMoE runner; with the PrePack hook in place the runner sees the same prepacked bytes either way, so outputs should agree to within fp16 rounding. Two cases cover small (64/32/E=4) and medium (128/64/E=8) shapes with SwiGLU interleaved fusion. Guarded by `@unittest.skipUnless(torch.cuda.is_available())` so it no-ops on CPU-only CI. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Match the CI ruff (0.12.12) import sort: treat onnxruntime as first-party so 'from onnxruntime.capi import _pybind_state' belongs in the local-imports block after 'import onnxruntime', not in the third-party block. Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
The first version compared the new raw-weight PrePack path against the
existing `pack_weights_for_cuda_mixed_gemm` offline-pre-pack path, but
that comparison is invalid on SM>=90: the existing test harness in this
file hardcodes `force_arch=80` when calling
`pack_weights_for_cuda_mixed_gemm`, and on H100/H200 the other QMoE
parity tests in this file fail with max-diff > 1.0 too (verified on
plain main, pre-dating this change).
Rewrite as a smoke test that:
- builds a single QMoE node with raw, un-prepacked `[E, N, K/2]` int4
weights from `quantize_matmul_4bits` (the new schema-conformant
layout that the PrePack hook unlocks),
- runs it through the CUDA QMoE kernel,
- asserts the output has the right shape, is finite, and has reasonable
magnitudes for the toy weight distribution.
Verified passing on H200 (sm_90) with the PrePack hook in place.
Also: keep `is_packed = false` after `PrePackIntExpertWeights` so the
original weight initializer stays alive for `moe_helper::CheckInputs`
to read its shape on every `Compute` call. The prepacked bytes live
in `packed_fc{1,2}_weights_` and the compute path prefers them over
`fc{1,2}_experts_weights->DataRaw()`. Same trade-off the wfp4afp8
weight branch uses.
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
|
Built locally and ran the smoke test on H200 (sm_90): A few notes from local verification:
|
tianleiwu
left a comment
There was a problem hiding this comment.
Review summary
Moving the CUTLASS fpA_intB layout transform for quant_type == "int" QMoE weights into the PrePack hook (mirroring MatMulNBits::PrePack_B) is a good direction, and the mechanics faithfully mirror the validated pack_weights_for_cuda_mixed_gemm pybind path. One blocking concern plus a couple of follow-ups.
High priority
Unconditional re-prepack is not backward compatible (no raw-vs-packed marker). The new dispatch runs PrePackIntExpertWeights for every int QMoE on slots 2/5, with no schema flag or version guard. Existing tooling — including this file's own quant_dequant_blockwise/preprocess_weights_for_mixed_gemm in test_qmoe_cuda.py — already emits weights in the CUTLASS layout and stores them under the logical [E, N, K/pack] shape that moe_helper::CheckInputs validates. Because raw and prepacked byte counts are identical (E*N*K/pack == E*K*N/pack), the declared shape, dtype, and size are indistinguishable. On the default path (prepacking enabled) those models get prepacked a second time and silently produce garbage — the same silent-failure class this PR set out to remove. The description's claim of "no behaviour change for callers that pre-prepacked their weights" only holds when session.disable_prepacking is set.
A safe fix needs an explicit signal that the weights are raw (e.g. a weights_prepacked attribute defaulting to legacy/prepacked, a com.microsoft opset bump, or a distinct marker). Until then this should be opt-in and the hard break documented. The existing prepacked-weight parity tests (test_qmoe_cuda.py ~L154-298) are not updated and would regress on an SM where they currently pass once the hook double-prepacks; the new test only covers the raw path so it won't catch that.
Suggestions
- Persistent weight memory ~2x for int QMoE. Keeping
is_packed = false(soCheckInputscan still read the original shape) while also allocating persistentpacked_fc{1,2}_weights_means the original int weights and the prepacked int weights both stay resident ~= 2x the dominant weight memory for both FC layers.MatMulNBits::PrePack_Bavoids this viais_packed = true. Consider caching just the (E, N, K) shape and releasing the source, or documenting the cost. Transient PrePack overhead per FC:E*N*K/packhost->device staging (CPU initializers only) +N*K/packper-expert transpose scratch + 128 B perm map, freed after the sync. - SM coverage. The offline packer restricts
force_archto {75,80,90} and warns arch>90 falls back to 80; the in-kernel path passessm_straight through (more correct, matches the device). Please confirmpreprocess_weights_for_mixed_gemm_cudais valid on SM100/120 if int QMoE is expected there, else add a guard/diagnostic.
Nitpick
ORT_ENFORCE(bits != 4 || k % 2 == 0, ...)is always true forbits == 4(k = k_packed*2).
Nice work on the motivation (decoupling QMoE export from a CUDA-built ORT) and the clear writeup.
| } else if (input_idx == 5 && quant_type_ == "wfp4afp8" && !use_wfp4afp8_dequant_fallback_) { | ||
| PrePackRepackFP4Weights(tensor, stream, alloc, packed_fp4_fc2_weights_, is_packed); | ||
| is_packed = false; | ||
| } else if (input_idx == 2 && quant_type_ == "int") { |
There was a problem hiding this comment.
Backward-compat blocker: this dispatches PrePackIntExpertWeights for every int QMoE unconditionally. Models produced by existing tooling already store weights in the CUTLASS fpA_intB layout under the logical [E, N, K/pack] shape that moe_helper::CheckInputs validates. Raw and prepacked byte counts are identical (E*N*K/pack == E*K*N/pack), so shape/dtype/size cannot distinguish them. On the default (prepacking-enabled) path those weights get prepacked a second time and silently produce garbage. Needs an explicit raw-vs-prepacked signal (e.g. a weight_layout attribute defaulting to legacy: prepacked for sm=80, or raw for CPU) before auto-prepacking is safe; otherwise gate behind opt-in and document the break.
| // matching the same trade-off used by the WFP4AFP8 weight branch | ||
| // above. ``packed_fc1_weights_`` carries the prepacked bytes. | ||
| PrePackIntExpertWeights(tensor, stream, alloc, packed_fc1_weights_, is_packed); | ||
| is_packed = false; |
There was a problem hiding this comment.
is_packed = false keeps the original int initializer resident (so CheckInputs can read its shape) in addition to the persistent packed_fc{1,2}_weights_ buffer, roughly doubling weight memory for both FC layers vs MatMulNBits::PrePack_B (which releases the source via is_packed = true). Consider caching just the (E, N, K) shape and releasing the source initializer, or documenting the ~2x cost.
…drop redundant assert Addresses tianleiwu's review on microsoft#28749: **Blocking — backward compatibility.** The previous version dispatched PrePackIntExpertWeights for every int QMoE unconditionally, which would double-prepack any model produced by existing tooling (quantize_matmul_4bits → pack_weights_for_cuda_mixed_gemm → CUTLASS layout) and silently corrupt its output. Add a new 'weights_prepacked' INT attribute on the QMoE schema, default value 1 (legacy behaviour: weights already in CUTLASS layout, kernel reads as-is). Setting it to 0 opts in to the new PrePack hook that takes raw [E, N, K/pack] quantize_matmul_{4,8}bits output and runs the layout transform inside ORT — matching MatMulNBits semantics and removing the offline pre-pack dependency from exporters. The PrePack dispatch and the compute-time weight-buffer override are both gated on '!weights_prepacked_'. Models without the attribute behave exactly as before. **SM coverage.** preprocess_weights_for_mixed_gemm_cuda only has tile / permutation tables for SM75/80/90; the offline pack_weights_for_cuda_mixed_gemm restricts force_arch to that set and falls back to 80 for newer archs. Mirror the same fallback inside PrePackIntExpertWeights so SM86/89 and SM100/120 callers get a defined Ampere-compiled layout rather than a silent path through the helper with an unknown SM. **Nit.** Drop 'ORT_ENFORCE(bits != 4 || k % 2 == 0, ...)' — k is computed as k_packed * pack_factor, so for bits=4, k % 2 == 0 is a tautology. **Memory cost documented.** is_packed stays false (so CheckInputs can read the source weight shape on every Compute call). Persistent memory cost is therefore ~2x the int4/int8 weight footprint, ~4x smaller than the original fp16 baseline. Documented inline. MatMulNBits avoids the doubling by caching shape in N_/K_ at construction; folding the same into QMoE is a follow-up. Tests still pass on H200 (sm_90) with weights_prepacked=0 set in the new test cases. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
|
Thanks for the thorough review! Addressed all 4 points in 2fcb940:
H200 re-verification with the new attribute: Test was also updated to set |
Fixes #28748.
Problem
MatMulNBits::PrePack_Bcallspreprocess_weights_for_mixed_gemm_cudaat session-load time so callers hand it the raw[N, K/(8/bits)]packed int4/int8 weights produced byquantize_matmul_{4,8}bits. The CUTLASS fpA_intB layout transform (row permutation + sub-byte transpose + column interleave + bias) happens inside ORT.QMoE::PrePackforquant_type == "int"does the opposite: input slots 2 and 5 (fc1/fc2 expert weights) are explicitly skipped withis_packed = false, and the compute path passestensor->DataRaw()straight into the CUTLASS runner. That assumes the caller has already prepacked the weights themselves, which:pack_weights_for_cuda_mixed_gemmpybind binding is only exposed when ORT is built withUSE_CUDA).Concrete impact: we hit this in microsoft/Olive#2491 (offline MoE→QMoE rewrite for mobius-exported Gemma 4 MoE models). The Olive pass currently has to depend on a CUDA-built ORT installation just to write out a QMoE model, even though the actual quantization math is CPU-side.
Solution
Mirror the MatMulNBits PrePack path inside QMoE:
packed_fc1_weights_/packed_fc2_weights_GPU buffers.PrePackIntExpertWeightshelper that walks theEexperts of the[E, N, K/(8/bits)]initializer, runs the existingunpack_uint4_transposed_to_int8_direct_cuda/transpose_uint8_matrix_and_convert_to_int8adapter, then the sharedpreprocess_weights_for_mixed_gemm_cudatransform, and stacks the per-expert results into[E, K, N/(8/bits)].PrePack()for slots 2 and 5 whenquant_type_ == "int".ComputeInternalto preferpacked_fc{1,2}_weights_over the raw tensor data when the PrePack hook has populated them, with a fall-through to the raw initializer (preserving today's behaviour for sessions that disable prepacking — in that case the caller still has to provide pre-prepacked bytes themselves).Diff scope
onnxruntime/contrib_ops/cuda/moe/moe_quantization.honnxruntime/contrib_ops/cuda/moe/moe_quantization.ccNo schema change, no behaviour change for callers that pre-prepacked their weights, and FP4/FP8/WFP4AFP8 paths are untouched.
Build status
Verified the changed translation unit compiles cleanly against current ORT main (built
contrib_ops/cuda/moe/moe_quantization.cc.owith nvcc 13.2 + sm_90 toolchain). Fullonnxruntime_providers_cudalink is currently blocked locally by a pre-existing CUDA 13.2 + CCCL header incompatibility inbias_softmax_impl.cu— unrelated to this change. Please run CUDA CI to confirm.Suggested test follow-ups
test_qmoe_cuda.pywith an additional code path that hands raw[E, N, K/2]quantized weights to the op (without callingpack_weights_for_cuda_mixed_gemmin Python) and asserts numerical parity with the existing pre-prepacked path. Happy to do this in a follow-up.