Optimize MatMulNBits 2-bit + float zero_point CPU dequantization with multi-threaded kernel#28589
Open
Copilot wants to merge 4 commits into
Open
Optimize MatMulNBits 2-bit + float zero_point CPU dequantization with multi-threaded kernel#28589Copilot wants to merge 4 commits into
Copilot wants to merge 4 commits into
Conversation
…ti-threaded kernel Replace the naive single-threaded scalar loop for 2-bit quantization with float/MLFloat16 zero points with a multi-threaded implementation using TrySimpleParallelFor. The new DequantizeBlockwise2Bits function processes 16 elements (one uint32 of packed 2-bit values) per iteration and distributes work across available threads, matching the parallelism pattern used by the existing 4-bit DequantizeBlockwise path. Agent-Logs-Url: https://github.com/microsoft/onnxruntime/sessions/76231b1d-cdea-427a-8824-29293b1d02eb Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
Copilot
AI
changed the title
[WIP] Optimize MatMulNBits performance for bits=2 and float zero_point
Optimize MatMulNBits 2-bit + float zero_point CPU dequantization with multi-threaded kernel
May 20, 2026
Contributor
There was a problem hiding this comment.
Pull request overview
Optimizes the CPU fallback path for MatMulNBits when bits=2 and zero_points are float/MLFloat16 by replacing a scalar single-thread dequantization loop with a threaded, blockwise dequantization kernel, aiming to remove the large performance regression reported for this configuration.
Changes:
- Adds a multi-threaded
DequantizeBlockwise2Bitskernel that processes 16 values per iteration and parallelizes viaTrySimpleParallelFor. - Switches the
MatMulNBits<float>andMatMulNBits<MLFloat16>unpacked compute paths to use the new 2-bit dequant kernel. - Adds a Python benchmark script for measuring 2-bit vs 4-bit performance and thread scaling.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| onnxruntime/test/python/quantization/bench_matmul_2bits.py | Adds a standalone benchmark script for MatMulNBits 2-bit float-ZP performance on CPU. |
| onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | Replaces naive 2-bit float/fp16-ZP dequant loops with DequantizeBlockwise2Bits calls. |
| onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h | Declares DequantizeBlockwise2Bits template API. |
| onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc | Implements the threaded 2-bit dequantization kernel and adds explicit instantiations. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tianleiwu
approved these changes
May 20, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Replace the naive single-threaded scalar loop for 2-bit dequantization with float/MLFloat16 zero points with a multi-threaded kernel (
DequantizeBlockwise2Bits) that:TrySimpleParallelFor— distributes work across all intra-op threads (previously single-threaded)uint32_t= 16 packed 2-bit values, reducing per-element overheadFollows the same threading pattern as the existing 4-bit
DequantizeBlockwisepath for consistency.Files changed:
matmul_nbits_impl.h— declareDequantizeBlockwise2Bitsmatmul_nbits_impl.cc— implementDequantize2BitsKernel+DequantizeBlockwise2Bitswith instantiations for<float,float>and<float,MLFloat16>matmul_nbits.cc— replace naive loops in bothMatMulNBits<float>andMatMulNBits<MLFloat16>ComputeBUnpackedMotivation and Context
The
bits=2+ float zero_point path (added in #28354) was flagged with// !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!. It ran ~20× slower than thebits=4MLAS path because it was a tight scalarfor n × for kloop with no threading — the entire N×K dequantization ran on a single core before callingMlasGemmBatch. With 8 intra-op threads this should recover most of that gap.Benchmark Results
Tested on a 96-core x86_64 Linux machine, ORT 1.27.0 CPU Release build, using typical LLM matrix shapes with
block_size=128and float zero points.Multi-thread speedup (2-bit dequantization, 1 thread → 8 threads)
2-bit vs 4-bit comparison (ratio = 2-bit / 4-bit; <1.0 means 2-bit is faster)
Key findings: