Skip to content

Route fp16 HQNBIT_CompInt8 (4-bit and 8-bit) through fp32 MLAS path in MatMulNBits#27820

Merged
jambayk merged 5 commits intomainfrom
jambayk/mnb-4-16
Mar 25, 2026
Merged

Route fp16 HQNBIT_CompInt8 (4-bit and 8-bit) through fp32 MLAS path in MatMulNBits#27820
jambayk merged 5 commits intomainfrom
jambayk/mnb-4-16

Conversation

@jambayk
Copy link
Copy Markdown
Contributor

@jambayk jambayk commented Mar 24, 2026

Description

Routes fp16 HQNBIT_CompInt8 through the fp32 MLAS path (SQNBIT_CompInt8) at the operator level for both 4-bit and 8-bit MatMulNBits, then removes the ~370 lines of dead HQ CompInt8 wrapper code from MLAS.

Operator changes (matmul_nbits.cc):

  • PrePack: Uses SQNBIT_CompInt8 for sizing/packing, pre-converts fp16 scales and bias to fp32, computes BZpCorr for asymmetric KleidiAI on ARM64.
  • ComputeBPacked: Bulk fp16→fp32 conversion of A, calls MlasQNBitGemmBatch<float> with SQNBIT_CompInt8, bulk fp32→fp16 conversion of C.

MLAS cleanup (qnbitgemm.cpp, qnbitgemm_kernel_neon.cpp):

  • Removed HQ4BitGemm_CompInt8, HQ8BitGemm_CompInt8, HQ8BitCompInt8PerGemmWorkspace, associated enum values, dispatch branches, workspace entries, and HQNBIT_CompInt8 NEON kernel conditions.
  • Added HQNBIT_CompInt8SQNBIT_CompInt8 redirect in MlasIsQNBitGemmAvailable for GetComputeType<MLFloat16> compatibility.

Motivation and Context

The HQ CompInt8 kernels are wrappers that convert fp16→fp32 per-tile before calling the same SQ fp32 kernels. This change:

  1. Eliminates per-tile overhead via bulk conversion at the operator level.
  2. Enables KleidiAI for fp16 4-bit — previously bypassed by the HQNBIT_CompInt8 path.
  3. Removes ~370 lines of dead wrapper code from MLAS.

Improvements

Measured on Snapdragon X Elite - X1E78100 - Qualcomm Oryon CPU

Asymmetric:

Model Seq Len Acc1/Acc4 (before) Acc1/Acc4 (after) Acc4 speedup Acc4 latency (after)
Qwen 1.5B 256 1.28× 1.55× 1.26× 1187.5ms
Qwen 1.5B 512 1.14× 1.63× 1.55× 2257.2ms
Qwen 3B 256 1.32× 1.82× 1.29× 2351.3ms
Qwen 3B 512 1.38× 1.70× 1.28× 4777.2ms
Qwen 7B 256 1.58× 2.26× 1.40× 4094.5ms
Qwen 7B 512 1.49× 2.23× 1.52× 8002.6ms

Symmetric:

Model Seq Len Acc1/Acc4 (before) Acc1/Acc4 (after) Acc4 speedup Acc4 latency (after)
Qwen 1.5B 256 0.95× 1.45× 1.67× 1255.5ms
Qwen 1.5B 512 1.04× 1.52× 1.55× 2406.7ms
Qwen 3B 256 1.39× 1.88× 1.32× 2215.0ms
Qwen 3B 512 1.42× 1.85× 1.31× 4318.3ms
Qwen 7B 256 1.66× 2.58× 1.55× 3564.4ms
Qwen 7B 512 1.57× 2.60× 1.64× 7227.9ms

