Route fp16 HQNBIT_CompInt8 (4-bit and 8-bit) through fp32 MLAS path in MatMulNBits#27820
Route fp16 HQNBIT_CompInt8 (4-bit and 8-bit) through fp32 MLAS path in MatMulNBits#27820
Conversation
The HQ CompInt8 kernels (HQ4BitGemm_CompInt8, HQ8BitGemm_CompInt8) are wrappers that convert fp16->fp32 per-tile and call the same SQ fp32 kernels. By doing bulk conversion at the operator level we: 1. Eliminate per-tile fp16<->fp32 conversion overhead 2. Automatically get KleidiAI support for 4-bit (SQ4BitGemm_CompInt8 checks KleidiAI internally) 3. Unify the ARM64 fp16 path with the x64/Apple ARM64 approach Changes in PrePack: - B input: Use SQNBIT_CompInt8 for sizing/packing when HQNBIT_CompInt8 4-bit. Convert constant fp16 scales to fp32 for KleidiAI packing. Compute BZpCorr for asymmetric KleidiAI path. - Scales input: Unified handler for both 4-bit and 8-bit HQNBIT_CompInt8. Checks KleidiAI scales-packed for 4-bit on ARM64. Falls back to non-KleidiAI with fp32 scale conversion when zero_points are dynamic. - Bias input: Pre-convert fp16 bias to fp32 for compute time. Changes in ComputeBPacked: - Added if constexpr block for MLFloat16 + HQNBIT_CompInt8 that does bulk A fp16->fp32, uses pre-converted scales/bias, calls MlasQNBitGemmBatch<float> with SQNBIT_CompInt8, bulk converts output fp32->fp16. - Standard path unchanged for other compute types.
There was a problem hiding this comment.
Pull request overview
This PR updates the CPU MatMulNBits operator to route HQNBIT_CompInt8 (fp16 input) through the fp32 MLAS SQNBIT_CompInt8 path for both 4-bit and 8-bit weights, replacing the existing per-tile fp16↔fp32 conversions inside the HQ wrapper kernels.
Changes:
- Adjusts PrePack for
HQNBIT_CompInt8(notably 4-bit) to size/pack usingSQNBIT_CompInt8, including fp16→fp32 scale conversion for packing and ARM64 KleidiAI asymmetricBZpCorrhandling. - Unifies
HQNBIT_CompInt8scales handling across 4-bit and 8-bit, including optional “scales already packed” behavior on ARM64 KleidiAI. - Updates compute for fp16
HQNBIT_CompInt8to bulk-convert A to fp32, runMlasQNBitGemmBatch<float>withSQNBIT_CompInt8, then bulk-convert output back to fp16 (and uses preconverted scales/bias when available).
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
These functions were wrappers that converted fp16->fp32 per-tile and called the same SQ fp32 kernels. Now that HQNBIT_CompInt8 delegates to SQNBIT_CompInt8 at the operator level (matmul_nbits.cc), they are dead code. Removed: - HQ4BitGemm_CompInt8 and HQ8BitGemm_CompInt8 functions - HQ8BitCompInt8PerGemmWorkspace struct - HQ4BitGemmVariant_CompInt8 and HQ8BitGemmVariant_CompInt8 enum values - InitializeWorkspace_CompInt8<MLAS_FP16> specialization - HQNBIT_CompInt8 branches in MlasQNBitGemmPackQuantBData - HQ8BitGemmVariant_CompInt8 workspace extraction in MlasQNBitGemmBatch Updated: - MlasIsQNBitGemmAvailable: redirects HQNBIT_CompInt8 to SQNBIT_CompInt8 (still called by GetComputeType<MLFloat16> on ARM64) - HQNBIT_CompInt8 workspace size returns 0 (no longer needed at MLAS level) - Simplified SQNBIT_CompInt8||HQNBIT_CompInt8 conditions to SQNBIT_CompInt8
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
Comments suppressed due to low confidence (1)
onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp:110
- This pack-size helper now only includes the extra SQNBIT_CompInt8 metadata (scales/blk-sums, KleidiAI layout). If any caller still passes
HQNBIT_CompInt8into the MLAS packing APIs, it will get a smaller size and a different packing format than the SQNBIT_CompInt8 compute path expects. To avoid inconsistent behavior, consider normalizingHQNBIT_CompInt8→SQNBIT_CompInt8before these dispatch helpers (or ensure HQNBIT_CompInt8 can never reach them).
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
if (ComputeType == SQNBIT_CompInt8) {
const size_t ScaleSize = N * BlockCountK * sizeof(float);
size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float);
// align on a 32-byte boundary
constexpr size_t PackedQuantBDataAlignment = 32;
PackedQuantBDataSize += PackedQuantBDataAlignment - 1;
constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment();
BlkSumSize += BlkSumAlignment - 1;
if constexpr (QuantAUnsigned) {
// 2 block sum
return PackedQuantBDataSize + ScaleSize + BlkSumSize + BlkSumSize;
} else {
return PackedQuantBDataSize + ScaleSize + BlkSumSize;
}
} else {
return PackedQuantBDataSize;
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Bug 1: effective_compute_type only redirected HQNBIT_CompInt8 to SQNBIT_CompInt8 for 4-bit, leaving 8-bit with HQNBIT_CompInt8 as the compute type. Since HQ CompInt8 code was removed from MLAS, pack size returns 0 and B is never packed for 8-bit. Bug 2: fp16 scales conversion during B packing was gated on nbits_ == 4, so 8-bit got null scales passed to SQ8BitGemmPackQuantBDataAndBlkSum. Bug 3: The HQNBIT_CompInt8 scales/zero_points PrePack called MlasQNBitGemmPackQuantBData with null QuantBData for separate packing on ARM64 for 4-bit. The ARM64 NEON pack function does not support incremental packing for 4-bit (the standard SQNBIT_CompInt8 path uses should_pack_scale_and_zp_inputs = (nbits_ == 8) on ARM64). Guard separate packing to match standard behavior. Also tightened the BZpCorr condition during B packing to check MlasQNBitGemmScalesPacked before computing correction.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
fc606f5 to
2309452
Compare
…teType - Replace graceful fallback with ORT_ENFORCE for scales_fp32_ in ComputeBPacked HQNBIT_CompInt8 path (PrePack always sets it). - GetComputeType<MLFloat16> now checks MlasIsQNBitGemmAvailable with SQNBIT_CompInt8 directly instead of HQNBIT_CompInt8. - Remove HQNBIT_CompInt8 -> SQNBIT_CompInt8 redirect from MlasIsQNBitGemmAvailable in MLAS layer.
2309452 to
138318a
Compare
Description
Routes fp16
HQNBIT_CompInt8through the fp32 MLAS path (SQNBIT_CompInt8) at the operator level for both 4-bit and 8-bit MatMulNBits, then removes the ~370 lines of dead HQ CompInt8 wrapper code from MLAS.Operator changes (matmul_nbits.cc):
SQNBIT_CompInt8for sizing/packing, pre-converts fp16 scales and bias to fp32, computes BZpCorr for asymmetric KleidiAI on ARM64.MlasQNBitGemmBatch<float>withSQNBIT_CompInt8, bulk fp32→fp16 conversion of C.MLAS cleanup (qnbitgemm.cpp, qnbitgemm_kernel_neon.cpp):
HQ4BitGemm_CompInt8,HQ8BitGemm_CompInt8,HQ8BitCompInt8PerGemmWorkspace, associated enum values, dispatch branches, workspace entries, andHQNBIT_CompInt8NEON kernel conditions.HQNBIT_CompInt8→SQNBIT_CompInt8redirect inMlasIsQNBitGemmAvailableforGetComputeType<MLFloat16>compatibility.Motivation and Context
The HQ CompInt8 kernels are wrappers that convert fp16→fp32 per-tile before calling the same SQ fp32 kernels. This change:
HQNBIT_CompInt8path.Improvements
Measured on
Snapdragon X Elite - X1E78100 - Qualcomm Oryon CPUAsymmetric:
Symmetric:
NOTE: The 8-bit accuracy level 4 path shows some regression (5–25% on 1.5B/3B models, neutral on 7B) due to the bulk fp16↔fp32 conversion overhead replacing the old per-tile approach. The old HQ CompInt8 wrappers kept small tiles cache-hot, while the new unified path does full-matrix conversion passes. This trade-off is acceptable since 4-bit is the dominant quantization format (gaining 26–67%), 8-bit acc4 still outperforms acc1 by 1.7–2.2×, and the regression is most pronounced at smaller model sizes where absolute latencies are already low. A proper fix would be 8-bit KleidiAI-style kernels rather than restoring the wrapper code.