Skip to content

QMoE: prepack int4/int8 expert weights in PrePack hook (symmetric with MatMulNBits)#28749

Open
justinchuby wants to merge 5 commits into
microsoft:mainfrom
justinchuby:qmoe-int-prepack
Open

QMoE: prepack int4/int8 expert weights in PrePack hook (symmetric with MatMulNBits)#28749
justinchuby wants to merge 5 commits into
microsoft:mainfrom
justinchuby:qmoe-int-prepack

Conversation

@justinchuby
Copy link
Copy Markdown
Contributor

Fixes #28748.

Problem

MatMulNBits::PrePack_B calls preprocess_weights_for_mixed_gemm_cuda at session-load time so callers hand it the raw [N, K/(8/bits)] packed int4/int8 weights produced by quantize_matmul_{4,8}bits. The CUTLASS fpA_intB layout transform (row permutation + sub-byte transpose + column interleave + bias) happens inside ORT.

QMoE::PrePack for quant_type == "int" does the opposite: input slots 2 and 5 (fc1/fc2 expert weights) are explicitly skipped with is_packed = false, and the compute path passes tensor->DataRaw() straight into the CUTLASS runner. That assumes the caller has already prepacked the weights themselves, which:

  • Requires a CUDA-built ORT just to export a QMoE model (the pack_weights_for_cuda_mixed_gemm pybind binding is only exposed when ORT is built with USE_CUDA).
  • Is silent-failure-prone — skipping the prepack just produces garbage output, not an error.

Concrete impact: we hit this in microsoft/Olive#2491 (offline MoE→QMoE rewrite for mobius-exported Gemma 4 MoE models). The Olive pass currently has to depend on a CUDA-built ORT installation just to write out a QMoE model, even though the actual quantization math is CPU-side.

Solution

Mirror the MatMulNBits PrePack path inside QMoE:

  • Add packed_fc1_weights_ / packed_fc2_weights_ GPU buffers.
  • Add a PrePackIntExpertWeights helper that walks the E experts of the [E, N, K/(8/bits)] initializer, runs the existing unpack_uint4_transposed_to_int8_direct_cuda / transpose_uint8_matrix_and_convert_to_int8 adapter, then the shared preprocess_weights_for_mixed_gemm_cuda transform, and stacks the per-expert results into [E, K, N/(8/bits)].
  • Dispatch from PrePack() for slots 2 and 5 when quant_type_ == "int".
  • Update ComputeInternal to prefer packed_fc{1,2}_weights_ over the raw tensor data when the PrePack hook has populated them, with a fall-through to the raw initializer (preserving today's behaviour for sessions that disable prepacking — in that case the caller still has to provide pre-prepacked bytes themselves).

Diff scope

File Change
onnxruntime/contrib_ops/cuda/moe/moe_quantization.h new private method + two pre-pack buffer members
onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc helper implementation + PrePack dispatch + ComputeInternal hookup

No schema change, no behaviour change for callers that pre-prepacked their weights, and FP4/FP8/WFP4AFP8 paths are untouched.

Build status

Verified the changed translation unit compiles cleanly against current ORT main (built contrib_ops/cuda/moe/moe_quantization.cc.o with nvcc 13.2 + sm_90 toolchain). Full onnxruntime_providers_cuda link is currently blocked locally by a pre-existing CUDA 13.2 + CCCL header incompatibility in bias_softmax_impl.cu — unrelated to this change. Please run CUDA CI to confirm.

Suggested test follow-ups

  • Extend test_qmoe_cuda.py with an additional code path that hands raw [E, N, K/2] quantized weights to the op (without calling pack_weights_for_cuda_mixed_gemm in Python) and asserts numerical parity with the existing pre-prepacked path. Happy to do this in a follow-up.

Fixes microsoft#28748.

`MatMulNBits::PrePack_B` calls `preprocess_weights_for_mixed_gemm_cuda`
at session-load time so callers can hand it the raw `[N, K/(8/bits)]`
packed int4/int8 weights produced by `quantize_matmul_{4,8}bits`. The
CUTLASS fpA_intB layout transform (row permutation + sub-byte transpose
+ column interleave + bias) happens inside ORT.

`QMoE::PrePack` for `quant_type == "int"` did the opposite: input
slots 2 and 5 (fc1/fc2 expert weights) were explicitly skipped with
`is_packed = false`, and the compute path passed
`tensor->DataRaw()` straight into the CUTLASS runner. That assumes
the caller has already prepacked the weights themselves, which:

- requires a CUDA-built ORT just to export a QMoE model (the
  `pack_weights_for_cuda_mixed_gemm` pybind binding is only exposed
  when ORT is built with USE_CUDA), and
- is silent-failure-prone: skipping the prepack just produces garbage
  output, not an error.

This change mirrors the MatMulNBits PrePack path:

- Add `packed_fc1_weights_` / `packed_fc2_weights_` buffers.
- Add `PrePackIntExpertWeights` helper that walks the E experts of
  the `[E, N, K/(8/bits)]` initializer, runs the existing
  `unpack_uint4_transposed_to_int8_direct_cuda` /
  `transpose_uint8_matrix_and_convert_to_int8` adapter, then the
  shared `preprocess_weights_for_mixed_gemm_cuda` transform, and
  stacks results into `[E, K, N/(8/bits)]`.
- Dispatch from `PrePack` for slots 2 and 5 when `quant_type_ == "int"`.
- Update `ComputeInternal` to prefer `packed_fc{1,2}_weights_` over
  the raw tensor data when the PrePack hook has populated them, with
  a fall-through to the raw initializer for sessions that disable
  prepacking (in that case the caller still has to provide
  pre-prepacked bytes themselves — same as today).

Builds cleanly (verified by re-compiling
`contrib_ops/cuda/moe/moe_quantization.cc.o` against the current
ORT main; remaining link-time errors in the surrounding
`onnxruntime_providers_cuda` target are a pre-existing CUDA 13.2 +
CCCL header incompatibility in `bias_softmax_impl.cu` and unrelated
to this change).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
justinchuby and others added 3 commits June 2, 2026 02:54
Bit-parity smoke test that constructs two single-node QMoE graphs over
identical per-expert quantized weights:

- **Raw path**: writes the un-prepacked `[E, N, K/2]` bytes from
  `quantize_matmul_4bits` straight into the initializer. Exercises
  the new `QMoE::PrePackIntExpertWeights` hook.
- **Pre-prepacked path**: applies `pack_weights_for_cuda_mixed_gemm`
  per-expert before writing the initializer (matches what the existing
  test_qmoe_cuda.py tests do).

Both feed the same QMoE runner; with the PrePack hook in place the
runner sees the same prepacked bytes either way, so outputs should
agree to within fp16 rounding. Two cases cover small (64/32/E=4) and
medium (128/64/E=8) shapes with SwiGLU interleaved fusion.

Guarded by `@unittest.skipUnless(torch.cuda.is_available())` so it
no-ops on CPU-only CI.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Match the CI ruff (0.12.12) import sort: treat onnxruntime as
first-party so 'from onnxruntime.capi import _pybind_state' belongs
in the local-imports block after 'import onnxruntime', not in the
third-party block.

Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
The first version compared the new raw-weight PrePack path against the
existing `pack_weights_for_cuda_mixed_gemm` offline-pre-pack path, but
that comparison is invalid on SM>=90: the existing test harness in this
file hardcodes `force_arch=80` when calling
`pack_weights_for_cuda_mixed_gemm`, and on H100/H200 the other QMoE
parity tests in this file fail with max-diff > 1.0 too (verified on
plain main, pre-dating this change).

Rewrite as a smoke test that:

- builds a single QMoE node with raw, un-prepacked `[E, N, K/2]` int4
  weights from `quantize_matmul_4bits` (the new schema-conformant
  layout that the PrePack hook unlocks),
- runs it through the CUDA QMoE kernel,
- asserts the output has the right shape, is finite, and has reasonable
  magnitudes for the toy weight distribution.

Verified passing on H200 (sm_90) with the PrePack hook in place.

Also: keep `is_packed = false` after `PrePackIntExpertWeights` so the
original weight initializer stays alive for `moe_helper::CheckInputs`
to read its shape on every `Compute` call. The prepacked bytes live
in `packed_fc{1,2}_weights_` and the compute path prefers them over
`fc{1,2}_experts_weights->DataRaw()`. Same trade-off the wfp4afp8
weight branch uses.

Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
@justinchuby
Copy link
Copy Markdown
Contributor Author

Built locally and ran the smoke test on H200 (sm_90):

$ pytest test_qmoe_cuda.py::TestQMoEIntPrePackParity -v
test_int4_swiglu_interleaved_small  PASSED
test_int4_swiglu_interleaved_medium PASSED

A few notes from local verification:

  1. is_packed = false after PrePackIntExpertWeights — same trade-off the existing wfp4afp8 weight branch uses (line 970). moe_helper::CheckInputs still needs the original weight tensor's shape on every Compute call to infer moe_params, so the initializer has to stay alive. The prepacked bytes live in the new packed_fc{1,2}_weights_ buffers and the compute path prefers them over fc{1,2}_experts_weights->DataRaw().

  2. Test is a smoke test, not bit-parity — first version I tried compared against the existing offline pack_weights_for_cuda_mixed_gemm path, but that comparison is invalid on SM>=90: the existing test harness in test_qmoe_cuda.py hardcodes force_arch=80, and on H100/H200 the existing test_swiglu_qmoe_parity_* cases all fail with max-diff > 1.0 on plain main (pre-dating this change). Filed as a separate observation if it's useful — looks like the harness needs to honour runtime SM.

  3. Build verification — full library link succeeded on the dev box after working around a separate CUDA 13.2 / CCCL header bug in bias_softmax_impl.cu (unrelated to this PR). CI should hit no such issues.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review summary

Moving the CUTLASS fpA_intB layout transform for quant_type == "int" QMoE weights into the PrePack hook (mirroring MatMulNBits::PrePack_B) is a good direction, and the mechanics faithfully mirror the validated pack_weights_for_cuda_mixed_gemm pybind path. One blocking concern plus a couple of follow-ups.

High priority

Unconditional re-prepack is not backward compatible (no raw-vs-packed marker). The new dispatch runs PrePackIntExpertWeights for every int QMoE on slots 2/5, with no schema flag or version guard. Existing tooling — including this file's own quant_dequant_blockwise/preprocess_weights_for_mixed_gemm in test_qmoe_cuda.py — already emits weights in the CUTLASS layout and stores them under the logical [E, N, K/pack] shape that moe_helper::CheckInputs validates. Because raw and prepacked byte counts are identical (E*N*K/pack == E*K*N/pack), the declared shape, dtype, and size are indistinguishable. On the default path (prepacking enabled) those models get prepacked a second time and silently produce garbage — the same silent-failure class this PR set out to remove. The description's claim of "no behaviour change for callers that pre-prepacked their weights" only holds when session.disable_prepacking is set.

A safe fix needs an explicit signal that the weights are raw (e.g. a weights_prepacked attribute defaulting to legacy/prepacked, a com.microsoft opset bump, or a distinct marker). Until then this should be opt-in and the hard break documented. The existing prepacked-weight parity tests (test_qmoe_cuda.py ~L154-298) are not updated and would regress on an SM where they currently pass once the hook double-prepacks; the new test only covers the raw path so it won't catch that.

Suggestions

  • Persistent weight memory ~2x for int QMoE. Keeping is_packed = false (so CheckInputs can still read the original shape) while also allocating persistent packed_fc{1,2}_weights_ means the original int weights and the prepacked int weights both stay resident ~= 2x the dominant weight memory for both FC layers. MatMulNBits::PrePack_B avoids this via is_packed = true. Consider caching just the (E, N, K) shape and releasing the source, or documenting the cost. Transient PrePack overhead per FC: E*N*K/pack host->device staging (CPU initializers only) + N*K/pack per-expert transpose scratch + 128 B perm map, freed after the sync.
  • SM coverage. The offline packer restricts force_arch to {75,80,90} and warns arch>90 falls back to 80; the in-kernel path passes sm_ straight through (more correct, matches the device). Please confirm preprocess_weights_for_mixed_gemm_cuda is valid on SM100/120 if int QMoE is expected there, else add a guard/diagnostic.

Nitpick

  • ORT_ENFORCE(bits != 4 || k % 2 == 0, ...) is always true for bits == 4 (k = k_packed*2).

Nice work on the motivation (decoupling QMoE export from a CUDA-built ORT) and the clear writeup.

} else if (input_idx == 5 && quant_type_ == "wfp4afp8" && !use_wfp4afp8_dequant_fallback_) {
PrePackRepackFP4Weights(tensor, stream, alloc, packed_fp4_fc2_weights_, is_packed);
is_packed = false;
} else if (input_idx == 2 && quant_type_ == "int") {
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu Jun 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward-compat blocker: this dispatches PrePackIntExpertWeights for every int QMoE unconditionally. Models produced by existing tooling already store weights in the CUTLASS fpA_intB layout under the logical [E, N, K/pack] shape that moe_helper::CheckInputs validates. Raw and prepacked byte counts are identical (E*N*K/pack == E*K*N/pack), so shape/dtype/size cannot distinguish them. On the default (prepacking-enabled) path those weights get prepacked a second time and silently produce garbage. Needs an explicit raw-vs-prepacked signal (e.g. a weight_layout attribute defaulting to legacy: prepacked for sm=80, or raw for CPU) before auto-prepacking is safe; otherwise gate behind opt-in and document the break.

// matching the same trade-off used by the WFP4AFP8 weight branch
// above. ``packed_fc1_weights_`` carries the prepacked bytes.
PrePackIntExpertWeights(tensor, stream, alloc, packed_fc1_weights_, is_packed);
is_packed = false;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_packed = false keeps the original int initializer resident (so CheckInputs can read its shape) in addition to the persistent packed_fc{1,2}_weights_ buffer, roughly doubling weight memory for both FC layers vs MatMulNBits::PrePack_B (which releases the source via is_packed = true). Consider caching just the (E, N, K) shape and releasing the source initializer, or documenting the ~2x cost.

…drop redundant assert

Addresses tianleiwu's review on microsoft#28749:

**Blocking — backward compatibility.** The previous version dispatched
PrePackIntExpertWeights for every int QMoE unconditionally, which would
double-prepack any model produced by existing tooling
(quantize_matmul_4bits → pack_weights_for_cuda_mixed_gemm → CUTLASS
layout) and silently corrupt its output. Add a new
'weights_prepacked' INT attribute on the QMoE schema, default value 1
(legacy behaviour: weights already in CUTLASS layout, kernel reads as-is).
Setting it to 0 opts in to the new PrePack hook that takes raw
[E, N, K/pack] quantize_matmul_{4,8}bits output and runs the layout
transform inside ORT — matching MatMulNBits semantics and removing the
offline pre-pack dependency from exporters.

The PrePack dispatch and the compute-time weight-buffer override are
both gated on '!weights_prepacked_'. Models without the attribute
behave exactly as before.

**SM coverage.** preprocess_weights_for_mixed_gemm_cuda only has tile /
permutation tables for SM75/80/90; the offline pack_weights_for_cuda_mixed_gemm
restricts force_arch to that set and falls back to 80 for newer archs.
Mirror the same fallback inside PrePackIntExpertWeights so SM86/89 and
SM100/120 callers get a defined Ampere-compiled layout rather than a
silent path through the helper with an unknown SM.

**Nit.** Drop 'ORT_ENFORCE(bits != 4 || k % 2 == 0, ...)' — k is
computed as k_packed * pack_factor, so for bits=4, k % 2 == 0 is a
tautology.

**Memory cost documented.** is_packed stays false (so CheckInputs can
read the source weight shape on every Compute call). Persistent memory
cost is therefore ~2x the int4/int8 weight footprint, ~4x smaller than
the original fp16 baseline. Documented inline. MatMulNBits avoids the
doubling by caching shape in N_/K_ at construction; folding the same
into QMoE is a follow-up.

Tests still pass on H200 (sm_90) with weights_prepacked=0 set in the
new test cases.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
@justinchuby
Copy link
Copy Markdown
Contributor Author

Thanks for the thorough review! Addressed all 4 points in 2fcb940:

  1. SM coverage — added the same SM75/80/90 fallback the offline pack_weights_for_cuda_mixed_gemm binding uses (clamp >90 → 80, also round SM86/89 → 80) inside PrePackIntExpertWeights before calling preprocess_weights_for_mixed_gemm_cuda. Defined layout on Blackwell+ via the Ampere tile tables, matching the offline packer's behaviour.

  2. Nit — dropped the redundant ORT_ENFORCE(bits != 4 || k % 2 == 0) (tautology since k = k_packed * pack_factor).

  3. 2x memory cost — documented inline. Kept is_packed = false so CheckInputs can still read the source shape per Compute. Folding the shape into member variables to release the source (matching MatMulNBits) is straightforward but touches all the other QMoE paths through CheckInputs; left as a clean-up follow-up to keep this PR focused.

H200 re-verification with the new attribute:

test_int4_swiglu_interleaved_small  PASSED
test_int4_swiglu_interleaved_medium PASSED

Test was also updated to set weights_prepacked=0 explicitly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

com.microsoft::QMoE should prepack int4/int8 weights in PrePack(), like MatMulNBits does

2 participants