[MLAS] Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback#28416
[MLAS] Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback#28416melkap01-Arm wants to merge 26 commits into
Conversation
Adds DynamicQuantMatMulFp8 schema under the Microsoft contrib opset.
Registers the CPU contrib kernel when FP8 types are enabled.
Adds dynamic_quant_matmul_fp8.{h,cc} CPU kernel implementation.
Adds MLAS FP8 GEMM API surface and scalar fallback implementation in qgemm_fp8.cpp.
Wires the MLAS FP8 source into the MLAS build.
Adds provider tests for the FP8 op-kernel path.
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
…tion enforced Signed-off-by: melkap01 <melike.kaptan@arm.com>
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds a CPU contrib implementation for com.microsoft::DynamicQuantMatMulFp8 backed by a new MLAS FP8 GEMM API, including a scalar fallback implementation and provider/MLAS unit tests.
Changes:
- Introduces
DynamicQuantMatMulFp8schema, CPU kernel registration, and a CPU opkernel implementation with prepack support for constant non-FP8 B. - Adds MLAS FP8 GEMM public API (
MlasFp8GemmBatch) and scalar fallback implementation, plus a sharedsize_toverflow helper. - Adds provider tests for the new contrib op and MLAS unit tests for the FP8 GEMM path.
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp | Adds MLAS unit tests for the FP8 GEMM batch API (threaded + edge cases). |
| onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc | Adds CPU provider tests covering op contract, prepack, scale/zero-point validation, and edge cases. |
| onnxruntime/core/mlas/lib/qgemm_fp8.cpp | Implements scalar fallback for MlasFp8GemmBatch with validation and parallelism. |
| onnxruntime/core/mlas/lib/mlasi.h | Adds MlasMultiplyOverflowsSizeT helper used for overflow-safe size computations. |
| onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp | Switches overflow checks to the shared MlasMultiplyOverflowsSizeT helper. |
| onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp | Switches overflow checks to the shared MlasMultiplyOverflowsSizeT helper. |
| onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h | Renames a parameter and removes the local overflow helper in favor of shared MLAS helper. |
| onnxruntime/core/mlas/inc/mlas.h | Adds public MLAS FP8 GEMM structs/API and the FP8 mode enum. |
| onnxruntime/core/graph/contrib_ops/quantization_defs.cc | Adds the DynamicQuantMatMulFp8 contrib operator schema + shape inference. |
| onnxruntime/core/graph/contrib_ops/ms_opset.h | Registers the new contrib schema in the Microsoft opset (gated on float8 support). |
| onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h | Declares the new CPU contrib kernel with prepack/shared-prepack support. |
| onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc | Implements the CPU kernel, including PrePack quantization of constant B and MLAS dispatch. |
| onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc | Registers the CPU kernel when float8 types are enabled. |
| cmake/onnxruntime_mlas.cmake | Wires qgemm_fp8.cpp into the MLAS build. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: melkap01 <melike.kaptan@arm.com>
…es implemented Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Review — PR #28416: [MLAS] Add CPU
|
Description
Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback
This MR adds a CPU contrib implementation for com.microsoft::DynamicQuantMatMulFp8. The path supports dynamic
block-wise quantization of float/float16/bfloat16 activations to FP8, FP8 runtime B, constant non-FP8 B pre-
quantization, block-wise scales, configurable block sizes, and float/float16/bfloat16 outputs.
Main Changes
Adds DynamicQuantMatMulFp8 schema under the Microsoft contrib opset.
Registers the CPU contrib kernel when FP8 types are enabled.
Adds MlasFp8GemmBatch scalar qgemm_fp8.cpp fallback implementation, which performs the FP8 GEMM compute path used by the CPU kernel.
Adds provider tests for the FP8 opkernel path.
Operator Contract
A supports float, float16, and bfloat16.
Runtime B supports FP8 only and must be rank-2.
Constant initializer B supports float, float16, bfloat16, or FP8.
Non-FP8 constant B is dynamically quantized once during PrePack.
Dynamic non-FP8 B is intentionally rejected.
Output Y supports float, float16, and bfloat16.
Optional Y_scale and Y_zero_point are supported.
FP8 formats supported:
FLOAT8E4M3FN
FLOAT8E4M3FNUZ
FLOAT8E5M2
FLOAT8E5M2FNUZ
Quantization Semantics
The implementation enforces symmetric quantization.
All provided zero-point inputs must encode 0.0.
Non-zero zero points are rejected.
Scale values are validated as finite and positive before use.
A scales are computed dynamically by the kernel.
For non-FP8 constant B, B scales are computed during PrePack.
For FP8 runtime/constant B, B_scale is required and validated.
Y_scale, when provided, must be scalar and is applied to the final accumulation.
Y_zero_point, when provided, must be scalar and zero-valued.
Block Layout
Adds block_size_m, block_size_k, and block_size_n.
block_size_m defaults to 1 and is currently constrained to 1.
block_size_k and block_size_n default to 128.
A scale layout is row/block-K based and generated internally by the kernel.
B_scale and B_zero_point use [N / block_size_n, K / block_size_k] layout.
Shape inference was tightened to match runtime behavior, including rank-2 B enforcement.
Kernel Behavior
Runtime FP8 B is consumed directly.
Constant non-FP8 B is quantized to FP8 in PrePack.
Constant FP8 B keeps its FP8 type metadata.
Prepacked metadata restores B shape, FP8 type, quantized B size, and B scale count for shared prepack reuse.
B/B-zero-point FP8 type consistency is validated regardless of whether B type came from runtime B or prepack
metadata.
K == 0 produces zero-filled output instead of returning uninitialized data.
M == 0 and N == 0 empty outputs return cleanly after cheap runtime contract validation.
MLAS FP8 Fallback
Adds MlasFp8GemmBatch.
Implements FP8 decode, block-wise scale application, float accumulation, and optional output scaling.
Supports all four FP8 modes listed above.
Parallelizes fallback work over BatchN * M.
Adds defensive validation before threaded execution:
valid FP8 mode
non-zero block sizes
required pointers only when actually dereferenced
leading dimensions only when used
strided offset overflow checks
block scale offset overflow checks
caller-provided block counts must match the GEMM shape and block sizes
This is a functional scalar fallback, not a hardware-optimized FP8 GEMM backend.
Tests
Provider tests cover:
Constant non-FP8 B prepack path
Runtime FP8 B path
All four FP8 formats
Omitted optional output quantization inputs
Optional Y_scale
Float16 and bfloat16 outputs
Bfloat16 scale tensors
Symmetric zero-point rejection for B/Y
FP8 B / B-zero-point type mismatch rejection
Non-default block sizes
Shared prepacked B metadata restore
Shared prepack semantic correctness with different B scales
Rejection of unsupported dynamic non-FP8 B
Runtime B rank > 2 schema/runtime rejection
Malformed B scale shape validation before scale reads
M == 0, N == 0, and K == 0 edge cases
Invalid Y_scale shape, value, and type on the K == 0 path
Known Limitations
Dynamic non-FP8 B is not supported by design.
No packed-B optimized FP8 backend is exposed in this MR.
No KleidiAI FP8 dispatch is included in this path.
MLAS FP8 GEMM is currently correctness-oriented scalar fallback code, not a production performance kernel.
Full MatMul broadcast semantics for batched B are intentionally not supported; runtime/schema validation is
restricted to rank-2 B.
Verification
Built onnxruntime_provider_test.
Ran FP8 provider tests successfully.
Ran the converted Qwen3 ONNX model successfully.
All DynamicQuantMatMulFp8 tests passed.
Motivation and Context