Skip to content

[Feature Request] MatMulNBits 2-bits Weights + Float Zero Point kernel support #28162

@thpereir

Description

@thpereir

Describe the feature request

Summary

Add native support in the MatMulNBits operator for 2-bit quantized weights paired with float16/bfloat16 Zero Points.

Motivation

2-bit quantization (e.g., from QAD/Quark-quantized models) uses non-uniform quantization levels such as [-1, -1/3, 1/3, 1]. These levels can be expressed in the MatMulNBits dequantization formula (index - zero_point) * scale only when zero_point = 1.5 — a fractional value that cannot be represented as an integer zero-point.

Currently, onnxruntime-genai works around this by packing 2-bit weights into MatMulNBits nodes and passing a float16 zero-point tensor of value 1.5 per group. However, MatMulNBits kernels does not implement support for:

  • bits = 2
  • float16/bfloat16 zero-points (as opposed to the packed uint8 integer zero-points used for 4-bit)

This creates a fragile dependency on undocumented runtime behavior.

Requested Changes

  1. Kernel support for bits = 2 — ensure the MatMulNBits CUDA, CPU, and other EP kernels correctly handle 2-bit packed weight tensors (4 values per byte, LSB-first).

  2. Float16/float32 zero-point input — formally specify and implement support for float-typed zero-points in MatMulNBits. For 2-bit non-uniform quantization the zero-point is fractional (e.g., 1.5) and must be stored as a float.

Describe scenario use case

Example Dequantization for 2-bit QAD

dequant_value = (uint2_index - 1.5f16) * (original_scale * (2/3))

uint2 index dequant value (relative to scale s)
0 -s
1 -s/3
2 +s/3
3 +s

Metadata

Metadata

Assignees

No one assigned

    Labels

    feature requestrequest for unsupported feature or enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions