Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ and [fbaldassarri](https://huggingface.co/fbaldassarri). For usage instructions,


## 🆕 What's New
[2025/10] We proposed a fast algorithm to generate mixed bits/datatypes schemes in minutes. Please
[2025/10] We proposed a fast algorithm to generate **mixed bits/datatypes** schemes in minutes. Please
refer to the documentation for accuracy [results](./docs/auto_scheme_acc.md) and [this guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme) for usage instructions.

[2025/09] AutoRound now includes experimental support for the mxfp4 and nvfp4 dtypes. For accuracy results, see the [documentation](./docs/mxnv_acc.md)
[2025/09] AutoRound now includes experimental support for the **mxfp4 and nvfp4 dtypes**. For accuracy results, see the [documentation](./docs/mxnv_acc.md)
. We currently recommend exporting to the LLM-Compressor format.

[2025/08] AutoRound now provides experimental support for an improved INT2 algorithm via `--enable_alg_ext`. See this [documentation](./docs/alg_202508.md)
[2025/08] AutoRound now provides experimental support for **an improved INT2 algorithm** via `--enable_alg_ext`. See this [documentation](./docs/alg_202508.md)
for some accuracy results.

[2025/07] AutoRound now offers experimental support for **GGUF** format, and recommends using optimized RTN mode (--iters 0) for
Expand Down Expand Up @@ -67,7 +67,7 @@ Support **AutoRound, AutoAWQ, AutoGPTQ, and GGUF** for maximum compatibility. De
✅ **Affordable Quantization Cost**
Quantize 7B models in about 10 minutes on a single GPU. Details are shown in [quantization costs](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#quantization-costs)

✅ **Fast mixed bits/data-types scheme generation**
✅ **Fast Mixed Bits/Dtypes Scheme Generation**
Automatically configure in minutes, with about 2X-4X the model’s BF16 VRAM size as overhead. Accuracy [results](./docs/auto_scheme_acc.md) and [user guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme).

✅ **10+ VLMs Support**
Expand All @@ -76,8 +76,8 @@ Out-of-the-box quantization for 10+ vision-language models [example models](http
✅ **Layerwise Mixed Bits Quantization**
Assign different bits per layer for fine-grained accuracy/performance trade-offs. Details are shown in [mixed bits quantization](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#mixed-bits-usage)

✅ **Round-to-Nearest Mode**
Use `--iters 0` for fast, calibration-free quantization with some accuracy drop. Details are shown in [rtn mode](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#rtn-mode)
✅ **Optimized Round-to-Nearest Mode**
Use `--iters 0` for fast, calibration-free quantization with some accuracy drop for 4 bits. Details are shown in [opt_rtn mode](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#opt-rtn-mode)

✅ **Multiple Recipes**
Choose from `auto-round-best`, `auto-round`, and `auto-round-light` to suit your needs. Details are shown in [quantization recipes](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#recipe-recommendation)
Expand Down
35 changes: 19 additions & 16 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,7 @@ def _quant_rtn_with_imatrix(self, all_to_quantized_module_names: list[str]) -> N
Returns:
None
"""
logger.info("start to compute imatrix for GGUF quantization")
logger.info("start to compute imatrix")

# Load dataset
from auto_round.calib_dataset import get_dataloader
Expand Down Expand Up @@ -1343,15 +1343,13 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
if _is_fp8_linear(m):
m = convert_fp8_layer_to_linear(m, self.amp_dtype)
set_module(self.model, name, m)

# Step 1: Use optimized RTN data type if available
if not self.disable_opt_rtn and not m.data_type.startswith("rtn_"):
from auto_round.data_type import QUANT_FUNC_WITH_DTYPE

rtn_dtype = "rtn_" + m.data_type
if rtn_dtype in QUANT_FUNC_WITH_DTYPE:
m.data_type = rtn_dtype
self.layer_config[name]["data_type"] = m.data_type
#
# # Step 1: Use optimized RTN data type if available
# if not self.disable_opt_rtn:
# rtn_data_type = self._check_rtn_dytpe(m.data_type, m.bits, m.sym)
# if rtn_data_type is not None:
# m.data_type = rtn_data_type
# self.layer_config[name]["data_type"] = m.data_type

# Step 2: Try quantization on GPU first, fall back to CPU if OOM
# if only export gguf, using gguf-packing instead of rtn
Expand All @@ -1367,6 +1365,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
enable_norm_bias_tuning=False,
enable_round_tuning=False,
enable_torch_compile=self.enable_torch_compile,
disable_opt_rtn=self.disable_opt_rtn,
)
m = m.unwrapper({})
m.to("cpu")
Expand Down Expand Up @@ -1457,7 +1456,14 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
self._quantize_embedding_layer()

self.model.to("cpu")

enable_imatrix = False
if has_gguf_k and not self.disable_opt_rtn:
enable_imatrix = True
if self.data_type == "int" and self.sym:
enable_imatrix = True

if enable_imatrix:
self._quant_rtn_with_imatrix(all_to_quantized_module_names)
elif self.act_bits <= 8 and check_need_act_calibration(
self.act_dynamic, self.act_data_type, self.act_bits
Expand Down Expand Up @@ -1800,8 +1806,8 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
Returns:
None
"""
##TODO currently we take all the layers outside blocks as post block layers which is not optimal
## if there is no input for layer, we use rtn
# TODO currently we take all the layers outside blocks as post block layers which is not optimal
# if there is no input for layer, we use rtn

for layer_name in copy.deepcopy(layer_names):
if layer_name not in layer_inputs:
Expand All @@ -1815,17 +1821,14 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
set_module(self.model, layer_name, new_layer)
layer = new_layer

if not self.disable_opt_rtn and "rtn_" + layer.data_type in QUANT_FUNC_WITH_DTYPE:
layer.data_type = "rtn_" + layer.data_type
logger.info("using optimized rtn method for quantizing %s", layer_name)
self.layer_config[layer_name]["data_type"] = layer.data_type
wrapper_layer = WrapperLinear(
layer,
enable_round_tuning=False,
enable_minmax_tuning=False,
enable_norm_bias_tuning=False,
enable_torch_compile=self.enable_torch_compile,
device=self.device,
disable_opt_rtn=self.disable_opt_rtn,
)
new_layer = wrapper_layer.unwrapper({})
set_module(self.model, layer_name, new_layer)
Expand Down
64 changes: 64 additions & 0 deletions auto_round/data_type/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,75 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

import torch

from auto_round.data_type.register import register_dtype
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste
from auto_round.utils import get_reciprocal


def search_scales(data: torch.Tensor, bits: int, qw: Union[None, torch.Tensor, float] = None) -> torch.Tensor:
nmax = pow(2, bits - 1)
imax = abs(data).argmax(axis=-1, keepdims=True)
group_max = torch.take_along_dim(data, imax, dim=-1)
iscales = -nmax * get_reciprocal(group_max)
scales = get_reciprocal(iscales)
L = torch.round(1.0 * iscales * data).clip(-nmax, nmax - 1)
if qw is None:
qw = 1.0
best_loss = torch.sum(((scales * L - data).to(torch.float32)) ** 2 * qw, dim=-1)
for _is in range(-18 * 5, 18 * 5 + 1):
if _is == 0:
continue
iscales = -(nmax - 0.01 * _is) * get_reciprocal(group_max)
tmp_L = torch.round(iscales * data).clip(-nmax, nmax - 1)
tmp_scales = get_reciprocal(iscales)
loss = torch.sum(((tmp_scales * tmp_L - data).to(torch.float32)) ** 2 * qw, dim=-1)
replace_id = loss < best_loss
scales[replace_id] = tmp_scales[replace_id]
best_loss[replace_id] = loss[replace_id]
return scales


@register_dtype("rtn_int_sym")
def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5, imatrix=None, **kwargs):
"""Quantize and de-quantize tensor asymmetrically. full range, credict goes to llamacpp community

Args:
tensor: Tensor containing the tensor to be quantized
bits: Number of bits for quantization (e.g., 2, 3, 4, 8)
group_size: Number of elements to share scale for quantization
v: Rounding value perturbation
min_scale: Minimum scale coefficient for tensor
max_scale: Maximum scale coefficient for tensor
tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None.
tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None.
scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import
q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability

Returns:
Quantized and de-quantized tensor, scale, zero-point
"""

tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
maxq = 2 ** (bits - 1)
if imatrix is None:
imatrix = 1.0
else:
imatrix = imatrix.reshape(1, -1)

imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1)
imatrix = imatrix.reshape(tensor.shape)

scale = search_scales(tensor, bits, qw=imatrix)
scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh))
int_w = round_ste(tensor / scale + v)
q = torch.clamp(int_w, -maxq, maxq - 1)
qdq_result = (scale * q).to(tensor.dtype)
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)
return qdq_result, scale, maxq


@register_dtype("int_sym")
Expand Down
48 changes: 23 additions & 25 deletions auto_round/data_type/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def revert_tensor_by_pad(data: torch.Tensor, orig_shape: tuple, pad_len: int):
return data_new


def get_quant_func(dtype, bits, sym):
def get_quant_func(dtype: str, bits: int, sym: bool, disable_opt_rtn=False) -> tuple[callable, str]:
"""Retrieve the quantization function based on data type, bit width, and symmetry.

This function returns the appropriate quantization function from the QUANT_FUNC_WITH_DTYPE
Expand All @@ -98,40 +98,38 @@ def get_quant_func(dtype, bits, sym):
dtype (str): The data type for the quantization (e.g., 'int', 'mxfp4').
bits (int): The bit width for the quantization (e.g., 2,4,8).
sym (bool): A flag indicating whether the quantization is symmetric (True) or asymmetric (False).
disable_opt_rtn(bool): whether to disable optimized rtn.

Returns:
function: The quantization function corresponding to the specified parameters.
str
"""
key = dtype
if key in QUANT_FUNC_WITH_DTYPE.keys():
return QUANT_FUNC_WITH_DTYPE[key], key

if sym:
key = dtype + "_sym"
else:
key = dtype + "_asym"
def pad_sym(data_type):
if sym:
data_sym = data_type + "_sym"
else:
data_sym = data_type + "_asym"
return data_sym

if key in QUANT_FUNC_WITH_DTYPE.keys():
return QUANT_FUNC_WITH_DTYPE[key], key
def pad_bits(data_type):
return data_type + str(bits)

##need to add bits and sym infos
if sym:
key = dtype + str(bits) + "_sym"
else:
key = dtype + str(bits) + "_asym"
if not disable_opt_rtn:
rtn_data_type = "rtn_" + dtype
data_types = [rtn_data_type, pad_bits(rtn_data_type), pad_sym(rtn_data_type), pad_sym(pad_bits(rtn_data_type))]
for data_type in data_types:
from auto_round.data_type import QUANT_FUNC_WITH_DTYPE

if key in QUANT_FUNC_WITH_DTYPE.keys():
return QUANT_FUNC_WITH_DTYPE[key], key

if sym:
key = dtype + str(bits)
else:
key = dtype + str(bits)
if data_type in QUANT_FUNC_WITH_DTYPE:
return QUANT_FUNC_WITH_DTYPE[data_type], data_type

if key in QUANT_FUNC_WITH_DTYPE.keys():
return QUANT_FUNC_WITH_DTYPE[key], key
data_types = [dtype, pad_bits(dtype), pad_sym(dtype), pad_sym(pad_bits(dtype))]
for data_type in data_types:
from auto_round.data_type import QUANT_FUNC_WITH_DTYPE

raise ValueError(f"{dtype} is not supported")
if data_type in QUANT_FUNC_WITH_DTYPE:
return QUANT_FUNC_WITH_DTYPE[data_type], data_type


def round_ste(x: torch.Tensor):
Expand Down
4 changes: 2 additions & 2 deletions auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def pack_layer(layer_name, model, backend, device=None):
zp = int(zp.flatten()[0])

qlayer.to("cpu")
##force to float32 to be compatible with torch 2.0
# Force to float32 to be compatible with torch 2.0
sig = inspect.signature(qlayer.pack)
param_count = len(sig.parameters)
if param_count == 2:
Expand Down Expand Up @@ -296,7 +296,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex

return save_quantized_as_autoround(output_dir, inplace=inplace, backend="auto_round", **kwargs)

##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source
# IF using sym, we change to gptq sym kernel to avoid compiling from auto_round source
if (
(kwargs.get("sym") is None or kwargs.get("sym"))
and ("gptq" not in backend and "awq" not in backend)
Expand Down
1 change: 1 addition & 0 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2929,6 +2929,7 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str
if name in all_module_names:
m = get_module(model, name)
if len(list(m.children())) == 0 and type(m) not in supported_types:
layer_config.pop(name)
logger.warning(f"{name} is not supported in current scheme, ignoring its setting in `layer_config`")
continue

Expand Down
8 changes: 6 additions & 2 deletions auto_round/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
device="cpu",
enable_round_tuning=True,
enable_torch_compile=False,
disable_opt_rtn=True,
**kwargs,
):
"""Initializes the WrapperLinear module.
Expand All @@ -92,6 +93,7 @@ def __init__(
"""
super(WrapperLinear, self).__init__()
self.orig_layer = orig_layer
self.disable_opt_rtn = disable_opt_rtn
self.output_device = device
self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device
self.enable_minmax_tuning = enable_minmax_tuning
Expand Down Expand Up @@ -146,13 +148,15 @@ def _init_tuning_params_and_quant_func(self):
self._init_params("min_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16))
self._init_params("max_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16))

self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym)
self.weight_quant_func, self.data_type = get_quant_func(
orig_layer.data_type, orig_layer.bits, orig_layer.sym, self.disable_opt_rtn
)
if self.enable_torch_compile:
self.weight_quant_func = compile_func(self.weight_quant_func, self.device)

if self.enable_act_quant:
self.act_quant_func, self.act_data_type = get_quant_func(
orig_layer.act_data_type, orig_layer.act_bits, orig_layer.act_sym
orig_layer.act_data_type, orig_layer.act_bits, orig_layer.act_sym, self.disable_opt_rtn
)
if self.enable_torch_compile:
self.act_quant_func = compile_func(self.act_quant_func, self.device)
Expand Down
47 changes: 47 additions & 0 deletions docs/opt_rtn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
### 🧮 Evaluation Results (LM-Eval)
For 2/3bit, we strongly recommend not using iter=0 except for GGUF:Q2_K_S which has a different quantization algorithm.

4BIT=W4A16
3BIT=W3A16
2BIT=W2A16G64

RTN mode

~~~bash
auto-round --model xxx --disable_opt_rtn --iters 0
~~~

OPT RTN mode

~~~bash
auto-round --model xxx --iters 0
~~~



| Model | RNT/OPT | AVG | HellaSwag | LAMBADA | MMLU | PIQA | WinoGrande |
|--------------------------------|----------|---------|-----------|---------|--------|--------|------------|
| **Meta-Llama-3.1-8B-Instruct** | RTN-4BIT | 0.69328 | 0.5896 | 0.7013 | 0.6538 | 0.7987 | 0.7230 |
| | OPT-4BIT | 0.69560 | 0.5882 | 0.7074 | 0.6631 | 0.7916 | 0.7277 |
| | RTN-3BIT | 0.64562 | 0.5410 | 0.6695 | 0.5449 | 0.7742 | 0.6985 |
| | OPT-3BIT | 0.65970 | 0.5490 | 0.6893 | 0.5711 | 0.7677 | 0.7214 |
| | RTN-2BIT | 0.33008 | 0.2918 | 0.0474 | 0.2321 | 0.5740 | 0.5051 |
| | OPT-2BIT | 0.38908 | 0.3241 | 0.1560 | 0.2822 | 0.6235 | 0.5596 |
| **Qwen2.5-7B-Instruct** | RTN-4BIT | 0.69560 | 0.6114 | 0.6713 | 0.7011 | 0.7878 | 0.7064 |
| | OPT-4BIT | 0.70034 | 0.6143 | 0.6945 | 0.7115 | 0.7845 | 0.6969 |
| | RTN-3BIT | 0.64144 | 0.5585 | 0.6092 | 0.6455 | 0.7476 | 0.6464 |
| | OPT-3BIT | 0.66764 | 0.5756 | 0.7013 | 0.6597 | 0.7481 | 0.6535 |
| | RTN-2BIT | 0.31856 | 0.2804 | 0.0351 | 0.2379 | 0.5256 | 0.5138 |
| | OPT-2BIT | 0.45146 | 0.3645 | 0.2992 | 0.4043 | 0.6415 | 0.5478 |
| **Qwen3-8B** | RTN-4BIT | 0.66240 | 0.5619 | 0.6150 | 0.7077 | 0.7573 | 0.6701 |
| | OPT-4BIT | 0.66992 | 0.5619 | 0.6346 | 0.7102 | 0.7633 | 0.6796 |
| | RTN-3BIT | 0.57322 | 0.4992 | 0.4260 | 0.6002 | 0.7361 | 0.6046 |
| | OPT-3BIT | 0.63698 | 0.5226 | 0.5814 | 0.6718 | 0.7437 | 0.6654 |
| | RTN-2BIT | 0.31150 | 0.2679 | 0.0041 | 0.2536 | 0.5283 | 0.5036 |
| | OPT-2BIT | 0.44254 | 0.3749 | 0.2005 | 0.4202 | 0.6670 | 0.5501 |
| **Qwen3-14B** | RTN-4BIT | 0.70448 | 0.5999 | 0.6511 | 0.7565 | 0.7998 | 0.7151 |
| | OPT-4BIT | 0.70798 | 0.6031 | 0.6627 | 0.7534 | 0.8009 | 0.7198 |
| | RTN-3BIT | 0.65876 | 0.5746 | 0.5467 | 0.7065 | 0.7628 | 0.7032 |
| | OPT-3BIT | 0.68610 | 0.5683 | 0.6633 | 0.7258 | 0.7699 | 0.7032 |
| | RTN-2BIT | 0.39398 | 0.3764 | 0.0607 | 0.3836 | 0.6480 | 0.5012 |
| | OPT-2BIT | 0.50080 | 0.4554 | 0.2451 | 0.4899 | 0.7138 | 0.5998 |
Loading