NOTE: The 8-bit accuracy level 4 path shows some regression (5–25% on 1.5B/3B models, neutral on 7B) due to the bulk fp16↔fp32 conversion overhead replacing the old per-tile approach. The old HQ CompInt8 wrappers kept small tiles cache-hot, while the new unified path does full-matrix conversion passes. This trade-off is acceptable since 4-bit is the dominant quantization format (gaining 26–67%), 8-bit acc4 still outperforms acc1 by 1.7–2.2×, and the regression is most pronounced at smaller model sizes where absolute latencies are already low. A proper fix would be 8-bit KleidiAI-style kernels rather than restoring the wrapper code.

The HQ CompInt8 kernels (HQ4BitGemm_CompInt8, HQ8BitGemm_CompInt8) are
wrappers that convert fp16->fp32 per-tile and call the same SQ fp32
kernels. By doing bulk conversion at the operator level we:

1. Eliminate per-tile fp16<->fp32 conversion overhead
2. Automatically get KleidiAI support for 4-bit (SQ4BitGemm_CompInt8
   checks KleidiAI internally)
3. Unify the ARM64 fp16 path with the x64/Apple ARM64 approach

Changes in PrePack:
- B input: Use SQNBIT_CompInt8 for sizing/packing when HQNBIT_CompInt8
  4-bit. Convert constant fp16 scales to fp32 for KleidiAI packing.
  Compute BZpCorr for asymmetric KleidiAI path.
- Scales input: Unified handler for both 4-bit and 8-bit HQNBIT_CompInt8.
  Checks KleidiAI scales-packed for 4-bit on ARM64. Falls back to
  non-KleidiAI with fp32 scale conversion when zero_points are dynamic.
- Bias input: Pre-convert fp16 bias to fp32 for compute time.

Changes in ComputeBPacked:
- Added if constexpr block for MLFloat16 + HQNBIT_CompInt8 that does
  bulk A fp16->fp32, uses pre-converted scales/bias, calls
  MlasQNBitGemmBatch<float> with SQNBIT_CompInt8, bulk converts output
  fp32->fp16.
- Standard path unchanged for other compute types.
Copy link
Copy Markdown
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the CPU MatMulNBits operator to route HQNBIT_CompInt8 (fp16 input) through the fp32 MLAS SQNBIT_CompInt8 path for both 4-bit and 8-bit weights, replacing the existing per-tile fp16↔fp32 conversions inside the HQ wrapper kernels.

Changes:

  • Adjusts PrePack for HQNBIT_CompInt8 (notably 4-bit) to size/pack using SQNBIT_CompInt8, including fp16→fp32 scale conversion for packing and ARM64 KleidiAI asymmetric BZpCorr handling.
  • Unifies HQNBIT_CompInt8 scales handling across 4-bit and 8-bit, including optional “scales already packed” behavior on ARM64 KleidiAI.
  • Updates compute for fp16 HQNBIT_CompInt8 to bulk-convert A to fp32, run MlasQNBitGemmBatch<float> with SQNBIT_CompInt8, then bulk-convert output back to fp16 (and uses preconverted scales/bias when available).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

jambayk added 2 commits March 24, 2026 02:58
These functions were wrappers that converted fp16->fp32 per-tile and
called the same SQ fp32 kernels. Now that HQNBIT_CompInt8 delegates to
SQNBIT_CompInt8 at the operator level (matmul_nbits.cc), they are dead
code.

Removed:
- HQ4BitGemm_CompInt8 and HQ8BitGemm_CompInt8 functions
- HQ8BitCompInt8PerGemmWorkspace struct
- HQ4BitGemmVariant_CompInt8 and HQ8BitGemmVariant_CompInt8 enum values
- InitializeWorkspace_CompInt8<MLAS_FP16> specialization
- HQNBIT_CompInt8 branches in MlasQNBitGemmPackQuantBData
- HQ8BitGemmVariant_CompInt8 workspace extraction in MlasQNBitGemmBatch

Updated:
- MlasIsQNBitGemmAvailable: redirects HQNBIT_CompInt8 to SQNBIT_CompInt8
  (still called by GetComputeType<MLFloat16> on ARM64)
- HQNBIT_CompInt8 workspace size returns 0 (no longer needed at MLAS level)
- Simplified SQNBIT_CompInt8||HQNBIT_CompInt8 conditions to SQNBIT_CompInt8
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (1)

onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp:110

  • This pack-size helper now only includes the extra SQNBIT_CompInt8 metadata (scales/blk-sums, KleidiAI layout). If any caller still passes HQNBIT_CompInt8 into the MLAS packing APIs, it will get a smaller size and a different packing format than the SQNBIT_CompInt8 compute path expects. To avoid inconsistent behavior, consider normalizing HQNBIT_CompInt8SQNBIT_CompInt8 before these dispatch helpers (or ensure HQNBIT_CompInt8 can never reach them).
        const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
        size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);

        if (ComputeType == SQNBIT_CompInt8) {
            const size_t ScaleSize = N * BlockCountK * sizeof(float);
            size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float);

            // align on a 32-byte boundary
            constexpr size_t PackedQuantBDataAlignment = 32;
            PackedQuantBDataSize += PackedQuantBDataAlignment - 1;
            constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment();
            BlkSumSize += BlkSumAlignment - 1;

            if constexpr (QuantAUnsigned) {
                // 2 block sum
                return PackedQuantBDataSize + ScaleSize + BlkSumSize + BlkSumSize;
            } else {
                return PackedQuantBDataSize + ScaleSize + BlkSumSize;
            }
        } else {
            return PackedQuantBDataSize;

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Bug 1: effective_compute_type only redirected HQNBIT_CompInt8 to
SQNBIT_CompInt8 for 4-bit, leaving 8-bit with HQNBIT_CompInt8 as
the compute type. Since HQ CompInt8 code was removed from MLAS,
pack size returns 0 and B is never packed for 8-bit.

Bug 2: fp16 scales conversion during B packing was gated on
nbits_ == 4, so 8-bit got null scales passed to
SQ8BitGemmPackQuantBDataAndBlkSum.

Bug 3: The HQNBIT_CompInt8 scales/zero_points PrePack called
MlasQNBitGemmPackQuantBData with null QuantBData for separate
packing on ARM64 for 4-bit. The ARM64 NEON pack function does not
support incremental packing for 4-bit (the standard SQNBIT_CompInt8
path uses should_pack_scale_and_zp_inputs = (nbits_ == 8) on ARM64).
Guard separate packing to match standard behavior.

Also tightened the BZpCorr condition during B packing to check
MlasQNBitGemmScalesPacked before computing correction.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@jambayk jambayk marked this pull request as ready for review March 24, 2026 16:30
@jambayk jambayk requested review from hariharans29 and vraspar March 24, 2026 16:31
@jambayk jambayk enabled auto-merge (squash) March 24, 2026 16:42
vraspar
vraspar previously approved these changes Mar 24, 2026
hariharans29
hariharans29 previously approved these changes Mar 24, 2026
@jambayk jambayk dismissed stale reviews from hariharans29 and vraspar via 2309452 March 24, 2026 22:56
vraspar
vraspar previously approved these changes Mar 24, 2026
…teType

- Replace graceful fallback with ORT_ENFORCE for scales_fp32_ in
  ComputeBPacked HQNBIT_CompInt8 path (PrePack always sets it).
- GetComputeType<MLFloat16> now checks MlasIsQNBitGemmAvailable with
  SQNBIT_CompInt8 directly instead of HQNBIT_CompInt8.
- Remove HQNBIT_CompInt8 -> SQNBIT_CompInt8 redirect from
  MlasIsQNBitGemmAvailable in MLAS layer.
@jambayk jambayk merged commit 36242c6 into main Mar 25, 2026
113 of 121 checks passed
@jambayk jambayk deleted the jambayk/mnb-4-16 branch March 25, 2026 23:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants