diff --git a/README.md b/README.md index fe5441e2e..8522468c7 100644 --- a/README.md +++ b/README.md @@ -27,13 +27,13 @@ and [fbaldassarri](https://huggingface.co/fbaldassarri). For usage instructions, ## ๐Ÿ†• What's New -[2025/10] We proposed a fast algorithm to generate mixed bits/datatypes schemes in minutes. Please +[2025/10] We proposed a fast algorithm to generate **mixed bits/datatypes** schemes in minutes. Please refer to the documentation for accuracy [results](./docs/auto_scheme_acc.md) and [this guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme) for usage instructions. -[2025/09] AutoRound now includes experimental support for the mxfp4 and nvfp4 dtypes. For accuracy results, see the [documentation](./docs/mxnv_acc.md) +[2025/09] AutoRound now includes experimental support for the **mxfp4 and nvfp4 dtypes**. For accuracy results, see the [documentation](./docs/mxnv_acc.md) . We currently recommend exporting to the LLM-Compressor format. -[2025/08] AutoRound now provides experimental support for an improved INT2 algorithm via `--enable_alg_ext`. See this [documentation](./docs/alg_202508.md) +[2025/08] AutoRound now provides experimental support for **an improved INT2 algorithm** via `--enable_alg_ext`. See this [documentation](./docs/alg_202508.md) for some accuracy results. [2025/07] AutoRound now offers experimental support for **GGUF** format, and recommends using optimized RTN mode (--iters 0) for @@ -67,7 +67,7 @@ Support **AutoRound, AutoAWQ, AutoGPTQ, and GGUF** for maximum compatibility. De โœ… **Affordable Quantization Cost** Quantize 7B models in about 10 minutes on a single GPU. Details are shown in [quantization costs](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#quantization-costs) -โœ… **Fast mixed bits/data-types scheme generation** +โœ… **Fast Mixed Bits/Dtypes Scheme Generation** Automatically configure in minutes, with about 2X-4X the modelโ€™s BF16 VRAM size as overhead. Accuracy [results](./docs/auto_scheme_acc.md) and [user guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme). โœ… **10+ VLMs Support** @@ -76,8 +76,8 @@ Out-of-the-box quantization for 10+ vision-language models [example models](http โœ… **Layerwise Mixed Bits Quantization** Assign different bits per layer for fine-grained accuracy/performance trade-offs. Details are shown in [mixed bits quantization](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#mixed-bits-usage) -โœ… **Round-to-Nearest Mode** -Use `--iters 0` for fast, calibration-free quantization with some accuracy drop. Details are shown in [rtn mode](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#rtn-mode) +โœ… **Optimized Round-to-Nearest Mode** +Use `--iters 0` for fast, calibration-free quantization with some accuracy drop for 4 bits. Details are shown in [opt_rtn mode](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#opt-rtn-mode) โœ… **Multiple Recipes** Choose from `auto-round-best`, `auto-round`, and `auto-round-light` to suit your needs. Details are shown in [quantization recipes](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#recipe-recommendation) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 4f4d42dfa..d948de65b 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1244,7 +1244,7 @@ def _quant_rtn_with_imatrix(self, all_to_quantized_module_names: list[str]) -> N Returns: None """ - logger.info("start to compute imatrix for GGUF quantization") + logger.info("start to compute imatrix") # Load dataset from auto_round.calib_dataset import get_dataloader @@ -1343,15 +1343,13 @@ def _quantize_layer_via_rtn(self, name: str) -> None: if _is_fp8_linear(m): m = convert_fp8_layer_to_linear(m, self.amp_dtype) set_module(self.model, name, m) - - # Step 1: Use optimized RTN data type if available - if not self.disable_opt_rtn and not m.data_type.startswith("rtn_"): - from auto_round.data_type import QUANT_FUNC_WITH_DTYPE - - rtn_dtype = "rtn_" + m.data_type - if rtn_dtype in QUANT_FUNC_WITH_DTYPE: - m.data_type = rtn_dtype - self.layer_config[name]["data_type"] = m.data_type + # + # # Step 1: Use optimized RTN data type if available + # if not self.disable_opt_rtn: + # rtn_data_type = self._check_rtn_dytpe(m.data_type, m.bits, m.sym) + # if rtn_data_type is not None: + # m.data_type = rtn_data_type + # self.layer_config[name]["data_type"] = m.data_type # Step 2: Try quantization on GPU first, fall back to CPU if OOM # if only export gguf, using gguf-packing instead of rtn @@ -1367,6 +1365,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None: enable_norm_bias_tuning=False, enable_round_tuning=False, enable_torch_compile=self.enable_torch_compile, + disable_opt_rtn=self.disable_opt_rtn, ) m = m.unwrapper({}) m.to("cpu") @@ -1457,7 +1456,14 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: self._quantize_embedding_layer() self.model.to("cpu") + + enable_imatrix = False if has_gguf_k and not self.disable_opt_rtn: + enable_imatrix = True + if self.data_type == "int" and self.sym: + enable_imatrix = True + + if enable_imatrix: self._quant_rtn_with_imatrix(all_to_quantized_module_names) elif self.act_bits <= 8 and check_need_act_calibration( self.act_dynamic, self.act_data_type, self.act_bits @@ -1800,8 +1806,8 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: Returns: None """ - ##TODO currently we take all the layers outside blocks as post block layers which is not optimal - ## if there is no input for layer, we use rtn + # TODO currently we take all the layers outside blocks as post block layers which is not optimal + # if there is no input for layer, we use rtn for layer_name in copy.deepcopy(layer_names): if layer_name not in layer_inputs: @@ -1815,10 +1821,6 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: set_module(self.model, layer_name, new_layer) layer = new_layer - if not self.disable_opt_rtn and "rtn_" + layer.data_type in QUANT_FUNC_WITH_DTYPE: - layer.data_type = "rtn_" + layer.data_type - logger.info("using optimized rtn method for quantizing %s", layer_name) - self.layer_config[layer_name]["data_type"] = layer.data_type wrapper_layer = WrapperLinear( layer, enable_round_tuning=False, @@ -1826,6 +1828,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: enable_norm_bias_tuning=False, enable_torch_compile=self.enable_torch_compile, device=self.device, + disable_opt_rtn=self.disable_opt_rtn, ) new_layer = wrapper_layer.unwrapper({}) set_module(self.model, layer_name, new_layer) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 5b151fdfe..fb6c5f3db 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -11,11 +11,75 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union import torch from auto_round.data_type.register import register_dtype from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste +from auto_round.utils import get_reciprocal + + +def search_scales(data: torch.Tensor, bits: int, qw: Union[None, torch.Tensor, float] = None) -> torch.Tensor: + nmax = pow(2, bits - 1) + imax = abs(data).argmax(axis=-1, keepdims=True) + group_max = torch.take_along_dim(data, imax, dim=-1) + iscales = -nmax * get_reciprocal(group_max) + scales = get_reciprocal(iscales) + L = torch.round(1.0 * iscales * data).clip(-nmax, nmax - 1) + if qw is None: + qw = 1.0 + best_loss = torch.sum(((scales * L - data).to(torch.float32)) ** 2 * qw, dim=-1) + for _is in range(-18 * 5, 18 * 5 + 1): + if _is == 0: + continue + iscales = -(nmax - 0.01 * _is) * get_reciprocal(group_max) + tmp_L = torch.round(iscales * data).clip(-nmax, nmax - 1) + tmp_scales = get_reciprocal(iscales) + loss = torch.sum(((tmp_scales * tmp_L - data).to(torch.float32)) ** 2 * qw, dim=-1) + replace_id = loss < best_loss + scales[replace_id] = tmp_scales[replace_id] + best_loss[replace_id] = loss[replace_id] + return scales + + +@register_dtype("rtn_int_sym") +def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5, imatrix=None, **kwargs): + """Quantize and de-quantize tensor asymmetrically. full range, credict goes to llamacpp community + + Args: + tensor: Tensor containing the tensor to be quantized + bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + group_size: Number of elements to share scale for quantization + v: Rounding value perturbation + min_scale: Minimum scale coefficient for tensor + max_scale: Maximum scale coefficient for tensor + tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. + tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. + scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import + q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability + + Returns: + Quantized and de-quantized tensor, scale, zero-point + """ + + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + maxq = 2 ** (bits - 1) + if imatrix is None: + imatrix = 1.0 + else: + imatrix = imatrix.reshape(1, -1) + + imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1) + imatrix = imatrix.reshape(tensor.shape) + + scale = search_scales(tensor, bits, qw=imatrix) + scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) + int_w = round_ste(tensor / scale + v) + q = torch.clamp(int_w, -maxq, maxq - 1) + qdq_result = (scale * q).to(tensor.dtype) + qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) + return qdq_result, scale, maxq @register_dtype("int_sym") diff --git a/auto_round/data_type/utils.py b/auto_round/data_type/utils.py index bdec2e0b4..1bb53a14b 100644 --- a/auto_round/data_type/utils.py +++ b/auto_round/data_type/utils.py @@ -87,7 +87,7 @@ def revert_tensor_by_pad(data: torch.Tensor, orig_shape: tuple, pad_len: int): return data_new -def get_quant_func(dtype, bits, sym): +def get_quant_func(dtype: str, bits: int, sym: bool, disable_opt_rtn=False) -> tuple[callable, str]: """Retrieve the quantization function based on data type, bit width, and symmetry. This function returns the appropriate quantization function from the QUANT_FUNC_WITH_DTYPE @@ -98,40 +98,38 @@ def get_quant_func(dtype, bits, sym): dtype (str): The data type for the quantization (e.g., 'int', 'mxfp4'). bits (int): The bit width for the quantization (e.g., 2,4,8). sym (bool): A flag indicating whether the quantization is symmetric (True) or asymmetric (False). + disable_opt_rtn(bool): whether to disable optimized rtn. Returns: function: The quantization function corresponding to the specified parameters. + str """ - key = dtype - if key in QUANT_FUNC_WITH_DTYPE.keys(): - return QUANT_FUNC_WITH_DTYPE[key], key - if sym: - key = dtype + "_sym" - else: - key = dtype + "_asym" + def pad_sym(data_type): + if sym: + data_sym = data_type + "_sym" + else: + data_sym = data_type + "_asym" + return data_sym - if key in QUANT_FUNC_WITH_DTYPE.keys(): - return QUANT_FUNC_WITH_DTYPE[key], key + def pad_bits(data_type): + return data_type + str(bits) - ##need to add bits and sym infos - if sym: - key = dtype + str(bits) + "_sym" - else: - key = dtype + str(bits) + "_asym" + if not disable_opt_rtn: + rtn_data_type = "rtn_" + dtype + data_types = [rtn_data_type, pad_bits(rtn_data_type), pad_sym(rtn_data_type), pad_sym(pad_bits(rtn_data_type))] + for data_type in data_types: + from auto_round.data_type import QUANT_FUNC_WITH_DTYPE - if key in QUANT_FUNC_WITH_DTYPE.keys(): - return QUANT_FUNC_WITH_DTYPE[key], key - - if sym: - key = dtype + str(bits) - else: - key = dtype + str(bits) + if data_type in QUANT_FUNC_WITH_DTYPE: + return QUANT_FUNC_WITH_DTYPE[data_type], data_type - if key in QUANT_FUNC_WITH_DTYPE.keys(): - return QUANT_FUNC_WITH_DTYPE[key], key + data_types = [dtype, pad_bits(dtype), pad_sym(dtype), pad_sym(pad_bits(dtype))] + for data_type in data_types: + from auto_round.data_type import QUANT_FUNC_WITH_DTYPE - raise ValueError(f"{dtype} is not supported") + if data_type in QUANT_FUNC_WITH_DTYPE: + return QUANT_FUNC_WITH_DTYPE[data_type], data_type def round_ste(x: torch.Tensor): diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 365523950..b31b81a16 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -230,7 +230,7 @@ def pack_layer(layer_name, model, backend, device=None): zp = int(zp.flatten()[0]) qlayer.to("cpu") - ##force to float32 to be compatible with torch 2.0 + # Force to float32 to be compatible with torch 2.0 sig = inspect.signature(qlayer.pack) param_count = len(sig.parameters) if param_count == 2: @@ -296,7 +296,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex return save_quantized_as_autoround(output_dir, inplace=inplace, backend="auto_round", **kwargs) - ##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source + # IF using sym, we change to gptq sym kernel to avoid compiling from auto_round source if ( (kwargs.get("sym") is None or kwargs.get("sym")) and ("gptq" not in backend and "awq" not in backend) diff --git a/auto_round/utils.py b/auto_round/utils.py index 6742011fa..c73d6a210 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -2929,6 +2929,7 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str if name in all_module_names: m = get_module(model, name) if len(list(m.children())) == 0 and type(m) not in supported_types: + layer_config.pop(name) logger.warning(f"{name} is not supported in current scheme, ignoring its setting in `layer_config`") continue diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index a9c0f5cb2..6a6227023 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -80,6 +80,7 @@ def __init__( device="cpu", enable_round_tuning=True, enable_torch_compile=False, + disable_opt_rtn=True, **kwargs, ): """Initializes the WrapperLinear module. @@ -92,6 +93,7 @@ def __init__( """ super(WrapperLinear, self).__init__() self.orig_layer = orig_layer + self.disable_opt_rtn = disable_opt_rtn self.output_device = device self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device self.enable_minmax_tuning = enable_minmax_tuning @@ -146,13 +148,15 @@ def _init_tuning_params_and_quant_func(self): self._init_params("min_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16)) self._init_params("max_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16)) - self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym) + self.weight_quant_func, self.data_type = get_quant_func( + orig_layer.data_type, orig_layer.bits, orig_layer.sym, self.disable_opt_rtn + ) if self.enable_torch_compile: self.weight_quant_func = compile_func(self.weight_quant_func, self.device) if self.enable_act_quant: self.act_quant_func, self.act_data_type = get_quant_func( - orig_layer.act_data_type, orig_layer.act_bits, orig_layer.act_sym + orig_layer.act_data_type, orig_layer.act_bits, orig_layer.act_sym, self.disable_opt_rtn ) if self.enable_torch_compile: self.act_quant_func = compile_func(self.act_quant_func, self.device) diff --git a/docs/opt_rtn.md b/docs/opt_rtn.md new file mode 100644 index 000000000..354a397c3 --- /dev/null +++ b/docs/opt_rtn.md @@ -0,0 +1,47 @@ +### ๐Ÿงฎ Evaluation Results (LM-Eval) +For 2/3bit, we strongly recommend not using iter=0 except for GGUF:Q2_K_S which has a different quantization algorithm. + +4BIT=W4A16 +3BIT=W3A16 +2BIT=W2A16G64 + +RTN mode + +~~~bash +auto-round --model xxx --disable_opt_rtn --iters 0 +~~~ + +OPT RTN mode + +~~~bash +auto-round --model xxx --iters 0 +~~~ + + + +| Model | RNT/OPT | AVG | HellaSwag | LAMBADA | MMLU | PIQA | WinoGrande | +|--------------------------------|----------|---------|-----------|---------|--------|--------|------------| +| **Meta-Llama-3.1-8B-Instruct** | RTN-4BIT | 0.69328 | 0.5896 | 0.7013 | 0.6538 | 0.7987 | 0.7230 | +| | OPT-4BIT | 0.69560 | 0.5882 | 0.7074 | 0.6631 | 0.7916 | 0.7277 | +| | RTN-3BIT | 0.64562 | 0.5410 | 0.6695 | 0.5449 | 0.7742 | 0.6985 | +| | OPT-3BIT | 0.65970 | 0.5490 | 0.6893 | 0.5711 | 0.7677 | 0.7214 | +| | RTN-2BIT | 0.33008 | 0.2918 | 0.0474 | 0.2321 | 0.5740 | 0.5051 | +| | OPT-2BIT | 0.38908 | 0.3241 | 0.1560 | 0.2822 | 0.6235 | 0.5596 | +| **Qwen2.5-7B-Instruct** | RTN-4BIT | 0.69560 | 0.6114 | 0.6713 | 0.7011 | 0.7878 | 0.7064 | +| | OPT-4BIT | 0.70034 | 0.6143 | 0.6945 | 0.7115 | 0.7845 | 0.6969 | +| | RTN-3BIT | 0.64144 | 0.5585 | 0.6092 | 0.6455 | 0.7476 | 0.6464 | +| | OPT-3BIT | 0.66764 | 0.5756 | 0.7013 | 0.6597 | 0.7481 | 0.6535 | +| | RTN-2BIT | 0.31856 | 0.2804 | 0.0351 | 0.2379 | 0.5256 | 0.5138 | +| | OPT-2BIT | 0.45146 | 0.3645 | 0.2992 | 0.4043 | 0.6415 | 0.5478 | +| **Qwen3-8B** | RTN-4BIT | 0.66240 | 0.5619 | 0.6150 | 0.7077 | 0.7573 | 0.6701 | +| | OPT-4BIT | 0.66992 | 0.5619 | 0.6346 | 0.7102 | 0.7633 | 0.6796 | +| | RTN-3BIT | 0.57322 | 0.4992 | 0.4260 | 0.6002 | 0.7361 | 0.6046 | +| | OPT-3BIT | 0.63698 | 0.5226 | 0.5814 | 0.6718 | 0.7437 | 0.6654 | +| | RTN-2BIT | 0.31150 | 0.2679 | 0.0041 | 0.2536 | 0.5283 | 0.5036 | +| | OPT-2BIT | 0.44254 | 0.3749 | 0.2005 | 0.4202 | 0.6670 | 0.5501 | +| **Qwen3-14B** | RTN-4BIT | 0.70448 | 0.5999 | 0.6511 | 0.7565 | 0.7998 | 0.7151 | +| | OPT-4BIT | 0.70798 | 0.6031 | 0.6627 | 0.7534 | 0.8009 | 0.7198 | +| | RTN-3BIT | 0.65876 | 0.5746 | 0.5467 | 0.7065 | 0.7628 | 0.7032 | +| | OPT-3BIT | 0.68610 | 0.5683 | 0.6633 | 0.7258 | 0.7699 | 0.7032 | +| | RTN-2BIT | 0.39398 | 0.3764 | 0.0607 | 0.3836 | 0.6480 | 0.5012 | +| | OPT-2BIT | 0.50080 | 0.4554 | 0.2451 | 0.4899 | 0.7138 | 0.5998 | \ No newline at end of file diff --git a/docs/step_by_step.md b/docs/step_by_step.md index ba27ef5d1..9a518e31f 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -23,7 +23,7 @@ This document presents step-by-step instructions for auto-round llm quantization - [CLI Usage](#cli-usage) - [API Usage](#api-usage-1) - [Hyperparameters in AutoScheme](#hyperparameters-in-autoscheme) - + [RTN mode](#rtn-mode) + + [OPT RTN mode](#opt-rtn-mode) + [GGUF format](#gguf-format) + [Quantization Costs](#quantization-costs) + [Device/Multi-GPU setting in Quantization](#device-multi-gpu-setting-in-quantization) @@ -368,8 +368,8 @@ We will try to optimize the VRAM usage in the future. Embedding layer is not supported in AutoScheme, it will use the best scheme in options. -### RTN mode -AutoRound also supports RTN (Round-To-Nearest) mode for fast, calibration-free baseline quantization. try setting `iters=0` and use `group_size=32` for better results. +### OPT RTN Mode +AutoRound also supports Optimized RTN (Round-To-Nearest) mode for fast, calibration-free baseline quantization. Setting `iters=0` tp enable it and we recommend using `group_size=32` for better results. Check [accuracy comparison](./opt_rtn.md) between RTN and OPT RTN mode For the GGUF format, we have optimized the RTN algorithm inspired by llamacpp. To use the original (pure) RTN algorithm instead, enable the `--disable_opt_rtn` option. ```python