From a6f95ac12b18beeb380dbea53c76ff00d3e7a669 Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Tue, 21 Oct 2025 12:49:38 +0800 Subject: [PATCH 01/14] optimize rtn for int woq --- auto_round/compressors/base.py | 58 +++++++++---- auto_round/data_type/int.py | 84 +++++++++++++++++++ auto_round/data_type/utils.py | 54 ++++++------ .../export/export_to_autoround/export.py | 8 +- auto_round/wrapper.py | 6 +- 5 files changed, 161 insertions(+), 49 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 4f4d42dfa..09216cd5c 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1244,7 +1244,7 @@ def _quant_rtn_with_imatrix(self, all_to_quantized_module_names: list[str]) -> N Returns: None """ - logger.info("start to compute imatrix for GGUF quantization") + logger.info("start to compute imatrix") # Load dataset from auto_round.calib_dataset import get_dataloader @@ -1343,15 +1343,14 @@ def _quantize_layer_via_rtn(self, name: str) -> None: if _is_fp8_linear(m): m = convert_fp8_layer_to_linear(m, self.amp_dtype) set_module(self.model, name, m) + # + # # Step 1: Use optimized RTN data type if available + # if not self.disable_opt_rtn: + # rtn_data_type = self._check_rtn_dytpe(m.data_type, m.bits, m.sym) + # if rtn_data_type is not None: + # m.data_type = rtn_data_type + # self.layer_config[name]["data_type"] = m.data_type - # Step 1: Use optimized RTN data type if available - if not self.disable_opt_rtn and not m.data_type.startswith("rtn_"): - from auto_round.data_type import QUANT_FUNC_WITH_DTYPE - - rtn_dtype = "rtn_" + m.data_type - if rtn_dtype in QUANT_FUNC_WITH_DTYPE: - m.data_type = rtn_dtype - self.layer_config[name]["data_type"] = m.data_type # Step 2: Try quantization on GPU first, fall back to CPU if OOM # if only export gguf, using gguf-packing instead of rtn @@ -1367,6 +1366,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None: enable_norm_bias_tuning=False, enable_round_tuning=False, enable_torch_compile=self.enable_torch_compile, + disable_opt_rtn=self.disable_opt_rtn ) m = m.unwrapper({}) m.to("cpu") @@ -1457,7 +1457,14 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: self._quantize_embedding_layer() self.model.to("cpu") + + enable_imatrix = False if has_gguf_k and not self.disable_opt_rtn: + enable_imatrix = True + if self.data_type=="int" and self.sym: + enable_imatrix = True + + if enable_imatrix: self._quant_rtn_with_imatrix(all_to_quantized_module_names) elif self.act_bits <= 8 and check_need_act_calibration( self.act_dynamic, self.act_data_type, self.act_bits @@ -1790,6 +1797,30 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: self.quantized = True return self.model, self.layer_config + + def _check_rtn_dytpe(self,data_type,bit, sym): + """Check if the given data type is an RTN (Round-To-Nearest) type. + + Args: + data_type (str): The data type to check. + """ + def pad_sym(dtype): + if sym: + data_sym = dtype + "_sym" + else: + data_sym = dtype+ "_asym" + return data_sym + def pad_bits(dtype): + return dtype+str(bit) + data_type = "rtn_"+data_type + data_types=[data_type, pad_bits(data_type), pad_sym(data_type), pad_sym(pad_bits(data_type))] + for data_type in data_types: + from auto_round.data_type import QUANT_FUNC_WITH_DTYPE + if data_type in QUANT_FUNC_WITH_DTYPE: + return data_type + return None + + def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: """Quantizes specified layers based on inputs and configuration. @@ -1800,8 +1831,8 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: Returns: None """ - ##TODO currently we take all the layers outside blocks as post block layers which is not optimal - ## if there is no input for layer, we use rtn + # TODO currently we take all the layers outside blocks as post block layers which is not optimal + # if there is no input for layer, we use rtn for layer_name in copy.deepcopy(layer_names): if layer_name not in layer_inputs: @@ -1815,10 +1846,6 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: set_module(self.model, layer_name, new_layer) layer = new_layer - if not self.disable_opt_rtn and "rtn_" + layer.data_type in QUANT_FUNC_WITH_DTYPE: - layer.data_type = "rtn_" + layer.data_type - logger.info("using optimized rtn method for quantizing %s", layer_name) - self.layer_config[layer_name]["data_type"] = layer.data_type wrapper_layer = WrapperLinear( layer, enable_round_tuning=False, @@ -1826,6 +1853,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: enable_norm_bias_tuning=False, enable_torch_compile=self.enable_torch_compile, device=self.device, + disable_opt_rtn=self.disable_opt_rtn, ) new_layer = wrapper_layer.unwrapper({}) set_module(self.model, layer_name, new_layer) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 5b151fdfe..c6bad388a 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -16,6 +16,90 @@ from auto_round.data_type.register import register_dtype from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste +from auto_round.utils import get_reciprocal + + +def search_scales(data, bits, qw=None): + nmax = pow(2, bits - 1) + imax = abs(data).argmax(axis=-1, keepdims=True) + group_max = torch.take_along_dim(data, imax, dim=-1) + iscales = -nmax * get_reciprocal(group_max) + scales = get_reciprocal(iscales) + L = torch.round(1.0 * iscales * data).clip(-nmax, nmax - 1) + if qw is None: + qw = 1.0 + best_loss = torch.sum(((scales * L - data).to(torch.float32)) ** 2 * qw, dim=-1) + for _is in range(-18 * 5, 18 * 5 + 1): + if _is == 0: + continue + iscales = -(nmax - 0.01 * _is) * get_reciprocal(group_max) + tmp_L = torch.round(iscales * data).clip(-nmax, nmax - 1) + tmp_scales = get_reciprocal(iscales) + loss = torch.sum(((tmp_scales * tmp_L - data).to(torch.float32)) ** 2 * qw, dim=-1) + replace_id = loss < best_loss + scales[replace_id] = tmp_scales[replace_id] + best_loss[replace_id] = loss[replace_id] + return scales + + +@register_dtype("rtn_int_sym") +def quant_tensor_sym( + tensor, + bits=4, + group_size=-1, + v=0, + min_scale=1.0, + max_scale=1.0, + scale_dtype=torch.float16, + tensor_min=None, + tensor_max=None, + q_scale_thresh=1e-5, + imatrix=None, + **kwargs +): + """Quantize and de-quantize tensor asymmetrically. full range, credict goes to llamacpp community + + Args: + tensor: Tensor containing the tensor to be quantized + bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + group_size: Number of elements to share scale for quantization + v: Rounding value perturbation + min_scale: Minimum scale coefficient for tensor + max_scale: Maximum scale coefficient for tensor + tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. + tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. + scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import + q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability + + Returns: + Quantized and de-quantized tensor, scale, zero-point + """ + + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + maxq = 2 ** (bits - 1) + # if tensor_min is None or tensor_max is None: + # wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) + # wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) + # else: + # wmin_tmp = tensor_min + # wmax_tmp = tensor_max + # + # wmin_abs = -(wmin_tmp * min_scale) # pylint: disable=E1130 + # wmax_abs = wmax_tmp * max_scale + # max_v = (2 * (wmax_abs < wmin_abs).int() - 1) * torch.max(wmax_abs, wmin_abs) + # scale = (max_v / maxq).to(scale_dtype) + imatrix = imatrix.reshape(1, -1) + + imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1) + imatrix = imatrix.reshape(tensor.shape) + + scale = search_scales(tensor, bits, qw=imatrix) + scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) + int_w = round_ste(tensor / scale + v) + q = torch.clamp(int_w, -maxq, maxq - 1) + qdq_result = (scale * q).to(tensor.dtype) + qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) + return qdq_result, scale, maxq @register_dtype("int_sym") diff --git a/auto_round/data_type/utils.py b/auto_round/data_type/utils.py index bdec2e0b4..7507e67f4 100644 --- a/auto_round/data_type/utils.py +++ b/auto_round/data_type/utils.py @@ -87,7 +87,7 @@ def revert_tensor_by_pad(data: torch.Tensor, orig_shape: tuple, pad_len: int): return data_new -def get_quant_func(dtype, bits, sym): +def get_quant_func(dtype:str, bits:int, sym:bool, disable_opt_rtn=False)->tuple[callable,str]: """Retrieve the quantization function based on data type, bit width, and symmetry. This function returns the appropriate quantization function from the QUANT_FUNC_WITH_DTYPE @@ -98,40 +98,36 @@ def get_quant_func(dtype, bits, sym): dtype (str): The data type for the quantization (e.g., 'int', 'mxfp4'). bits (int): The bit width for the quantization (e.g., 2,4,8). sym (bool): A flag indicating whether the quantization is symmetric (True) or asymmetric (False). + disable_opt_rtn(bool): whether to disable optimized rtn. Returns: function: The quantization function corresponding to the specified parameters. + str """ - key = dtype - if key in QUANT_FUNC_WITH_DTYPE.keys(): - return QUANT_FUNC_WITH_DTYPE[key], key - if sym: - key = dtype + "_sym" - else: - key = dtype + "_asym" - - if key in QUANT_FUNC_WITH_DTYPE.keys(): - return QUANT_FUNC_WITH_DTYPE[key], key - - ##need to add bits and sym infos - if sym: - key = dtype + str(bits) + "_sym" - else: - key = dtype + str(bits) + "_asym" - - if key in QUANT_FUNC_WITH_DTYPE.keys(): - return QUANT_FUNC_WITH_DTYPE[key], key + def pad_sym(data_type): + if sym: + data_sym = data_type + "_sym" + else: + data_sym = data_type + "_asym" + return data_sym - if sym: - key = dtype + str(bits) - else: - key = dtype + str(bits) + def pad_bits(data_type): + return data_type + str(bits) - if key in QUANT_FUNC_WITH_DTYPE.keys(): - return QUANT_FUNC_WITH_DTYPE[key], key + if not disable_opt_rtn: + rtn_data_type = "rtn_" + dtype + data_types = [rtn_data_type, pad_bits(rtn_data_type), pad_sym(rtn_data_type), pad_sym(pad_bits(rtn_data_type))] + for data_type in data_types: + from auto_round.data_type import QUANT_FUNC_WITH_DTYPE + if data_type in QUANT_FUNC_WITH_DTYPE: + return QUANT_FUNC_WITH_DTYPE[data_type],data_type - raise ValueError(f"{dtype} is not supported") + data_types = [dtype, pad_bits(dtype), pad_sym(dtype), pad_sym(pad_bits(dtype))] + for data_type in data_types: + from auto_round.data_type import QUANT_FUNC_WITH_DTYPE + if data_type in QUANT_FUNC_WITH_DTYPE: + return QUANT_FUNC_WITH_DTYPE[data_type], data_type def round_ste(x: torch.Tensor): @@ -254,12 +250,12 @@ def update_fused_layer_global_scales(submodule: torch.nn.Module, base_name="weig def _is_attention_module(module: Module): return "attention" in module.__class__.__name__.lower() and ( - hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") + hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") ) def _is_mlp_module(module: Module): return "mlp" in module.__class__.__name__.lower() and ( - hasattr(module, "gate_proj") or hasattr(module, "up_proj") + hasattr(module, "gate_proj") or hasattr(module, "up_proj") ) if _is_attention_module(submodule): diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 30fcb2bd6..60c61cd82 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -228,7 +228,7 @@ def pack_layer(layer_name, model, backend, device=None): zp = int(zp.flatten()[0]) qlayer.to("cpu") - ##force to float32 to be compatible with torch 2.0 + # Force to float32 to be compatible with torch 2.0 sig = inspect.signature(qlayer.pack) param_count = len(sig.parameters) if param_count == 2: @@ -294,7 +294,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex return save_quantized_as_autoround(output_dir, inplace=inplace, backend="auto_round", **kwargs) - ##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source + # IF using sym, we change to gptq sym kernel to avoid compiling from auto_round source if ( (kwargs.get("sym") is None or kwargs.get("sym")) and ("gptq" not in backend and "awq" not in backend) @@ -327,7 +327,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex for layer_name in layer_config: if ( not layer_config[layer_name]["in_blocks"] and layer_config[layer_name]["bits"] <= 8 - ): ##lm head ##TODO fix act and so on + ): # lm head ##TODO fix act and so on extra_config[layer_name] = {} extra_config[layer_name]["bits"] = layer_config[layer_name]["bits"] extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"] @@ -344,6 +344,8 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex for key in neq_keys: if layer_config[layer_name][key] is not None: extra_config[layer_name][key] = layer_config[layer_name][key] + if not extra_config[layer_name]: # Pop empty dict + extra_config.pop(layer_name) if len(extra_config) > 0: quantization_config["extra_config"] = extra_config names = list(layer_config.keys()) diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index a9c0f5cb2..6adc5f1e3 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -80,6 +80,7 @@ def __init__( device="cpu", enable_round_tuning=True, enable_torch_compile=False, + disable_opt_rtn=True, **kwargs, ): """Initializes the WrapperLinear module. @@ -92,6 +93,7 @@ def __init__( """ super(WrapperLinear, self).__init__() self.orig_layer = orig_layer + self.disable_opt_rtn = disable_opt_rtn self.output_device = device self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device self.enable_minmax_tuning = enable_minmax_tuning @@ -146,13 +148,13 @@ def _init_tuning_params_and_quant_func(self): self._init_params("min_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16)) self._init_params("max_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16)) - self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym) + self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym, self.disable_opt_rtn) if self.enable_torch_compile: self.weight_quant_func = compile_func(self.weight_quant_func, self.device) if self.enable_act_quant: self.act_quant_func, self.act_data_type = get_quant_func( - orig_layer.act_data_type, orig_layer.act_bits, orig_layer.act_sym + orig_layer.act_data_type, orig_layer.act_bits, orig_layer.act_sym, self.disable_opt_rtn ) if self.enable_torch_compile: self.act_quant_func = compile_func(self.act_quant_func, self.device) From 005a63e660d20c54af4e9752d2237cfaaae43eda Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Oct 2025 04:52:48 +0000 Subject: [PATCH 02/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/compressors/base.py | 21 ++++++++++--------- auto_round/data_type/utils.py | 10 +++++---- .../export/export_to_autoround/export.py | 2 +- auto_round/wrapper.py | 4 +++- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 09216cd5c..14d119907 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1351,7 +1351,6 @@ def _quantize_layer_via_rtn(self, name: str) -> None: # m.data_type = rtn_data_type # self.layer_config[name]["data_type"] = m.data_type - # Step 2: Try quantization on GPU first, fall back to CPU if OOM # if only export gguf, using gguf-packing instead of rtn if self.is_packing_immediate and self.iters == 0 and "gguf" in self.formats[0] and not self.disable_opt_rtn: @@ -1366,7 +1365,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None: enable_norm_bias_tuning=False, enable_round_tuning=False, enable_torch_compile=self.enable_torch_compile, - disable_opt_rtn=self.disable_opt_rtn + disable_opt_rtn=self.disable_opt_rtn, ) m = m.unwrapper({}) m.to("cpu") @@ -1461,7 +1460,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: enable_imatrix = False if has_gguf_k and not self.disable_opt_rtn: enable_imatrix = True - if self.data_type=="int" and self.sym: + if self.data_type == "int" and self.sym: enable_imatrix = True if enable_imatrix: @@ -1797,30 +1796,32 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: self.quantized = True return self.model, self.layer_config - - def _check_rtn_dytpe(self,data_type,bit, sym): + def _check_rtn_dytpe(self, data_type, bit, sym): """Check if the given data type is an RTN (Round-To-Nearest) type. Args: data_type (str): The data type to check. """ + def pad_sym(dtype): if sym: data_sym = dtype + "_sym" else: - data_sym = dtype+ "_asym" + data_sym = dtype + "_asym" return data_sym + def pad_bits(dtype): - return dtype+str(bit) - data_type = "rtn_"+data_type - data_types=[data_type, pad_bits(data_type), pad_sym(data_type), pad_sym(pad_bits(data_type))] + return dtype + str(bit) + + data_type = "rtn_" + data_type + data_types = [data_type, pad_bits(data_type), pad_sym(data_type), pad_sym(pad_bits(data_type))] for data_type in data_types: from auto_round.data_type import QUANT_FUNC_WITH_DTYPE + if data_type in QUANT_FUNC_WITH_DTYPE: return data_type return None - def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: """Quantizes specified layers based on inputs and configuration. diff --git a/auto_round/data_type/utils.py b/auto_round/data_type/utils.py index 7507e67f4..1bb53a14b 100644 --- a/auto_round/data_type/utils.py +++ b/auto_round/data_type/utils.py @@ -87,7 +87,7 @@ def revert_tensor_by_pad(data: torch.Tensor, orig_shape: tuple, pad_len: int): return data_new -def get_quant_func(dtype:str, bits:int, sym:bool, disable_opt_rtn=False)->tuple[callable,str]: +def get_quant_func(dtype: str, bits: int, sym: bool, disable_opt_rtn=False) -> tuple[callable, str]: """Retrieve the quantization function based on data type, bit width, and symmetry. This function returns the appropriate quantization function from the QUANT_FUNC_WITH_DTYPE @@ -120,12 +120,14 @@ def pad_bits(data_type): data_types = [rtn_data_type, pad_bits(rtn_data_type), pad_sym(rtn_data_type), pad_sym(pad_bits(rtn_data_type))] for data_type in data_types: from auto_round.data_type import QUANT_FUNC_WITH_DTYPE + if data_type in QUANT_FUNC_WITH_DTYPE: - return QUANT_FUNC_WITH_DTYPE[data_type],data_type + return QUANT_FUNC_WITH_DTYPE[data_type], data_type data_types = [dtype, pad_bits(dtype), pad_sym(dtype), pad_sym(pad_bits(dtype))] for data_type in data_types: from auto_round.data_type import QUANT_FUNC_WITH_DTYPE + if data_type in QUANT_FUNC_WITH_DTYPE: return QUANT_FUNC_WITH_DTYPE[data_type], data_type @@ -250,12 +252,12 @@ def update_fused_layer_global_scales(submodule: torch.nn.Module, base_name="weig def _is_attention_module(module: Module): return "attention" in module.__class__.__name__.lower() and ( - hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") + hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") ) def _is_mlp_module(module: Module): return "mlp" in module.__class__.__name__.lower() and ( - hasattr(module, "gate_proj") or hasattr(module, "up_proj") + hasattr(module, "gate_proj") or hasattr(module, "up_proj") ) if _is_attention_module(submodule): diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 60c61cd82..2578a3702 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -344,7 +344,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex for key in neq_keys: if layer_config[layer_name][key] is not None: extra_config[layer_name][key] = layer_config[layer_name][key] - if not extra_config[layer_name]: # Pop empty dict + if not extra_config[layer_name]: # Pop empty dict extra_config.pop(layer_name) if len(extra_config) > 0: quantization_config["extra_config"] = extra_config diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index 6adc5f1e3..6a6227023 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -148,7 +148,9 @@ def _init_tuning_params_and_quant_func(self): self._init_params("min_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16)) self._init_params("max_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16)) - self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym, self.disable_opt_rtn) + self.weight_quant_func, self.data_type = get_quant_func( + orig_layer.data_type, orig_layer.bits, orig_layer.sym, self.disable_opt_rtn + ) if self.enable_torch_compile: self.weight_quant_func = compile_func(self.weight_quant_func, self.device) From bda0dd27cc06352a3d2329ec8c526b52c92a93f1 Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Tue, 21 Oct 2025 13:53:50 +0800 Subject: [PATCH 03/14] revert change --- auto_round/export/export_to_autoround/export.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 60c61cd82..a2f364d0f 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -344,8 +344,6 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex for key in neq_keys: if layer_config[layer_name][key] is not None: extra_config[layer_name][key] = layer_config[layer_name][key] - if not extra_config[layer_name]: # Pop empty dict - extra_config.pop(layer_name) if len(extra_config) > 0: quantization_config["extra_config"] = extra_config names = list(layer_config.keys()) From fc4222d82237644cad6c1086f2909bd242617c93 Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Tue, 21 Oct 2025 17:02:30 +0800 Subject: [PATCH 04/14] fix and update accuracy --- README.md | 4 +- auto_round/data_type/int.py | 9 +- auto_round/utils.py | 1 + docs/opt_rtn.md | 42 + docs/step_by_step.md | 6 +- test/test_cpu/test_autoround.py | 1394 +++++++++++++++---------------- 6 files changed, 751 insertions(+), 705 deletions(-) create mode 100644 docs/opt_rtn.md diff --git a/README.md b/README.md index fe5441e2e..115d2152e 100644 --- a/README.md +++ b/README.md @@ -76,8 +76,8 @@ Out-of-the-box quantization for 10+ vision-language models [example models](http โœ… **Layerwise Mixed Bits Quantization** Assign different bits per layer for fine-grained accuracy/performance trade-offs. Details are shown in [mixed bits quantization](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#mixed-bits-usage) -โœ… **Round-to-Nearest Mode** -Use `--iters 0` for fast, calibration-free quantization with some accuracy drop. Details are shown in [rtn mode](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#rtn-mode) +โœ… **Optimized Round-to-Nearest Mode** +Use `--iters 0` for fast, calibration-free quantization with some accuracy drop for 4 bits. Details are shown in [opt_rtn mode](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#opt-rtn-mode) โœ… **Multiple Recipes** Choose from `auto-round-best`, `auto-round`, and `auto-round-light` to suit your needs. Details are shown in [quantization recipes](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#recipe-recommendation) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index b940201ca..eaafd2aee 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -77,10 +77,13 @@ def quant_tensor_sym( tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) maxq = 2 ** (bits - 1) - imatrix = imatrix.reshape(1, -1) + if imatrix is None: + imatrix = 1.0 + else: + imatrix = imatrix.reshape(1, -1) - imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1) - imatrix = imatrix.reshape(tensor.shape) + imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1) + imatrix = imatrix.reshape(tensor.shape) scale = search_scales(tensor, bits, qw=imatrix) scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) diff --git a/auto_round/utils.py b/auto_round/utils.py index 8c8c9acc5..9285af51a 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -2929,6 +2929,7 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str if name in all_module_names: m = get_module(model, name) if len(list(m.children())) == 0 and type(m) not in supported_types: + layer_config.pop(name) logger.warning(f"{name} is not supported in current scheme, ignoring its setting in `layer_config`") continue diff --git a/docs/opt_rtn.md b/docs/opt_rtn.md new file mode 100644 index 000000000..d674d2b9d --- /dev/null +++ b/docs/opt_rtn.md @@ -0,0 +1,42 @@ +### ๐Ÿงฎ Evaluation Results (LM-Eval) +4BIT=W4A16 +3BIT=W3A16 +2BIT=W2A16G64 + +RTN mode +~~~bash +auto-round --model xxx --disable_opt_rtn --iters 0 +~~~ + +OPT RTN mode +~~~bash +auto-round --model xxx --iters 0 +~~~ +For 2/3bit, we strongly recommend not using iter=0. + +| Model | RNT/OPT | AVG | HellaSwag | LAMBADA | MMLU | PIQA | WinoGrande | +|-------|----------|-----|------------|-----------------|------|------|-------------| +| **Meta-Llama-3.1-8B-Instruct** | RTN-4BIT | 0.69328 | 0.5896 | 0.7013 | 0.6538 | 0.7987 | 0.7230 | +| | OPT-4BIT | 0.69560 | 0.5882 | 0.7074 | 0.6631 | 0.7916 | 0.7277 | +| | RTN-3BIT | 0.64562 | 0.5410 | 0.6695 | 0.5449 | 0.7742 | 0.6985 | +| | OPT-3BIT | 0.65970 | 0.5490 | 0.6893 | 0.5711 | 0.7677 | 0.7214 | +| | RTN-2BIT | 0.33008 | 0.2918 | 0.0474 | 0.2321 | 0.5740 | 0.5051 | +| | OPT-2BIT | 0.38908 | 0.3241 | 0.1560 | 0.2822 | 0.6235 | 0.5596 | +| **Qwen2.5-7B-Instruct** | RTN-4BIT | 0.69560 | 0.6114 | 0.6713 | 0.7011 | 0.7878 | 0.7064 | +| | OPT-4BIT | 0.70034 | 0.6143 | 0.6945 | 0.7115 | 0.7845 | 0.6969 | +| | RTN-3BIT | 0.64144 | 0.5585 | 0.6092 | 0.6455 | 0.7476 | 0.6464 | +| | OPT-3BIT | 0.66764 | 0.5756 | 0.7013 | 0.6597 | 0.7481 | 0.6535 | +| | RTN-2BIT | 0.31856 | 0.2804 | 0.0351 | 0.2379 | 0.5256 | 0.5138 | +| | OPT-2BIT | 0.45146 | 0.3645 | 0.2992 | 0.4043 | 0.6415 | 0.5478 | +| **Qwen3-8B** | RTN-4BIT | 0.66240 | 0.5619 | 0.6150 | 0.7077 | 0.7573 | 0.6701 | +| | OPT-4BIT | 0.66992 | 0.5619 | 0.6346 | 0.7102 | 0.7633 | 0.6796 | +| | RTN-3BIT | 0.57322 | 0.4992 | 0.4260 | 0.6002 | 0.7361 | 0.6046 | +| | OPT-3BIT | 0.63698 | 0.5226 | 0.5814 | 0.6718 | 0.7437 | 0.6654 | +| | RTN-2BIT | 0.31150 | 0.2679 | 0.0041 | 0.2536 | 0.5283 | 0.5036 | +| | OPT-2BIT | 0.44254 | 0.3749 | 0.2005 | 0.4202 | 0.6670 | 0.5501 | +| **Qwen3-14B** | RTN-4BIT | 0.70448 | 0.5999 | 0.6511 | 0.7565 | 0.7998 | 0.7151 | +| | OPT-4BIT | 0.70798 | 0.6031 | 0.6627 | 0.7534 | 0.8009 | 0.7198 | +| | RTN-3BIT | 0.65876 | 0.5746 | 0.5467 | 0.7065 | 0.7628 | 0.7032 | +| | OPT-3BIT | 0.68610 | 0.5683 | 0.6633 | 0.7258 | 0.7699 | 0.7032 | +| | RTN-2BIT | 0.39398 | 0.3764 | 0.0607 | 0.3836 | 0.6480 | 0.5012 | +| | OPT-2BIT | 0.50080 | 0.4554 | 0.2451 | 0.4899 | 0.7138 | 0.5998 | \ No newline at end of file diff --git a/docs/step_by_step.md b/docs/step_by_step.md index ba27ef5d1..097b97549 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -23,7 +23,7 @@ This document presents step-by-step instructions for auto-round llm quantization - [CLI Usage](#cli-usage) - [API Usage](#api-usage-1) - [Hyperparameters in AutoScheme](#hyperparameters-in-autoscheme) - + [RTN mode](#rtn-mode) + + [OPT RTN mode](#opt-rtn-mode) + [GGUF format](#gguf-format) + [Quantization Costs](#quantization-costs) + [Device/Multi-GPU setting in Quantization](#device-multi-gpu-setting-in-quantization) @@ -368,8 +368,8 @@ We will try to optimize the VRAM usage in the future. Embedding layer is not supported in AutoScheme, it will use the best scheme in options. -### RTN mode -AutoRound also supports RTN (Round-To-Nearest) mode for fast, calibration-free baseline quantization. try setting `iters=0` and use `group_size=32` for better results. +### OPT RTN Mode +AutoRound also supports Optimized RTN (Round-To-Nearest) mode for fast, calibration-free baseline quantization. try setting `iters=0` and use `group_size=32` for better results. For the GGUF format, we have optimized the RTN algorithm inspired by llamacpp. To use the original (pure) RTN algorithm instead, enable the `--disable_opt_rtn` option. ```python diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index cbd0583df..4336bd771 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -29,8 +29,8 @@ class TestAutoRound(unittest.TestCase): @classmethod def setUpClass(self): model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + # self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + # self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) self.llm_dataloader = LLMDataLoader() self.save_folder = "./saved" @@ -39,404 +39,404 @@ def tearDownClass(self): shutil.rmtree(self.save_folder, ignore_errors=True) shutil.rmtree("runs", ignore_errors=True) - def test_bits_setting(self): - layer_config = {"model.decoder.layers.0.self_attn.k_proj": {"data_type": "mx_fp8", "group_size": 32}} - autoround = AutoRound( - "/tf_dataset/auto_round/models/facebook/opt-125m", iters=2, seqlen=2, nsamples=1, layer_config=layer_config - ) - autoround.quantize() - module = get_module(autoround.model, "model.decoder.layers.0.self_attn.k_proj") - if module.bits != 8: - raise ValueError(f"Expected bits to be 8, but got {module.bits}") - - def test_layer_config(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - layer_config = {"self_attn": {"bits": 4, "data_type": "nv_fp", "act_bits": 16, "group_size": 16}} - autoround = AutoRound( - model_name, - self.tokenizer, - scheme="NVFP4", - iters=0, - seqlen=2, - dataset=self.llm_dataloader, - layer_config=layer_config, - amp=False, - ) - autoround.quantize_and_save(self.save_folder, inplace=False, format="fake") - shutil.rmtree(self.save_folder) - - def test_remove_whole_block(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - layer_config = { - "model.decoder.layers.0.self_attn.k_proj": {"bits": 32}, - "model.decoder.layers.0.self_attn.v_proj": {"bits": 32}, - "model.decoder.layers.0.self_attn.q_proj": {"bits": 32}, - "model.decoder.layers.0.self_attn.out_proj": {"bits": 32}, - "model.decoder.layers.0.fc1": {"bits": 32}, - "model.decoder.layers.0.fc2": {"bits": 32}, - } - bits, group_size, sym = 4, 128, False - autoround = AutoRound( - model_name, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - layer_config=layer_config, - ) - autoround.quantize() - - def test_consecutive_quant(self): - bits, group_size, sym = 4, -1, False - autoround = AutoRound( - self.model, - self.tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - ) - autoround.quantize() - - model = AutoModelForCausalLM.from_pretrained( - "/tf_dataset/auto_round/models/microsoft/phi-2", torch_dtype="auto", trust_remote_code=True - ) - tokenizer = AutoTokenizer.from_pretrained( - "/tf_dataset/auto_round/models/microsoft/phi-2", trust_remote_code=True - ) - autoround = AutoRound( - model, - tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - ) - autoround.quantize() - - def test_mx_fp4(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - bits, group_size, sym = 4, 32, False - autoround = AutoRound( - model_name, - bits=bits, - act_bits=bits, - group_size=group_size, - sym=sym, - iters=2, - nsamples=2, - seqlen=128, - data_type="mx_fp4", - act_data_type="mx_fp_rceil", - ) - model, _ = autoround.quantize() - result = simple_evaluate_user_model( - model, self.tokenizer, batch_size="auto:8", tasks="lambada_openai", limit=32 - ) - print(result["results"]["lambada_openai"]["acc,none"]) - self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.3) # 0.375 - - def test_nv_fp4(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - bits, group_size, sym = 4, 16, False - autoround = AutoRound( - model_name, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - data_type="nv_fp4", - ) - model, _ = autoround.quantize() - result = simple_evaluate_user_model( - model, self.tokenizer, batch_size="auto:8", tasks="lambada_openai", limit=32 - ) - print(result["results"]["lambada_openai"]["acc,none"]) - self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.35) - - def test_default(self): - bits, group_size, sym = 4, 128, False - autoround = AutoRound( - self.model, - self.tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - ) - autoround.quantize() - - autoround.save_quantized(output_dir="./saved", inplace=False, format="itrex") - try: - import auto_gptq - except: - return - if torch.cuda.is_available(): - autoround.save_quantized(output_dir="./saved", inplace=False) - - def test_w4g1(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - bits, group_size, sym = 4, -1, True - autoround = AutoRound( - model_name, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=10, - dataset=self.llm_dataloader, - ) - autoround.quantize() - - @parameterized.expand([(2,), (3,), (4,)]) - def test_g128(self, bits): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - group_size, sym = 128, True - autoround = AutoRound( - model_name, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=10, - dataset=self.llm_dataloader, - ) - model, _ = autoround.quantize() - if bits > 2: - result = simple_evaluate_user_model( - model, self.tokenizer, batch_size="auto:8", tasks="lambada_openai", limit=32 - ) - print(result["results"]["lambada_openai"]["acc,none"]) - self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.3) - - def test_disable_quanted_input(self): - bits, group_size, sym = 4, -1, True - autoround = AutoRound( - self.model, - self.tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=10, - enable_quanted_input=False, - dataset=self.llm_dataloader, - ) - autoround.quantize() - - def test_enable_norm_bias_tuning_qwen3(self): - bits, group_size, sym = 4, 128, True - model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B" - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - autoround = AutoRound( - model, - tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=10, - enable_norm_bias_tuning=True, - dataset=self.llm_dataloader, - ) - autoround.quantize() - - def test_enable_norm_bias_tuning(self): - bits, group_size, sym = 4, -1, True - autoround = AutoRound( - self.model, - self.tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=10, - enable_quanted_input=False, - enable_norm_bias_tuning=True, - dataset=self.llm_dataloader, - ) - autoround.quantize() - - def test_disable_minmax_tuning(self): - bits, group_size, sym = 4, -1, True - autoround = AutoRound( - self.model, - self.tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=10, - enable_minmax_tuning=False, - dataset=self.llm_dataloader, - ) - autoround.quantize() - + # def test_bits_setting(self): + # layer_config = {"model.decoder.layers.0.self_attn.k_proj": {"data_type": "mx_fp8", "group_size": 32}} + # autoround = AutoRound( + # "/tf_dataset/auto_round/models/facebook/opt-125m", iters=2, seqlen=2, nsamples=1, layer_config=layer_config + # ) + # autoround.quantize() + # module = get_module(autoround.model, "model.decoder.layers.0.self_attn.k_proj") + # if module.bits != 8: + # raise ValueError(f"Expected bits to be 8, but got {module.bits}") + # + # def test_layer_config(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # layer_config = {"self_attn": {"bits": 4, "data_type": "nv_fp", "act_bits": 16, "group_size": 16}} + # autoround = AutoRound( + # model_name, + # self.tokenizer, + # scheme="NVFP4", + # iters=0, + # seqlen=2, + # dataset=self.llm_dataloader, + # layer_config=layer_config, + # amp=False, + # ) + # autoround.quantize_and_save(self.save_folder, inplace=False, format="fake") + # shutil.rmtree(self.save_folder) + # + # def test_remove_whole_block(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # layer_config = { + # "model.decoder.layers.0.self_attn.k_proj": {"bits": 32}, + # "model.decoder.layers.0.self_attn.v_proj": {"bits": 32}, + # "model.decoder.layers.0.self_attn.q_proj": {"bits": 32}, + # "model.decoder.layers.0.self_attn.out_proj": {"bits": 32}, + # "model.decoder.layers.0.fc1": {"bits": 32}, + # "model.decoder.layers.0.fc2": {"bits": 32}, + # } + # bits, group_size, sym = 4, 128, False + # autoround = AutoRound( + # model_name, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # layer_config=layer_config, + # ) + # autoround.quantize() + # + # def test_consecutive_quant(self): + # bits, group_size, sym = 4, -1, False + # autoround = AutoRound( + # self.model, + # self.tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # ) + # autoround.quantize() + # + # model = AutoModelForCausalLM.from_pretrained( + # "/tf_dataset/auto_round/models/microsoft/phi-2", torch_dtype="auto", trust_remote_code=True + # ) + # tokenizer = AutoTokenizer.from_pretrained( + # "/tf_dataset/auto_round/models/microsoft/phi-2", trust_remote_code=True + # ) + # autoround = AutoRound( + # model, + # tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # ) + # autoround.quantize() + # + # def test_mx_fp4(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # bits, group_size, sym = 4, 32, False + # autoround = AutoRound( + # model_name, + # bits=bits, + # act_bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # nsamples=2, + # seqlen=128, + # data_type="mx_fp4", + # act_data_type="mx_fp_rceil", + # ) + # model, _ = autoround.quantize() + # result = simple_evaluate_user_model( + # model, self.tokenizer, batch_size="auto:8", tasks="lambada_openai", limit=32 + # ) + # print(result["results"]["lambada_openai"]["acc,none"]) + # self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.3) # 0.375 + # + # def test_nv_fp4(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # bits, group_size, sym = 4, 16, False + # autoround = AutoRound( + # model_name, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # data_type="nv_fp4", + # ) + # model, _ = autoround.quantize() + # result = simple_evaluate_user_model( + # model, self.tokenizer, batch_size="auto:8", tasks="lambada_openai", limit=32 + # ) + # print(result["results"]["lambada_openai"]["acc,none"]) + # self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.35) + # + # def test_default(self): + # bits, group_size, sym = 4, 128, False + # autoround = AutoRound( + # self.model, + # self.tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # ) + # autoround.quantize() + # + # autoround.save_quantized(output_dir="./saved", inplace=False, format="itrex") + # try: + # import auto_gptq + # except: + # return + # if torch.cuda.is_available(): + # autoround.save_quantized(output_dir="./saved", inplace=False) + # + # def test_w4g1(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # bits, group_size, sym = 4, -1, True + # autoround = AutoRound( + # model_name, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=10, + # dataset=self.llm_dataloader, + # ) + # autoround.quantize() + # + # @parameterized.expand([(2,), (3,), (4,)]) + # def test_g128(self, bits): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # group_size, sym = 128, True + # autoround = AutoRound( + # model_name, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=10, + # dataset=self.llm_dataloader, + # ) + # model, _ = autoround.quantize() + # if bits > 2: + # result = simple_evaluate_user_model( + # model, self.tokenizer, batch_size="auto:8", tasks="lambada_openai", limit=32 + # ) + # print(result["results"]["lambada_openai"]["acc,none"]) + # self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.3) + # + # def test_disable_quanted_input(self): + # bits, group_size, sym = 4, -1, True + # autoround = AutoRound( + # self.model, + # self.tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=10, + # enable_quanted_input=False, + # dataset=self.llm_dataloader, + # ) + # autoround.quantize() + # + # def test_enable_norm_bias_tuning_qwen3(self): + # bits, group_size, sym = 4, 128, True + # model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B" + # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + # autoround = AutoRound( + # model, + # tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=10, + # enable_norm_bias_tuning=True, + # dataset=self.llm_dataloader, + # ) + # autoround.quantize() + # + # def test_enable_norm_bias_tuning(self): + # bits, group_size, sym = 4, -1, True + # autoround = AutoRound( + # self.model, + # self.tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=10, + # enable_quanted_input=False, + # enable_norm_bias_tuning=True, + # dataset=self.llm_dataloader, + # ) + # autoround.quantize() + # + # def test_disable_minmax_tuning(self): + # bits, group_size, sym = 4, -1, True + # autoround = AutoRound( + # self.model, + # self.tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=10, + # enable_minmax_tuning=False, + # dataset=self.llm_dataloader, + # ) + # autoround.quantize() + # + # # + # def test_signround(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # bits, group_size, sym = 4, -1, False + # autoround = AutoRound( + # model_name, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=10, + # enable_minmax_tuning=False, + # enable_quanted_input=False, + # dataset=self.llm_dataloader, + # ) + # autoround.quantize() + # + # def test_lm_head_layer_config_way(self): + # bits, group_size, sym = 4, -1, False + # layer_config = {"lm_head": {"data_type": "int"}} + # autoround = AutoRound( + # self.model, + # self.tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=10, + # enable_minmax_tuning=False, + # enable_quanted_input=False, + # dataset=self.llm_dataloader, + # layer_config=layer_config, + # ) + # autoround.quantize() + # + # def test_wa_quant(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # bits, group_size, sym, act_bits = 4, 128, False, 4 + # autoround = AutoRound( + # model_name, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # act_bits=act_bits, + # ) + # autoround.quantize() + # + # def test_auto_device_map(self): + # bits, group_size, sym = 4, 128, False + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # model = AutoModelForCausalLM.from_pretrained( + # model_name, torch_dtype="auto", trust_remote_code=True, device_map="auto" + # ) + # autoround = AutoRound( + # model, + # self.tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # ) + # autoround.quantize() + # + # def test_device_map_dict(self): + # bits, group_size, sym = 4, 128, False + # device_map = {".*": "cpu"} + # autoround = AutoRound( + # self.model, + # self.tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # device_map=device_map, + # ) + # autoround.quantize() + # + # # test model_name + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # autoround = AutoRound( + # model_name, + # self.tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # device_map=device_map, + # ) + # autoround.quantize() + # + # def test_fp32(self): + # bits, group_size, sym = 4, 128, False + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # model = AutoModelForCausalLM.from_pretrained( + # model_name, torch_dtype=torch.float32, trust_remote_code=True, device_map="auto" + # ) + # autoround = AutoRound( + # model, + # self.tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # amp=False, + # ) + # autoround.quantize() + # + # def test_tensor_reshape(self): + # bits, group_size, sym = 4, 100, False + # autoround = AutoRound( + # self.model, + # self.tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # ) + # autoround.quantize() + # + # def test_rtn(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + # + # bits, group_size, sym = 4, 128, True + # autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, iters=0, nsamples=1) + # quantized_model_path = self.save_folder + # autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + # model = AutoModelForCausalLM.from_pretrained( + # self.save_folder, + # torch_dtype=torch.float16, + # device_map="auto", + # ) + # + # tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + # model_infer(model, tokenizer) + # shutil.rmtree(self.save_folder) # - def test_signround(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - bits, group_size, sym = 4, -1, False - autoround = AutoRound( - model_name, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=10, - enable_minmax_tuning=False, - enable_quanted_input=False, - dataset=self.llm_dataloader, - ) - autoround.quantize() - - def test_lm_head_layer_config_way(self): - bits, group_size, sym = 4, -1, False - layer_config = {"lm_head": {"data_type": "int"}} - autoround = AutoRound( - self.model, - self.tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=10, - enable_minmax_tuning=False, - enable_quanted_input=False, - dataset=self.llm_dataloader, - layer_config=layer_config, - ) - autoround.quantize() - - def test_wa_quant(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - bits, group_size, sym, act_bits = 4, 128, False, 4 - autoround = AutoRound( - model_name, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - act_bits=act_bits, - ) - autoround.quantize() - - def test_auto_device_map(self): - bits, group_size, sym = 4, 128, False - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - model = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype="auto", trust_remote_code=True, device_map="auto" - ) - autoround = AutoRound( - model, - self.tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - ) - autoround.quantize() - - def test_device_map_dict(self): - bits, group_size, sym = 4, 128, False - device_map = {".*": "cpu"} - autoround = AutoRound( - self.model, - self.tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - device_map=device_map, - ) - autoround.quantize() - - # test model_name - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - autoround = AutoRound( - model_name, - self.tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - device_map=device_map, - ) - autoround.quantize() - - def test_fp32(self): - bits, group_size, sym = 4, 128, False - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - model = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=torch.float32, trust_remote_code=True, device_map="auto" - ) - autoround = AutoRound( - model, - self.tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - amp=False, - ) - autoround.quantize() - - def test_tensor_reshape(self): - bits, group_size, sym = 4, 100, False - autoround = AutoRound( - self.model, - self.tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - ) - autoround.quantize() - - def test_rtn(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - - bits, group_size, sym = 4, 128, True - autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, iters=0, nsamples=1) - quantized_model_path = self.save_folder - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") - model = AutoModelForCausalLM.from_pretrained( - self.save_folder, - torch_dtype=torch.float16, - device_map="auto", - ) - - tokenizer = AutoTokenizer.from_pretrained(self.save_folder) - model_infer(model, tokenizer) - shutil.rmtree(self.save_folder) - def test_embed_quant(self): bits, group_size, sym = 4, 128, True - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + model_name = "/models/opt-125m" layer_config = { "model.decoder.embed_tokens": {"bits": 4}, } @@ -452,308 +452,308 @@ def test_embed_quant(self): layer_config=layer_config, ) autoround.quantize() - - def test_fallback_layers(self): - bits, group_size, sym = 4, 128, True - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - model = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=torch.float32, trust_remote_code=True, device_map="auto" - ) - layer_config = { - "model.decoder.layers.0.self_attn.q_proj": {"bits": 16}, - "model.decoder.layers.1.self_attn.k_proj": {"bits": 16}, - "model.decoder.embed_tokens": {"bits": 16}, - } - autoround = AutoRound( - model, - self.tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - layer_config=layer_config, - ) - autoround.quantize() - quantized_model_path = self.save_folder - - autoround.save_quantized(output_dir=quantized_model_path, format="auto_round", inplace=True) - quantization_config = AutoRoundConfig(backend="ipex") - - model = AutoModelForCausalLM.from_pretrained( - quantized_model_path, device_map="cpu", quantization_config=quantization_config - ) - tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) - text = "There is a girl who likes adventure," - inputs = tokenizer(text, return_tensors="pt").to(model.device) - res = tokenizer.decode(model.generate(**inputs, max_new_tokens=1)[0]) - shutil.rmtree(self.save_folder, ignore_errors=True) - - def test_not_convert_modules(self): - import requests - from PIL import Image - from transformers import AutoProcessor, Qwen2VLForConditionalGeneration - - from auto_round_extension.ipex.qlinear_ipex_awq import QuantLinear - - model_name = "/tf_dataset/auto_round/models/Qwen/Qwen2-VL-2B-Instruct-AWQ" - quantization_config = AutoRoundConfig() - model = Qwen2VLForConditionalGeneration.from_pretrained( - model_name, quantization_config=quantization_config, device_map="cpu", torch_dtype=torch.float16 - ) - self.assertTrue(isinstance(model.visual.blocks[0].attn.qkv, torch.nn.Linear)) - self.assertFalse(isinstance(model.visual.merger.mlp[0], QuantLinear)) - if hasattr(model.model, "language_model"): - self.assertTrue(isinstance(model.model.language_model.layers[0].self_attn.v_proj, QuantLinear)) - else: - self.assertTrue(isinstance(model.model.layers[0].self_attn.v_proj, QuantLinear)) - - processor = AutoProcessor.from_pretrained(model_name, size=None) - image_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" - messages = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": image_url, - }, - {"type": "text", "text": "Describe this image."}, - ], - } - ] - - # Preparation for inference - text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - image_inputs = Image.open(requests.get(image_url, stream=True).raw) - inputs = processor( - text=[text], - images=image_inputs, - padding=True, - return_tensors="pt", - ) - - # Inference: Generation of the output - generated_ids = model.generate(**inputs, max_new_tokens=1) - generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] - output_text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - print(output_text) - - def test_fallback_layers_regex_awq(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - bits, group_size, sym = 4, 128, True - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - layer_config = { - r"model\.decoder\.layers\.(?:[0-9]|1[0-1])\.self_attn\.q_proj": {"bits": 16}, - "model.decoder.layers.1.self_attn.k_proj": {"bits": 16}, - } - autoround = AutoRound( - model, - tokenizer=tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - layer_config=layer_config, - ) - autoround.quantize() - quantized_model_path = self.save_folder - - autoround.save_quantized(output_dir=quantized_model_path, format="auto_awq", inplace=True) - quantization_config = AutoRoundConfig() - - model = AutoModelForCausalLM.from_pretrained( - quantized_model_path, device_map="auto", quantization_config=quantization_config - ) - tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) - text = "There is a girl who likes adventure," - inputs = tokenizer(text, return_tensors="pt").to(model.device) - res = tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) - print(res) - shutil.rmtree(self.save_folder, ignore_errors=True) - - def test_fallback_layers_regex_gptq(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - bits, group_size, sym = 4, 128, True - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - layer_config = { - r"model\.decoder\.layers\.(?:[0-9]|1[0-1])\.self_attn\.q_proj": {"bits": 16}, - ##"model.decoder.layers.1.self_attn.k_proj": {"bits": 16} - } - autoround = AutoRound( - model, - tokenizer=tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - layer_config=layer_config, - ) - autoround.quantize() - quantized_model_path = self.save_folder - - autoround.save_quantized(output_dir=quantized_model_path, format="auto_gptq", inplace=True) - quantization_config = AutoRoundConfig() - - model = AutoModelForCausalLM.from_pretrained( - quantized_model_path, device_map="auto", quantization_config=quantization_config - ) - tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) - text = "There is a girl who likes adventure," - inputs = tokenizer(text, return_tensors="pt").to(model.device) - res = tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) - print(res) - shutil.rmtree(self.save_folder, ignore_errors=True) - - def test_fallback_layers_regex_round(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - bits, group_size, sym = 4, 128, True - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - layer_config = { - r"model\.decoder\.layers\.(?:[0-9]|1[0-1])\.self_attn\.q_proj": {"bits": 16}, - r"model.decoder.layers.1.self_attn.k_proj": {"bits": 16}, - } - autoround = AutoRound( - model, - tokenizer=tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - layer_config=layer_config, - ) - autoround.quantize() - quantized_model_path = self.save_folder - - autoround.save_quantized(output_dir=quantized_model_path, format="auto_round", inplace=True) - quantization_config = AutoRoundConfig() - - model = AutoModelForCausalLM.from_pretrained( - quantized_model_path, device_map="auto", quantization_config=quantization_config - ) - tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) - text = "There is a girl who likes adventure," - inputs = tokenizer(text, return_tensors="pt").to(model.device) - res = tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) - print(res) - shutil.rmtree(self.save_folder, ignore_errors=True) - - def test_fallback_layers_regex_exception(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - bits, group_size, sym = 4, 128, True - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - layer_config = {"model.decoder.layers.12.self_attn.k_proj": {"bits": 16}} - with self.assertRaises(ValueError): - autoround = AutoRound( - model, - tokenizer=tokenizer, - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - layer_config=layer_config, - ) - autoround.quantize() - - # def test_fp8_model_input_rtn_generation(self): - # model_name = "Qwen/Qwen3-0.6B-FP8" - # ar = AutoRound(model=model_name, iters=0) - # ar.quantize_and_save(output_dir=self.save_folder) - # model = AutoModelForCausalLM.from_pretrained(self.save_folder, torch_dtype="auto", trust_remote_code=True) - # tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + # + # def test_fallback_layers(self): + # bits, group_size, sym = 4, 128, True + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # model = AutoModelForCausalLM.from_pretrained( + # model_name, torch_dtype=torch.float32, trust_remote_code=True, device_map="auto" + # ) + # layer_config = { + # "model.decoder.layers.0.self_attn.q_proj": {"bits": 16}, + # "model.decoder.layers.1.self_attn.k_proj": {"bits": 16}, + # "model.decoder.embed_tokens": {"bits": 16}, + # } + # autoround = AutoRound( + # model, + # self.tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # layer_config=layer_config, + # ) + # autoround.quantize() + # quantized_model_path = self.save_folder + # + # autoround.save_quantized(output_dir=quantized_model_path, format="auto_round", inplace=True) + # quantization_config = AutoRoundConfig(backend="ipex") + # + # model = AutoModelForCausalLM.from_pretrained( + # quantized_model_path, device_map="cpu", quantization_config=quantization_config + # ) + # tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) # text = "There is a girl who likes adventure," # inputs = tokenizer(text, return_tensors="pt").to(model.device) - # print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) - - def test_dequant_fp8_weight(self): - from auto_round.utils import dequant_block_fp8_weight - - # test pad and unpad - weight = torch.randn(587, 7168) - weight_scale = torch.randn(5, 56) - block_size = [128, 128] - dequant_weight = dequant_block_fp8_weight(weight, weight_scale, block_size) - self.assertEqual(dequant_weight.shape.numel(), 4207616) - - # test experts are stacked. - weight = torch.randn([32, 5760, 1440]) - weight_scale = torch.randn([32, 5760, 90]) - block_size = [1, 16] - dequant_weight = dequant_block_fp8_weight(weight, weight_scale, block_size) - self.assertEqual(len(dequant_weight.shape), 3) - self.assertEqual(dequant_weight.shape[0], 32) - self.assertEqual(dequant_weight.shape.numel(), 32 * 5760 * 1440) - - def test_mixed_bit_setting(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - layer_config = {"model.decoder.layers.7.fc1": {"bits": 8, "act_bits": 8}} - ar = AutoRound(model_name, data_type="mx_fp4", act_bits=4, iters=0, layer_config=layer_config) - ar.quantize() - layer_config = ar.layer_config - if ( - layer_config["model.decoder.layers.7.fc1"]["bits"] != 8 - or layer_config["model.decoder.layers.7.fc1"]["act_bits"] != 8 - ): - raise ValueError("mixed bits is not correct") - - def test_alg_ext(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - ar = AutoRound(model_name, scheme="W2A16", iters=1, nsamples=1, enable_alg_ext=True) - ar.quantize() - - def test_invalid_layer_config(self): - with self.assertRaises(ValueError): - layer_config = {"model.decoder.layers.2.self_attnx": {"bits": 2}} - ar = AutoRound( - "/tf_dataset/auto_round/models/facebook/opt-125m", - scheme="W3A16", - nsamples=1, - iters=1, - layer_config=layer_config, - ) - ar.quantize() - with self.assertRaises(ValueError): - layer_config = {"model.decoder.layers.2.self_attn": {"bit": 2}} # should be bits - ar = AutoRound( - "/tf_dataset/auto_round/models/facebook/opt-125m", - scheme="W3A16", - nsamples=1, - iters=1, - layer_config=layer_config, - ) - ar.quantize() - - def test_quant_lm_head(self): - model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" - ar = AutoRound(model_name, quant_lm_head=True, iters=1, nsamples=1, seqlen=32) - ar.quantize() - - def test_compressor(self): - model_name = "Qwen/Qwen2-VL-2B-Instruct" - ar = AutoRound(model_name, enable_adam=True) - self.assertEqual(ar.optimizer, torch.optim.AdamW) - self.assertTrue(ar.mllm) - - # test old api - from auto_round import AutoRoundMLLM - - ar = AutoRoundMLLM(model_name) - self.assertTrue(ar.mllm) + # res = tokenizer.decode(model.generate(**inputs, max_new_tokens=1)[0]) + # shutil.rmtree(self.save_folder, ignore_errors=True) + # + # def test_not_convert_modules(self): + # import requests + # from PIL import Image + # from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + # + # from auto_round_extension.ipex.qlinear_ipex_awq import QuantLinear + # + # model_name = "/tf_dataset/auto_round/models/Qwen/Qwen2-VL-2B-Instruct-AWQ" + # quantization_config = AutoRoundConfig() + # model = Qwen2VLForConditionalGeneration.from_pretrained( + # model_name, quantization_config=quantization_config, device_map="cpu", torch_dtype=torch.float16 + # ) + # self.assertTrue(isinstance(model.visual.blocks[0].attn.qkv, torch.nn.Linear)) + # self.assertFalse(isinstance(model.visual.merger.mlp[0], QuantLinear)) + # if hasattr(model.model, "language_model"): + # self.assertTrue(isinstance(model.model.language_model.layers[0].self_attn.v_proj, QuantLinear)) + # else: + # self.assertTrue(isinstance(model.model.layers[0].self_attn.v_proj, QuantLinear)) + # + # processor = AutoProcessor.from_pretrained(model_name, size=None) + # image_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" + # messages = [ + # { + # "role": "user", + # "content": [ + # { + # "type": "image", + # "image": image_url, + # }, + # {"type": "text", "text": "Describe this image."}, + # ], + # } + # ] + # + # # Preparation for inference + # text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + # image_inputs = Image.open(requests.get(image_url, stream=True).raw) + # inputs = processor( + # text=[text], + # images=image_inputs, + # padding=True, + # return_tensors="pt", + # ) + # + # # Inference: Generation of the output + # generated_ids = model.generate(**inputs, max_new_tokens=1) + # generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + # output_text = processor.batch_decode( + # generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + # ) + # print(output_text) + # + # def test_fallback_layers_regex_awq(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # bits, group_size, sym = 4, 128, True + # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + # layer_config = { + # r"model\.decoder\.layers\.(?:[0-9]|1[0-1])\.self_attn\.q_proj": {"bits": 16}, + # "model.decoder.layers.1.self_attn.k_proj": {"bits": 16}, + # } + # autoround = AutoRound( + # model, + # tokenizer=tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # layer_config=layer_config, + # ) + # autoround.quantize() + # quantized_model_path = self.save_folder + # + # autoround.save_quantized(output_dir=quantized_model_path, format="auto_awq", inplace=True) + # quantization_config = AutoRoundConfig() + # + # model = AutoModelForCausalLM.from_pretrained( + # quantized_model_path, device_map="auto", quantization_config=quantization_config + # ) + # tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + # text = "There is a girl who likes adventure," + # inputs = tokenizer(text, return_tensors="pt").to(model.device) + # res = tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) + # print(res) + # shutil.rmtree(self.save_folder, ignore_errors=True) + # + # def test_fallback_layers_regex_gptq(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # bits, group_size, sym = 4, 128, True + # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + # layer_config = { + # r"model\.decoder\.layers\.(?:[0-9]|1[0-1])\.self_attn\.q_proj": {"bits": 16}, + # ##"model.decoder.layers.1.self_attn.k_proj": {"bits": 16} + # } + # autoround = AutoRound( + # model, + # tokenizer=tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # layer_config=layer_config, + # ) + # autoround.quantize() + # quantized_model_path = self.save_folder + # + # autoround.save_quantized(output_dir=quantized_model_path, format="auto_gptq", inplace=True) + # quantization_config = AutoRoundConfig() + # + # model = AutoModelForCausalLM.from_pretrained( + # quantized_model_path, device_map="auto", quantization_config=quantization_config + # ) + # tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + # text = "There is a girl who likes adventure," + # inputs = tokenizer(text, return_tensors="pt").to(model.device) + # res = tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) + # print(res) + # shutil.rmtree(self.save_folder, ignore_errors=True) + # + # def test_fallback_layers_regex_round(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # bits, group_size, sym = 4, 128, True + # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + # layer_config = { + # r"model\.decoder\.layers\.(?:[0-9]|1[0-1])\.self_attn\.q_proj": {"bits": 16}, + # r"model.decoder.layers.1.self_attn.k_proj": {"bits": 16}, + # } + # autoround = AutoRound( + # model, + # tokenizer=tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # layer_config=layer_config, + # ) + # autoround.quantize() + # quantized_model_path = self.save_folder + # + # autoround.save_quantized(output_dir=quantized_model_path, format="auto_round", inplace=True) + # quantization_config = AutoRoundConfig() + # + # model = AutoModelForCausalLM.from_pretrained( + # quantized_model_path, device_map="auto", quantization_config=quantization_config + # ) + # tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + # text = "There is a girl who likes adventure," + # inputs = tokenizer(text, return_tensors="pt").to(model.device) + # res = tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) + # print(res) + # shutil.rmtree(self.save_folder, ignore_errors=True) + # + # def test_fallback_layers_regex_exception(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # bits, group_size, sym = 4, 128, True + # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + # layer_config = {"model.decoder.layers.12.self_attn.k_proj": {"bits": 16}} + # with self.assertRaises(ValueError): + # autoround = AutoRound( + # model, + # tokenizer=tokenizer, + # bits=bits, + # group_size=group_size, + # sym=sym, + # iters=2, + # seqlen=2, + # dataset=self.llm_dataloader, + # layer_config=layer_config, + # ) + # autoround.quantize() + # + # # def test_fp8_model_input_rtn_generation(self): + # # model_name = "Qwen/Qwen3-0.6B-FP8" + # # ar = AutoRound(model=model_name, iters=0) + # # ar.quantize_and_save(output_dir=self.save_folder) + # # model = AutoModelForCausalLM.from_pretrained(self.save_folder, torch_dtype="auto", trust_remote_code=True) + # # tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + # # text = "There is a girl who likes adventure," + # # inputs = tokenizer(text, return_tensors="pt").to(model.device) + # # print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) + # + # def test_dequant_fp8_weight(self): + # from auto_round.utils import dequant_block_fp8_weight + # + # # test pad and unpad + # weight = torch.randn(587, 7168) + # weight_scale = torch.randn(5, 56) + # block_size = [128, 128] + # dequant_weight = dequant_block_fp8_weight(weight, weight_scale, block_size) + # self.assertEqual(dequant_weight.shape.numel(), 4207616) + # + # # test experts are stacked. + # weight = torch.randn([32, 5760, 1440]) + # weight_scale = torch.randn([32, 5760, 90]) + # block_size = [1, 16] + # dequant_weight = dequant_block_fp8_weight(weight, weight_scale, block_size) + # self.assertEqual(len(dequant_weight.shape), 3) + # self.assertEqual(dequant_weight.shape[0], 32) + # self.assertEqual(dequant_weight.shape.numel(), 32 * 5760 * 1440) + # + # def test_mixed_bit_setting(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # layer_config = {"model.decoder.layers.7.fc1": {"bits": 8, "act_bits": 8}} + # ar = AutoRound(model_name, data_type="mx_fp4", act_bits=4, iters=0, layer_config=layer_config) + # ar.quantize() + # layer_config = ar.layer_config + # if ( + # layer_config["model.decoder.layers.7.fc1"]["bits"] != 8 + # or layer_config["model.decoder.layers.7.fc1"]["act_bits"] != 8 + # ): + # raise ValueError("mixed bits is not correct") + # + # def test_alg_ext(self): + # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + # ar = AutoRound(model_name, scheme="W2A16", iters=1, nsamples=1, enable_alg_ext=True) + # ar.quantize() + # + # def test_invalid_layer_config(self): + # with self.assertRaises(ValueError): + # layer_config = {"model.decoder.layers.2.self_attnx": {"bits": 2}} + # ar = AutoRound( + # "/tf_dataset/auto_round/models/facebook/opt-125m", + # scheme="W3A16", + # nsamples=1, + # iters=1, + # layer_config=layer_config, + # ) + # ar.quantize() + # with self.assertRaises(ValueError): + # layer_config = {"model.decoder.layers.2.self_attn": {"bit": 2}} # should be bits + # ar = AutoRound( + # "/tf_dataset/auto_round/models/facebook/opt-125m", + # scheme="W3A16", + # nsamples=1, + # iters=1, + # layer_config=layer_config, + # ) + # ar.quantize() + # + # def test_quant_lm_head(self): + # model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" + # ar = AutoRound(model_name, quant_lm_head=True, iters=1, nsamples=1, seqlen=32) + # ar.quantize() + # + # def test_compressor(self): + # model_name = "Qwen/Qwen2-VL-2B-Instruct" + # ar = AutoRound(model_name, enable_adam=True) + # self.assertEqual(ar.optimizer, torch.optim.AdamW) + # self.assertTrue(ar.mllm) + # + # # test old api + # from auto_round import AutoRoundMLLM + # + # ar = AutoRoundMLLM(model_name) + # self.assertTrue(ar.mllm) if __name__ == "__main__": From 815b616567b0b452ff67e703f9dc5d6d01617346 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Oct 2025 09:05:01 +0000 Subject: [PATCH 05/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/test_cpu/test_autoround.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index 4336bd771..171aa0c33 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -452,6 +452,7 @@ def test_embed_quant(self): layer_config=layer_config, ) autoround.quantize() + # # def test_fallback_layers(self): # bits, group_size, sym = 4, 128, True From 9dfa78d017f4f86549067eda7f5eddad9b6cd8ac Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Tue, 21 Oct 2025 17:07:24 +0800 Subject: [PATCH 06/14] fix --- auto_round/data_type/int.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index eaafd2aee..917e0e129 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -43,16 +43,11 @@ def search_scales(data, bits, qw=None): @register_dtype("rtn_int_sym") -def quant_tensor_sym( +def quant_tensor_rnt_sym( tensor, bits=4, group_size=-1, v=0, - min_scale=1.0, - max_scale=1.0, - scale_dtype=torch.float16, - tensor_min=None, - tensor_max=None, q_scale_thresh=1e-5, imatrix=None, **kwargs From 2444764a4107857614cb0d7da7a9e7884252e571 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Oct 2025 09:08:15 +0000 Subject: [PATCH 07/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/data_type/int.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 917e0e129..c7343d927 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -43,15 +43,7 @@ def search_scales(data, bits, qw=None): @register_dtype("rtn_int_sym") -def quant_tensor_rnt_sym( - tensor, - bits=4, - group_size=-1, - v=0, - q_scale_thresh=1e-5, - imatrix=None, - **kwargs -): +def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5, imatrix=None, **kwargs): """Quantize and de-quantize tensor asymmetrically. full range, credict goes to llamacpp community Args: From 2fc7021411886585ec1fb7cbd884bebbb2b1d6c2 Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Tue, 21 Oct 2025 17:10:56 +0800 Subject: [PATCH 08/14] refine --- auto_round/data_type/int.py | 3 +- docs/opt_rtn.md | 56 ++++++++++++++++++++----------------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 917e0e129..a949e4db8 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union import torch @@ -19,7 +20,7 @@ from auto_round.utils import get_reciprocal -def search_scales(data, bits, qw=None): +def search_scales(data:torch.Tensor, bits:int, qw:Union[None,torch.Tensor,float]=None)->torch.Tensor: nmax = pow(2, bits - 1) imax = abs(data).argmax(axis=-1, keepdims=True) group_max = torch.take_along_dim(data, imax, dim=-1) diff --git a/docs/opt_rtn.md b/docs/opt_rtn.md index d674d2b9d..f694c38a2 100644 --- a/docs/opt_rtn.md +++ b/docs/opt_rtn.md @@ -1,42 +1,46 @@ ### ๐Ÿงฎ Evaluation Results (LM-Eval) + 4BIT=W4A16 3BIT=W3A16 2BIT=W2A16G64 RTN mode + ~~~bash auto-round --model xxx --disable_opt_rtn --iters 0 ~~~ OPT RTN mode + ~~~bash auto-round --model xxx --iters 0 ~~~ + For 2/3bit, we strongly recommend not using iter=0. -| Model | RNT/OPT | AVG | HellaSwag | LAMBADA | MMLU | PIQA | WinoGrande | -|-------|----------|-----|------------|-----------------|------|------|-------------| -| **Meta-Llama-3.1-8B-Instruct** | RTN-4BIT | 0.69328 | 0.5896 | 0.7013 | 0.6538 | 0.7987 | 0.7230 | -| | OPT-4BIT | 0.69560 | 0.5882 | 0.7074 | 0.6631 | 0.7916 | 0.7277 | -| | RTN-3BIT | 0.64562 | 0.5410 | 0.6695 | 0.5449 | 0.7742 | 0.6985 | -| | OPT-3BIT | 0.65970 | 0.5490 | 0.6893 | 0.5711 | 0.7677 | 0.7214 | -| | RTN-2BIT | 0.33008 | 0.2918 | 0.0474 | 0.2321 | 0.5740 | 0.5051 | -| | OPT-2BIT | 0.38908 | 0.3241 | 0.1560 | 0.2822 | 0.6235 | 0.5596 | -| **Qwen2.5-7B-Instruct** | RTN-4BIT | 0.69560 | 0.6114 | 0.6713 | 0.7011 | 0.7878 | 0.7064 | -| | OPT-4BIT | 0.70034 | 0.6143 | 0.6945 | 0.7115 | 0.7845 | 0.6969 | -| | RTN-3BIT | 0.64144 | 0.5585 | 0.6092 | 0.6455 | 0.7476 | 0.6464 | -| | OPT-3BIT | 0.66764 | 0.5756 | 0.7013 | 0.6597 | 0.7481 | 0.6535 | -| | RTN-2BIT | 0.31856 | 0.2804 | 0.0351 | 0.2379 | 0.5256 | 0.5138 | -| | OPT-2BIT | 0.45146 | 0.3645 | 0.2992 | 0.4043 | 0.6415 | 0.5478 | -| **Qwen3-8B** | RTN-4BIT | 0.66240 | 0.5619 | 0.6150 | 0.7077 | 0.7573 | 0.6701 | -| | OPT-4BIT | 0.66992 | 0.5619 | 0.6346 | 0.7102 | 0.7633 | 0.6796 | -| | RTN-3BIT | 0.57322 | 0.4992 | 0.4260 | 0.6002 | 0.7361 | 0.6046 | -| | OPT-3BIT | 0.63698 | 0.5226 | 0.5814 | 0.6718 | 0.7437 | 0.6654 | -| | RTN-2BIT | 0.31150 | 0.2679 | 0.0041 | 0.2536 | 0.5283 | 0.5036 | -| | OPT-2BIT | 0.44254 | 0.3749 | 0.2005 | 0.4202 | 0.6670 | 0.5501 | -| **Qwen3-14B** | RTN-4BIT | 0.70448 | 0.5999 | 0.6511 | 0.7565 | 0.7998 | 0.7151 | -| | OPT-4BIT | 0.70798 | 0.6031 | 0.6627 | 0.7534 | 0.8009 | 0.7198 | -| | RTN-3BIT | 0.65876 | 0.5746 | 0.5467 | 0.7065 | 0.7628 | 0.7032 | -| | OPT-3BIT | 0.68610 | 0.5683 | 0.6633 | 0.7258 | 0.7699 | 0.7032 | -| | RTN-2BIT | 0.39398 | 0.3764 | 0.0607 | 0.3836 | 0.6480 | 0.5012 | -| | OPT-2BIT | 0.50080 | 0.4554 | 0.2451 | 0.4899 | 0.7138 | 0.5998 | \ No newline at end of file +| Model | RNT/OPT | AVG | HellaSwag | LAMBADA | MMLU | PIQA | WinoGrande | +|--------------------------------|----------|---------|-----------|---------|--------|--------|------------| +| **Meta-Llama-3.1-8B-Instruct** | RTN-4BIT | 0.69328 | 0.5896 | 0.7013 | 0.6538 | 0.7987 | 0.7230 | +| | OPT-4BIT | 0.69560 | 0.5882 | 0.7074 | 0.6631 | 0.7916 | 0.7277 | +| | RTN-3BIT | 0.64562 | 0.5410 | 0.6695 | 0.5449 | 0.7742 | 0.6985 | +| | OPT-3BIT | 0.65970 | 0.5490 | 0.6893 | 0.5711 | 0.7677 | 0.7214 | +| | RTN-2BIT | 0.33008 | 0.2918 | 0.0474 | 0.2321 | 0.5740 | 0.5051 | +| | OPT-2BIT | 0.38908 | 0.3241 | 0.1560 | 0.2822 | 0.6235 | 0.5596 | +| **Qwen2.5-7B-Instruct** | RTN-4BIT | 0.69560 | 0.6114 | 0.6713 | 0.7011 | 0.7878 | 0.7064 | +| | OPT-4BIT | 0.70034 | 0.6143 | 0.6945 | 0.7115 | 0.7845 | 0.6969 | +| | RTN-3BIT | 0.64144 | 0.5585 | 0.6092 | 0.6455 | 0.7476 | 0.6464 | +| | OPT-3BIT | 0.66764 | 0.5756 | 0.7013 | 0.6597 | 0.7481 | 0.6535 | +| | RTN-2BIT | 0.31856 | 0.2804 | 0.0351 | 0.2379 | 0.5256 | 0.5138 | +| | OPT-2BIT | 0.45146 | 0.3645 | 0.2992 | 0.4043 | 0.6415 | 0.5478 | +| **Qwen3-8B** | RTN-4BIT | 0.66240 | 0.5619 | 0.6150 | 0.7077 | 0.7573 | 0.6701 | +| | OPT-4BIT | 0.66992 | 0.5619 | 0.6346 | 0.7102 | 0.7633 | 0.6796 | +| | RTN-3BIT | 0.57322 | 0.4992 | 0.4260 | 0.6002 | 0.7361 | 0.6046 | +| | OPT-3BIT | 0.63698 | 0.5226 | 0.5814 | 0.6718 | 0.7437 | 0.6654 | +| | RTN-2BIT | 0.31150 | 0.2679 | 0.0041 | 0.2536 | 0.5283 | 0.5036 | +| | OPT-2BIT | 0.44254 | 0.3749 | 0.2005 | 0.4202 | 0.6670 | 0.5501 | +| **Qwen3-14B** | RTN-4BIT | 0.70448 | 0.5999 | 0.6511 | 0.7565 | 0.7998 | 0.7151 | +| | OPT-4BIT | 0.70798 | 0.6031 | 0.6627 | 0.7534 | 0.8009 | 0.7198 | +| | RTN-3BIT | 0.65876 | 0.5746 | 0.5467 | 0.7065 | 0.7628 | 0.7032 | +| | OPT-3BIT | 0.68610 | 0.5683 | 0.6633 | 0.7258 | 0.7699 | 0.7032 | +| | RTN-2BIT | 0.39398 | 0.3764 | 0.0607 | 0.3836 | 0.6480 | 0.5012 | +| | OPT-2BIT | 0.50080 | 0.4554 | 0.2451 | 0.4899 | 0.7138 | 0.5998 | \ No newline at end of file From 3605b4c4689c52dc4ef4c595caccfaf42e985aee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Oct 2025 09:12:23 +0000 Subject: [PATCH 09/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/data_type/int.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index fd28a2175..fb6c5f3db 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -20,7 +20,7 @@ from auto_round.utils import get_reciprocal -def search_scales(data:torch.Tensor, bits:int, qw:Union[None,torch.Tensor,float]=None)->torch.Tensor: +def search_scales(data: torch.Tensor, bits: int, qw: Union[None, torch.Tensor, float] = None) -> torch.Tensor: nmax = pow(2, bits - 1) imax = abs(data).argmax(axis=-1, keepdims=True) group_max = torch.take_along_dim(data, imax, dim=-1) From 67b8e897bc301c259b5fed0912dbe9f44c48a67f Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Tue, 21 Oct 2025 17:14:52 +0800 Subject: [PATCH 10/14] rm useless code --- auto_round/compressors/base.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 14d119907..d948de65b 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1796,32 +1796,6 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: self.quantized = True return self.model, self.layer_config - def _check_rtn_dytpe(self, data_type, bit, sym): - """Check if the given data type is an RTN (Round-To-Nearest) type. - - Args: - data_type (str): The data type to check. - """ - - def pad_sym(dtype): - if sym: - data_sym = dtype + "_sym" - else: - data_sym = dtype + "_asym" - return data_sym - - def pad_bits(dtype): - return dtype + str(bit) - - data_type = "rtn_" + data_type - data_types = [data_type, pad_bits(data_type), pad_sym(data_type), pad_sym(pad_bits(data_type))] - for data_type in data_types: - from auto_round.data_type import QUANT_FUNC_WITH_DTYPE - - if data_type in QUANT_FUNC_WITH_DTYPE: - return data_type - return None - def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: """Quantizes specified layers based on inputs and configuration. From fbd392affb5340391de179e1b376b6825da7015e Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Tue, 21 Oct 2025 17:56:26 +0800 Subject: [PATCH 11/14] revert ut change --- test/test_cpu/test_autoround.py | 1394 +++++++++++++++---------------- 1 file changed, 697 insertions(+), 697 deletions(-) diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index 171aa0c33..3cf33ef7a 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -29,8 +29,8 @@ class TestAutoRound(unittest.TestCase): @classmethod def setUpClass(self): model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - # self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) self.llm_dataloader = LLMDataLoader() self.save_folder = "./saved" @@ -39,404 +39,404 @@ def tearDownClass(self): shutil.rmtree(self.save_folder, ignore_errors=True) shutil.rmtree("runs", ignore_errors=True) - # def test_bits_setting(self): - # layer_config = {"model.decoder.layers.0.self_attn.k_proj": {"data_type": "mx_fp8", "group_size": 32}} - # autoround = AutoRound( - # "/tf_dataset/auto_round/models/facebook/opt-125m", iters=2, seqlen=2, nsamples=1, layer_config=layer_config - # ) - # autoround.quantize() - # module = get_module(autoround.model, "model.decoder.layers.0.self_attn.k_proj") - # if module.bits != 8: - # raise ValueError(f"Expected bits to be 8, but got {module.bits}") - # - # def test_layer_config(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # layer_config = {"self_attn": {"bits": 4, "data_type": "nv_fp", "act_bits": 16, "group_size": 16}} - # autoround = AutoRound( - # model_name, - # self.tokenizer, - # scheme="NVFP4", - # iters=0, - # seqlen=2, - # dataset=self.llm_dataloader, - # layer_config=layer_config, - # amp=False, - # ) - # autoround.quantize_and_save(self.save_folder, inplace=False, format="fake") - # shutil.rmtree(self.save_folder) - # - # def test_remove_whole_block(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # layer_config = { - # "model.decoder.layers.0.self_attn.k_proj": {"bits": 32}, - # "model.decoder.layers.0.self_attn.v_proj": {"bits": 32}, - # "model.decoder.layers.0.self_attn.q_proj": {"bits": 32}, - # "model.decoder.layers.0.self_attn.out_proj": {"bits": 32}, - # "model.decoder.layers.0.fc1": {"bits": 32}, - # "model.decoder.layers.0.fc2": {"bits": 32}, - # } - # bits, group_size, sym = 4, 128, False - # autoround = AutoRound( - # model_name, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # layer_config=layer_config, - # ) - # autoround.quantize() - # - # def test_consecutive_quant(self): - # bits, group_size, sym = 4, -1, False - # autoround = AutoRound( - # self.model, - # self.tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # ) - # autoround.quantize() - # - # model = AutoModelForCausalLM.from_pretrained( - # "/tf_dataset/auto_round/models/microsoft/phi-2", torch_dtype="auto", trust_remote_code=True - # ) - # tokenizer = AutoTokenizer.from_pretrained( - # "/tf_dataset/auto_round/models/microsoft/phi-2", trust_remote_code=True - # ) - # autoround = AutoRound( - # model, - # tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # ) - # autoround.quantize() - # - # def test_mx_fp4(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # bits, group_size, sym = 4, 32, False - # autoround = AutoRound( - # model_name, - # bits=bits, - # act_bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # nsamples=2, - # seqlen=128, - # data_type="mx_fp4", - # act_data_type="mx_fp_rceil", - # ) - # model, _ = autoround.quantize() - # result = simple_evaluate_user_model( - # model, self.tokenizer, batch_size="auto:8", tasks="lambada_openai", limit=32 - # ) - # print(result["results"]["lambada_openai"]["acc,none"]) - # self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.3) # 0.375 - # - # def test_nv_fp4(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # bits, group_size, sym = 4, 16, False - # autoround = AutoRound( - # model_name, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # data_type="nv_fp4", - # ) - # model, _ = autoround.quantize() - # result = simple_evaluate_user_model( - # model, self.tokenizer, batch_size="auto:8", tasks="lambada_openai", limit=32 - # ) - # print(result["results"]["lambada_openai"]["acc,none"]) - # self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.35) - # - # def test_default(self): - # bits, group_size, sym = 4, 128, False - # autoround = AutoRound( - # self.model, - # self.tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # ) - # autoround.quantize() - # - # autoround.save_quantized(output_dir="./saved", inplace=False, format="itrex") - # try: - # import auto_gptq - # except: - # return - # if torch.cuda.is_available(): - # autoround.save_quantized(output_dir="./saved", inplace=False) - # - # def test_w4g1(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # bits, group_size, sym = 4, -1, True - # autoround = AutoRound( - # model_name, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=10, - # dataset=self.llm_dataloader, - # ) - # autoround.quantize() - # - # @parameterized.expand([(2,), (3,), (4,)]) - # def test_g128(self, bits): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # group_size, sym = 128, True - # autoround = AutoRound( - # model_name, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=10, - # dataset=self.llm_dataloader, - # ) - # model, _ = autoround.quantize() - # if bits > 2: - # result = simple_evaluate_user_model( - # model, self.tokenizer, batch_size="auto:8", tasks="lambada_openai", limit=32 - # ) - # print(result["results"]["lambada_openai"]["acc,none"]) - # self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.3) - # - # def test_disable_quanted_input(self): - # bits, group_size, sym = 4, -1, True - # autoround = AutoRound( - # self.model, - # self.tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=10, - # enable_quanted_input=False, - # dataset=self.llm_dataloader, - # ) - # autoround.quantize() - # - # def test_enable_norm_bias_tuning_qwen3(self): - # bits, group_size, sym = 4, 128, True - # model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B" - # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - # autoround = AutoRound( - # model, - # tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=10, - # enable_norm_bias_tuning=True, - # dataset=self.llm_dataloader, - # ) - # autoround.quantize() - # - # def test_enable_norm_bias_tuning(self): - # bits, group_size, sym = 4, -1, True - # autoround = AutoRound( - # self.model, - # self.tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=10, - # enable_quanted_input=False, - # enable_norm_bias_tuning=True, - # dataset=self.llm_dataloader, - # ) - # autoround.quantize() - # - # def test_disable_minmax_tuning(self): - # bits, group_size, sym = 4, -1, True - # autoround = AutoRound( - # self.model, - # self.tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=10, - # enable_minmax_tuning=False, - # dataset=self.llm_dataloader, - # ) - # autoround.quantize() - # - # # - # def test_signround(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # bits, group_size, sym = 4, -1, False - # autoround = AutoRound( - # model_name, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=10, - # enable_minmax_tuning=False, - # enable_quanted_input=False, - # dataset=self.llm_dataloader, - # ) - # autoround.quantize() - # - # def test_lm_head_layer_config_way(self): - # bits, group_size, sym = 4, -1, False - # layer_config = {"lm_head": {"data_type": "int"}} - # autoround = AutoRound( - # self.model, - # self.tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=10, - # enable_minmax_tuning=False, - # enable_quanted_input=False, - # dataset=self.llm_dataloader, - # layer_config=layer_config, - # ) - # autoround.quantize() - # - # def test_wa_quant(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # bits, group_size, sym, act_bits = 4, 128, False, 4 - # autoround = AutoRound( - # model_name, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # act_bits=act_bits, - # ) - # autoround.quantize() - # - # def test_auto_device_map(self): - # bits, group_size, sym = 4, 128, False - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # model = AutoModelForCausalLM.from_pretrained( - # model_name, torch_dtype="auto", trust_remote_code=True, device_map="auto" - # ) - # autoround = AutoRound( - # model, - # self.tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # ) - # autoround.quantize() - # - # def test_device_map_dict(self): - # bits, group_size, sym = 4, 128, False - # device_map = {".*": "cpu"} - # autoround = AutoRound( - # self.model, - # self.tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # device_map=device_map, - # ) - # autoround.quantize() - # - # # test model_name - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # autoround = AutoRound( - # model_name, - # self.tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # device_map=device_map, - # ) - # autoround.quantize() - # - # def test_fp32(self): - # bits, group_size, sym = 4, 128, False - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # model = AutoModelForCausalLM.from_pretrained( - # model_name, torch_dtype=torch.float32, trust_remote_code=True, device_map="auto" - # ) - # autoround = AutoRound( - # model, - # self.tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # amp=False, - # ) - # autoround.quantize() - # - # def test_tensor_reshape(self): - # bits, group_size, sym = 4, 100, False - # autoround = AutoRound( - # self.model, - # self.tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # ) - # autoround.quantize() - # - # def test_rtn(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - # - # bits, group_size, sym = 4, 128, True - # autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, iters=0, nsamples=1) - # quantized_model_path = self.save_folder - # autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") - # model = AutoModelForCausalLM.from_pretrained( - # self.save_folder, - # torch_dtype=torch.float16, - # device_map="auto", - # ) - # - # tokenizer = AutoTokenizer.from_pretrained(self.save_folder) - # model_infer(model, tokenizer) - # shutil.rmtree(self.save_folder) + def test_bits_setting(self): + layer_config = {"model.decoder.layers.0.self_attn.k_proj": {"data_type": "mx_fp8", "group_size": 32}} + autoround = AutoRound( + "/tf_dataset/auto_round/models/facebook/opt-125m", iters=2, seqlen=2, nsamples=1, layer_config=layer_config + ) + autoround.quantize() + module = get_module(autoround.model, "model.decoder.layers.0.self_attn.k_proj") + if module.bits != 8: + raise ValueError(f"Expected bits to be 8, but got {module.bits}") + + def test_layer_config(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + layer_config = {"self_attn": {"bits": 4, "data_type": "nv_fp", "act_bits": 16, "group_size": 16}} + autoround = AutoRound( + model_name, + self.tokenizer, + scheme="NVFP4", + iters=0, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + amp=False, + ) + autoround.quantize_and_save(self.save_folder, inplace=False, format="fake") + shutil.rmtree(self.save_folder) + + def test_remove_whole_block(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + layer_config = { + "model.decoder.layers.0.self_attn.k_proj": {"bits": 32}, + "model.decoder.layers.0.self_attn.v_proj": {"bits": 32}, + "model.decoder.layers.0.self_attn.q_proj": {"bits": 32}, + "model.decoder.layers.0.self_attn.out_proj": {"bits": 32}, + "model.decoder.layers.0.fc1": {"bits": 32}, + "model.decoder.layers.0.fc2": {"bits": 32}, + } + bits, group_size, sym = 4, 128, False + autoround = AutoRound( + model_name, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + autoround.quantize() + + def test_consecutive_quant(self): + bits, group_size, sym = 4, -1, False + autoround = AutoRound( + self.model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + ) + autoround.quantize() + + model = AutoModelForCausalLM.from_pretrained( + "/tf_dataset/auto_round/models/microsoft/phi-2", torch_dtype="auto", trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained( + "/tf_dataset/auto_round/models/microsoft/phi-2", trust_remote_code=True + ) + autoround = AutoRound( + model, + tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + ) + autoround.quantize() + + def test_mx_fp4(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + bits, group_size, sym = 4, 32, False + autoround = AutoRound( + model_name, + bits=bits, + act_bits=bits, + group_size=group_size, + sym=sym, + iters=2, + nsamples=2, + seqlen=128, + data_type="mx_fp4", + act_data_type="mx_fp_rceil", + ) + model, _ = autoround.quantize() + result = simple_evaluate_user_model( + model, self.tokenizer, batch_size="auto:8", tasks="lambada_openai", limit=32 + ) + print(result["results"]["lambada_openai"]["acc,none"]) + self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.3) # 0.375 + + def test_nv_fp4(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + bits, group_size, sym = 4, 16, False + autoround = AutoRound( + model_name, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + data_type="nv_fp4", + ) + model, _ = autoround.quantize() + result = simple_evaluate_user_model( + model, self.tokenizer, batch_size="auto:8", tasks="lambada_openai", limit=32 + ) + print(result["results"]["lambada_openai"]["acc,none"]) + self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.35) + + def test_default(self): + bits, group_size, sym = 4, 128, False + autoround = AutoRound( + self.model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + ) + autoround.quantize() + + autoround.save_quantized(output_dir="./saved", inplace=False, format="itrex") + try: + import auto_gptq + except: + return + if torch.cuda.is_available(): + autoround.save_quantized(output_dir="./saved", inplace=False) + + def test_w4g1(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + bits, group_size, sym = 4, -1, True + autoround = AutoRound( + model_name, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=10, + dataset=self.llm_dataloader, + ) + autoround.quantize() + + @parameterized.expand([(2,), (3,), (4,)]) + def test_g128(self, bits): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + group_size, sym = 128, True + autoround = AutoRound( + model_name, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=10, + dataset=self.llm_dataloader, + ) + model, _ = autoround.quantize() + if bits > 2: + result = simple_evaluate_user_model( + model, self.tokenizer, batch_size="auto:8", tasks="lambada_openai", limit=32 + ) + print(result["results"]["lambada_openai"]["acc,none"]) + self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.3) + + def test_disable_quanted_input(self): + bits, group_size, sym = 4, -1, True + autoround = AutoRound( + self.model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=10, + enable_quanted_input=False, + dataset=self.llm_dataloader, + ) + autoround.quantize() + + def test_enable_norm_bias_tuning_qwen3(self): + bits, group_size, sym = 4, 128, True + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B" + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + autoround = AutoRound( + model, + tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=10, + enable_norm_bias_tuning=True, + dataset=self.llm_dataloader, + ) + autoround.quantize() + + def test_enable_norm_bias_tuning(self): + bits, group_size, sym = 4, -1, True + autoround = AutoRound( + self.model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=10, + enable_quanted_input=False, + enable_norm_bias_tuning=True, + dataset=self.llm_dataloader, + ) + autoround.quantize() + + def test_disable_minmax_tuning(self): + bits, group_size, sym = 4, -1, True + autoround = AutoRound( + self.model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=10, + enable_minmax_tuning=False, + dataset=self.llm_dataloader, + ) + autoround.quantize() + # + def test_signround(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + bits, group_size, sym = 4, -1, False + autoround = AutoRound( + model_name, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=10, + enable_minmax_tuning=False, + enable_quanted_input=False, + dataset=self.llm_dataloader, + ) + autoround.quantize() + + def test_lm_head_layer_config_way(self): + bits, group_size, sym = 4, -1, False + layer_config = {"lm_head": {"data_type": "int"}} + autoround = AutoRound( + self.model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=10, + enable_minmax_tuning=False, + enable_quanted_input=False, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + autoround.quantize() + + def test_wa_quant(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + bits, group_size, sym, act_bits = 4, 128, False, 4 + autoround = AutoRound( + model_name, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + act_bits=act_bits, + ) + autoround.quantize() + + def test_auto_device_map(self): + bits, group_size, sym = 4, 128, False + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype="auto", trust_remote_code=True, device_map="auto" + ) + autoround = AutoRound( + model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + ) + autoround.quantize() + + def test_device_map_dict(self): + bits, group_size, sym = 4, 128, False + device_map = {".*": "cpu"} + autoround = AutoRound( + self.model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + device_map=device_map, + ) + autoround.quantize() + + # test model_name + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + autoround = AutoRound( + model_name, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + device_map=device_map, + ) + autoround.quantize() + + def test_fp32(self): + bits, group_size, sym = 4, 128, False + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float32, trust_remote_code=True, device_map="auto" + ) + autoround = AutoRound( + model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + amp=False, + ) + autoround.quantize() + + def test_tensor_reshape(self): + bits, group_size, sym = 4, 100, False + autoround = AutoRound( + self.model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + ) + autoround.quantize() + + def test_rtn(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + bits, group_size, sym = 4, 128, True + autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, iters=0, nsamples=1) + quantized_model_path = self.save_folder + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + model = AutoModelForCausalLM.from_pretrained( + self.save_folder, + torch_dtype=torch.float16, + device_map="auto", + ) + + tokenizer = AutoTokenizer.from_pretrained(self.save_folder) + model_infer(model, tokenizer) + shutil.rmtree(self.save_folder) + def test_embed_quant(self): bits, group_size, sym = 4, 128, True - model_name = "/models/opt-125m" + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" layer_config = { "model.decoder.embed_tokens": {"bits": 4}, } @@ -453,308 +453,308 @@ def test_embed_quant(self): ) autoround.quantize() - # - # def test_fallback_layers(self): - # bits, group_size, sym = 4, 128, True - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # model = AutoModelForCausalLM.from_pretrained( - # model_name, torch_dtype=torch.float32, trust_remote_code=True, device_map="auto" - # ) - # layer_config = { - # "model.decoder.layers.0.self_attn.q_proj": {"bits": 16}, - # "model.decoder.layers.1.self_attn.k_proj": {"bits": 16}, - # "model.decoder.embed_tokens": {"bits": 16}, - # } - # autoround = AutoRound( - # model, - # self.tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # layer_config=layer_config, - # ) - # autoround.quantize() - # quantized_model_path = self.save_folder - # - # autoround.save_quantized(output_dir=quantized_model_path, format="auto_round", inplace=True) - # quantization_config = AutoRoundConfig(backend="ipex") - # - # model = AutoModelForCausalLM.from_pretrained( - # quantized_model_path, device_map="cpu", quantization_config=quantization_config - # ) - # tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) - # text = "There is a girl who likes adventure," - # inputs = tokenizer(text, return_tensors="pt").to(model.device) - # res = tokenizer.decode(model.generate(**inputs, max_new_tokens=1)[0]) - # shutil.rmtree(self.save_folder, ignore_errors=True) - # - # def test_not_convert_modules(self): - # import requests - # from PIL import Image - # from transformers import AutoProcessor, Qwen2VLForConditionalGeneration - # - # from auto_round_extension.ipex.qlinear_ipex_awq import QuantLinear - # - # model_name = "/tf_dataset/auto_round/models/Qwen/Qwen2-VL-2B-Instruct-AWQ" - # quantization_config = AutoRoundConfig() - # model = Qwen2VLForConditionalGeneration.from_pretrained( - # model_name, quantization_config=quantization_config, device_map="cpu", torch_dtype=torch.float16 - # ) - # self.assertTrue(isinstance(model.visual.blocks[0].attn.qkv, torch.nn.Linear)) - # self.assertFalse(isinstance(model.visual.merger.mlp[0], QuantLinear)) - # if hasattr(model.model, "language_model"): - # self.assertTrue(isinstance(model.model.language_model.layers[0].self_attn.v_proj, QuantLinear)) - # else: - # self.assertTrue(isinstance(model.model.layers[0].self_attn.v_proj, QuantLinear)) - # - # processor = AutoProcessor.from_pretrained(model_name, size=None) - # image_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" - # messages = [ - # { - # "role": "user", - # "content": [ - # { - # "type": "image", - # "image": image_url, - # }, - # {"type": "text", "text": "Describe this image."}, - # ], - # } - # ] - # - # # Preparation for inference - # text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - # image_inputs = Image.open(requests.get(image_url, stream=True).raw) - # inputs = processor( - # text=[text], - # images=image_inputs, - # padding=True, - # return_tensors="pt", - # ) - # - # # Inference: Generation of the output - # generated_ids = model.generate(**inputs, max_new_tokens=1) - # generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] - # output_text = processor.batch_decode( - # generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False - # ) - # print(output_text) - # - # def test_fallback_layers_regex_awq(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # bits, group_size, sym = 4, 128, True - # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - # layer_config = { - # r"model\.decoder\.layers\.(?:[0-9]|1[0-1])\.self_attn\.q_proj": {"bits": 16}, - # "model.decoder.layers.1.self_attn.k_proj": {"bits": 16}, - # } - # autoround = AutoRound( - # model, - # tokenizer=tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # layer_config=layer_config, - # ) - # autoround.quantize() - # quantized_model_path = self.save_folder - # - # autoround.save_quantized(output_dir=quantized_model_path, format="auto_awq", inplace=True) - # quantization_config = AutoRoundConfig() - # - # model = AutoModelForCausalLM.from_pretrained( - # quantized_model_path, device_map="auto", quantization_config=quantization_config - # ) - # tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) - # text = "There is a girl who likes adventure," - # inputs = tokenizer(text, return_tensors="pt").to(model.device) - # res = tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) - # print(res) - # shutil.rmtree(self.save_folder, ignore_errors=True) - # - # def test_fallback_layers_regex_gptq(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # bits, group_size, sym = 4, 128, True - # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - # layer_config = { - # r"model\.decoder\.layers\.(?:[0-9]|1[0-1])\.self_attn\.q_proj": {"bits": 16}, - # ##"model.decoder.layers.1.self_attn.k_proj": {"bits": 16} - # } - # autoround = AutoRound( - # model, - # tokenizer=tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # layer_config=layer_config, - # ) - # autoround.quantize() - # quantized_model_path = self.save_folder - # - # autoround.save_quantized(output_dir=quantized_model_path, format="auto_gptq", inplace=True) - # quantization_config = AutoRoundConfig() - # - # model = AutoModelForCausalLM.from_pretrained( - # quantized_model_path, device_map="auto", quantization_config=quantization_config - # ) - # tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) - # text = "There is a girl who likes adventure," - # inputs = tokenizer(text, return_tensors="pt").to(model.device) - # res = tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) - # print(res) - # shutil.rmtree(self.save_folder, ignore_errors=True) - # - # def test_fallback_layers_regex_round(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # bits, group_size, sym = 4, 128, True - # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - # layer_config = { - # r"model\.decoder\.layers\.(?:[0-9]|1[0-1])\.self_attn\.q_proj": {"bits": 16}, - # r"model.decoder.layers.1.self_attn.k_proj": {"bits": 16}, - # } - # autoround = AutoRound( - # model, - # tokenizer=tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # layer_config=layer_config, - # ) - # autoround.quantize() - # quantized_model_path = self.save_folder - # - # autoround.save_quantized(output_dir=quantized_model_path, format="auto_round", inplace=True) - # quantization_config = AutoRoundConfig() - # - # model = AutoModelForCausalLM.from_pretrained( - # quantized_model_path, device_map="auto", quantization_config=quantization_config - # ) - # tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + + def test_fallback_layers(self): + bits, group_size, sym = 4, 128, True + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float32, trust_remote_code=True, device_map="auto" + ) + layer_config = { + "model.decoder.layers.0.self_attn.q_proj": {"bits": 16}, + "model.decoder.layers.1.self_attn.k_proj": {"bits": 16}, + "model.decoder.embed_tokens": {"bits": 16}, + } + autoround = AutoRound( + model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + autoround.quantize() + quantized_model_path = self.save_folder + + autoround.save_quantized(output_dir=quantized_model_path, format="auto_round", inplace=True) + quantization_config = AutoRoundConfig(backend="ipex") + + model = AutoModelForCausalLM.from_pretrained( + quantized_model_path, device_map="cpu", quantization_config=quantization_config + ) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + res = tokenizer.decode(model.generate(**inputs, max_new_tokens=1)[0]) + shutil.rmtree(self.save_folder, ignore_errors=True) + + def test_not_convert_modules(self): + import requests + from PIL import Image + from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + + from auto_round_extension.ipex.qlinear_ipex_awq import QuantLinear + + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen2-VL-2B-Instruct-AWQ" + quantization_config = AutoRoundConfig() + model = Qwen2VLForConditionalGeneration.from_pretrained( + model_name, quantization_config=quantization_config, device_map="cpu", torch_dtype=torch.float16 + ) + self.assertTrue(isinstance(model.visual.blocks[0].attn.qkv, torch.nn.Linear)) + self.assertFalse(isinstance(model.visual.merger.mlp[0], QuantLinear)) + if hasattr(model.model, "language_model"): + self.assertTrue(isinstance(model.model.language_model.layers[0].self_attn.v_proj, QuantLinear)) + else: + self.assertTrue(isinstance(model.model.layers[0].self_attn.v_proj, QuantLinear)) + + processor = AutoProcessor.from_pretrained(model_name, size=None) + image_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": image_url, + }, + {"type": "text", "text": "Describe this image."}, + ], + } + ] + + # Preparation for inference + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + image_inputs = Image.open(requests.get(image_url, stream=True).raw) + inputs = processor( + text=[text], + images=image_inputs, + padding=True, + return_tensors="pt", + ) + + # Inference: Generation of the output + generated_ids = model.generate(**inputs, max_new_tokens=1) + generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + output_text = processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + print(output_text) + + def test_fallback_layers_regex_awq(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + bits, group_size, sym = 4, 128, True + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + layer_config = { + r"model\.decoder\.layers\.(?:[0-9]|1[0-1])\.self_attn\.q_proj": {"bits": 16}, + "model.decoder.layers.1.self_attn.k_proj": {"bits": 16}, + } + autoround = AutoRound( + model, + tokenizer=tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + autoround.quantize() + quantized_model_path = self.save_folder + + autoround.save_quantized(output_dir=quantized_model_path, format="auto_awq", inplace=True) + quantization_config = AutoRoundConfig() + + model = AutoModelForCausalLM.from_pretrained( + quantized_model_path, device_map="auto", quantization_config=quantization_config + ) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + res = tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) + print(res) + shutil.rmtree(self.save_folder, ignore_errors=True) + + def test_fallback_layers_regex_gptq(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + bits, group_size, sym = 4, 128, True + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + layer_config = { + r"model\.decoder\.layers\.(?:[0-9]|1[0-1])\.self_attn\.q_proj": {"bits": 16}, + ##"model.decoder.layers.1.self_attn.k_proj": {"bits": 16} + } + autoround = AutoRound( + model, + tokenizer=tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + autoround.quantize() + quantized_model_path = self.save_folder + + autoround.save_quantized(output_dir=quantized_model_path, format="auto_gptq", inplace=True) + quantization_config = AutoRoundConfig() + + model = AutoModelForCausalLM.from_pretrained( + quantized_model_path, device_map="auto", quantization_config=quantization_config + ) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + res = tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) + print(res) + shutil.rmtree(self.save_folder, ignore_errors=True) + + def test_fallback_layers_regex_round(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + bits, group_size, sym = 4, 128, True + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + layer_config = { + r"model\.decoder\.layers\.(?:[0-9]|1[0-1])\.self_attn\.q_proj": {"bits": 16}, + r"model.decoder.layers.1.self_attn.k_proj": {"bits": 16}, + } + autoround = AutoRound( + model, + tokenizer=tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + autoround.quantize() + quantized_model_path = self.save_folder + + autoround.save_quantized(output_dir=quantized_model_path, format="auto_round", inplace=True) + quantization_config = AutoRoundConfig() + + model = AutoModelForCausalLM.from_pretrained( + quantized_model_path, device_map="auto", quantization_config=quantization_config + ) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + res = tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) + print(res) + shutil.rmtree(self.save_folder, ignore_errors=True) + + def test_fallback_layers_regex_exception(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + bits, group_size, sym = 4, 128, True + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + layer_config = {"model.decoder.layers.12.self_attn.k_proj": {"bits": 16}} + with self.assertRaises(ValueError): + autoround = AutoRound( + model, + tokenizer=tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + autoround.quantize() + + # def test_fp8_model_input_rtn_generation(self): + # model_name = "Qwen/Qwen3-0.6B-FP8" + # ar = AutoRound(model=model_name, iters=0) + # ar.quantize_and_save(output_dir=self.save_folder) + # model = AutoModelForCausalLM.from_pretrained(self.save_folder, torch_dtype="auto", trust_remote_code=True) + # tokenizer = AutoTokenizer.from_pretrained(self.save_folder) # text = "There is a girl who likes adventure," # inputs = tokenizer(text, return_tensors="pt").to(model.device) - # res = tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) - # print(res) - # shutil.rmtree(self.save_folder, ignore_errors=True) - # - # def test_fallback_layers_regex_exception(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # bits, group_size, sym = 4, 128, True - # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - # layer_config = {"model.decoder.layers.12.self_attn.k_proj": {"bits": 16}} - # with self.assertRaises(ValueError): - # autoround = AutoRound( - # model, - # tokenizer=tokenizer, - # bits=bits, - # group_size=group_size, - # sym=sym, - # iters=2, - # seqlen=2, - # dataset=self.llm_dataloader, - # layer_config=layer_config, - # ) - # autoround.quantize() - # - # # def test_fp8_model_input_rtn_generation(self): - # # model_name = "Qwen/Qwen3-0.6B-FP8" - # # ar = AutoRound(model=model_name, iters=0) - # # ar.quantize_and_save(output_dir=self.save_folder) - # # model = AutoModelForCausalLM.from_pretrained(self.save_folder, torch_dtype="auto", trust_remote_code=True) - # # tokenizer = AutoTokenizer.from_pretrained(self.save_folder) - # # text = "There is a girl who likes adventure," - # # inputs = tokenizer(text, return_tensors="pt").to(model.device) - # # print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) - # - # def test_dequant_fp8_weight(self): - # from auto_round.utils import dequant_block_fp8_weight - # - # # test pad and unpad - # weight = torch.randn(587, 7168) - # weight_scale = torch.randn(5, 56) - # block_size = [128, 128] - # dequant_weight = dequant_block_fp8_weight(weight, weight_scale, block_size) - # self.assertEqual(dequant_weight.shape.numel(), 4207616) - # - # # test experts are stacked. - # weight = torch.randn([32, 5760, 1440]) - # weight_scale = torch.randn([32, 5760, 90]) - # block_size = [1, 16] - # dequant_weight = dequant_block_fp8_weight(weight, weight_scale, block_size) - # self.assertEqual(len(dequant_weight.shape), 3) - # self.assertEqual(dequant_weight.shape[0], 32) - # self.assertEqual(dequant_weight.shape.numel(), 32 * 5760 * 1440) - # - # def test_mixed_bit_setting(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # layer_config = {"model.decoder.layers.7.fc1": {"bits": 8, "act_bits": 8}} - # ar = AutoRound(model_name, data_type="mx_fp4", act_bits=4, iters=0, layer_config=layer_config) - # ar.quantize() - # layer_config = ar.layer_config - # if ( - # layer_config["model.decoder.layers.7.fc1"]["bits"] != 8 - # or layer_config["model.decoder.layers.7.fc1"]["act_bits"] != 8 - # ): - # raise ValueError("mixed bits is not correct") - # - # def test_alg_ext(self): - # model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - # ar = AutoRound(model_name, scheme="W2A16", iters=1, nsamples=1, enable_alg_ext=True) - # ar.quantize() - # - # def test_invalid_layer_config(self): - # with self.assertRaises(ValueError): - # layer_config = {"model.decoder.layers.2.self_attnx": {"bits": 2}} - # ar = AutoRound( - # "/tf_dataset/auto_round/models/facebook/opt-125m", - # scheme="W3A16", - # nsamples=1, - # iters=1, - # layer_config=layer_config, - # ) - # ar.quantize() - # with self.assertRaises(ValueError): - # layer_config = {"model.decoder.layers.2.self_attn": {"bit": 2}} # should be bits - # ar = AutoRound( - # "/tf_dataset/auto_round/models/facebook/opt-125m", - # scheme="W3A16", - # nsamples=1, - # iters=1, - # layer_config=layer_config, - # ) - # ar.quantize() - # - # def test_quant_lm_head(self): - # model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" - # ar = AutoRound(model_name, quant_lm_head=True, iters=1, nsamples=1, seqlen=32) - # ar.quantize() - # - # def test_compressor(self): - # model_name = "Qwen/Qwen2-VL-2B-Instruct" - # ar = AutoRound(model_name, enable_adam=True) - # self.assertEqual(ar.optimizer, torch.optim.AdamW) - # self.assertTrue(ar.mllm) - # - # # test old api - # from auto_round import AutoRoundMLLM - # - # ar = AutoRoundMLLM(model_name) - # self.assertTrue(ar.mllm) + # print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) + + def test_dequant_fp8_weight(self): + from auto_round.utils import dequant_block_fp8_weight + + # test pad and unpad + weight = torch.randn(587, 7168) + weight_scale = torch.randn(5, 56) + block_size = [128, 128] + dequant_weight = dequant_block_fp8_weight(weight, weight_scale, block_size) + self.assertEqual(dequant_weight.shape.numel(), 4207616) + + # test experts are stacked. + weight = torch.randn([32, 5760, 1440]) + weight_scale = torch.randn([32, 5760, 90]) + block_size = [1, 16] + dequant_weight = dequant_block_fp8_weight(weight, weight_scale, block_size) + self.assertEqual(len(dequant_weight.shape), 3) + self.assertEqual(dequant_weight.shape[0], 32) + self.assertEqual(dequant_weight.shape.numel(), 32 * 5760 * 1440) + + def test_mixed_bit_setting(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + layer_config = {"model.decoder.layers.7.fc1": {"bits": 8, "act_bits": 8}} + ar = AutoRound(model_name, data_type="mx_fp4", act_bits=4, iters=0, layer_config=layer_config) + ar.quantize() + layer_config = ar.layer_config + if ( + layer_config["model.decoder.layers.7.fc1"]["bits"] != 8 + or layer_config["model.decoder.layers.7.fc1"]["act_bits"] != 8 + ): + raise ValueError("mixed bits is not correct") + + def test_alg_ext(self): + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + ar = AutoRound(model_name, scheme="W2A16", iters=1, nsamples=1, enable_alg_ext=True) + ar.quantize() + + def test_invalid_layer_config(self): + with self.assertRaises(ValueError): + layer_config = {"model.decoder.layers.2.self_attnx": {"bits": 2}} + ar = AutoRound( + "/tf_dataset/auto_round/models/facebook/opt-125m", + scheme="W3A16", + nsamples=1, + iters=1, + layer_config=layer_config, + ) + ar.quantize() + with self.assertRaises(ValueError): + layer_config = {"model.decoder.layers.2.self_attn": {"bit": 2}} # should be bits + ar = AutoRound( + "/tf_dataset/auto_round/models/facebook/opt-125m", + scheme="W3A16", + nsamples=1, + iters=1, + layer_config=layer_config, + ) + ar.quantize() + + def test_quant_lm_head(self): + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" + ar = AutoRound(model_name, quant_lm_head=True, iters=1, nsamples=1, seqlen=32) + ar.quantize() + + def test_compressor(self): + model_name = "Qwen/Qwen2-VL-2B-Instruct" + ar = AutoRound(model_name, enable_adam=True) + self.assertEqual(ar.optimizer, torch.optim.AdamW) + self.assertTrue(ar.mllm) + + # test old api + from auto_round import AutoRoundMLLM + + ar = AutoRoundMLLM(model_name) + self.assertTrue(ar.mllm) if __name__ == "__main__": From 86c5fab85754a4ba0aba2479378245ffe3ba7ceb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Oct 2025 09:57:05 +0000 Subject: [PATCH 12/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/test_cpu/test_autoround.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index 3cf33ef7a..cbd0583df 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -453,7 +453,6 @@ def test_embed_quant(self): ) autoround.quantize() - def test_fallback_layers(self): bits, group_size, sym = 4, 128, True model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" From 11b25a91a2da34c0ac0e74056cd7aef260d7a509 Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Wed, 22 Oct 2025 11:32:46 +0800 Subject: [PATCH 13/14] update doc --- README.md | 8 ++++---- docs/opt_rtn.md | 3 ++- docs/step_by_step.md | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 115d2152e..8522468c7 100644 --- a/README.md +++ b/README.md @@ -27,13 +27,13 @@ and [fbaldassarri](https://huggingface.co/fbaldassarri). For usage instructions, ## ๐Ÿ†• What's New -[2025/10] We proposed a fast algorithm to generate mixed bits/datatypes schemes in minutes. Please +[2025/10] We proposed a fast algorithm to generate **mixed bits/datatypes** schemes in minutes. Please refer to the documentation for accuracy [results](./docs/auto_scheme_acc.md) and [this guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme) for usage instructions. -[2025/09] AutoRound now includes experimental support for the mxfp4 and nvfp4 dtypes. For accuracy results, see the [documentation](./docs/mxnv_acc.md) +[2025/09] AutoRound now includes experimental support for the **mxfp4 and nvfp4 dtypes**. For accuracy results, see the [documentation](./docs/mxnv_acc.md) . We currently recommend exporting to the LLM-Compressor format. -[2025/08] AutoRound now provides experimental support for an improved INT2 algorithm via `--enable_alg_ext`. See this [documentation](./docs/alg_202508.md) +[2025/08] AutoRound now provides experimental support for **an improved INT2 algorithm** via `--enable_alg_ext`. See this [documentation](./docs/alg_202508.md) for some accuracy results. [2025/07] AutoRound now offers experimental support for **GGUF** format, and recommends using optimized RTN mode (--iters 0) for @@ -67,7 +67,7 @@ Support **AutoRound, AutoAWQ, AutoGPTQ, and GGUF** for maximum compatibility. De โœ… **Affordable Quantization Cost** Quantize 7B models in about 10 minutes on a single GPU. Details are shown in [quantization costs](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#quantization-costs) -โœ… **Fast mixed bits/data-types scheme generation** +โœ… **Fast Mixed Bits/Dtypes Scheme Generation** Automatically configure in minutes, with about 2X-4X the modelโ€™s BF16 VRAM size as overhead. Accuracy [results](./docs/auto_scheme_acc.md) and [user guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme). โœ… **10+ VLMs Support** diff --git a/docs/opt_rtn.md b/docs/opt_rtn.md index f694c38a2..354a397c3 100644 --- a/docs/opt_rtn.md +++ b/docs/opt_rtn.md @@ -1,4 +1,5 @@ ### ๐Ÿงฎ Evaluation Results (LM-Eval) +For 2/3bit, we strongly recommend not using iter=0 except for GGUF:Q2_K_S which has a different quantization algorithm. 4BIT=W4A16 3BIT=W3A16 @@ -16,7 +17,7 @@ OPT RTN mode auto-round --model xxx --iters 0 ~~~ -For 2/3bit, we strongly recommend not using iter=0. + | Model | RNT/OPT | AVG | HellaSwag | LAMBADA | MMLU | PIQA | WinoGrande | |--------------------------------|----------|---------|-----------|---------|--------|--------|------------| diff --git a/docs/step_by_step.md b/docs/step_by_step.md index 097b97549..bcf9a54d9 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -369,7 +369,7 @@ Embedding layer is not supported in AutoScheme, it will use the best scheme in o ### OPT RTN Mode -AutoRound also supports Optimized RTN (Round-To-Nearest) mode for fast, calibration-free baseline quantization. try setting `iters=0` and use `group_size=32` for better results. +AutoRound also supports Optimized RTN (Round-To-Nearest) mode for fast, calibration-free baseline quantization. Setting `iters=0` tp enable it and we recommend useing `group_size=32` for better results. Check [accuracy comparison](./opt_rtn.md) between RTN and OPT RTN mode For the GGUF format, we have optimized the RTN algorithm inspired by llamacpp. To use the original (pure) RTN algorithm instead, enable the `--disable_opt_rtn` option. ```python From abf86a18a5f985e1ddda68193d764098a286fda7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Oct 2025 03:33:26 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/step_by_step.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/step_by_step.md b/docs/step_by_step.md index bcf9a54d9..9a518e31f 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -369,7 +369,7 @@ Embedding layer is not supported in AutoScheme, it will use the best scheme in o ### OPT RTN Mode -AutoRound also supports Optimized RTN (Round-To-Nearest) mode for fast, calibration-free baseline quantization. Setting `iters=0` tp enable it and we recommend useing `group_size=32` for better results. Check [accuracy comparison](./opt_rtn.md) between RTN and OPT RTN mode +AutoRound also supports Optimized RTN (Round-To-Nearest) mode for fast, calibration-free baseline quantization. Setting `iters=0` tp enable it and we recommend using `group_size=32` for better results. Check [accuracy comparison](./opt_rtn.md) between RTN and OPT RTN mode For the GGUF format, we have optimized the RTN algorithm inspired by llamacpp. To use the original (pure) RTN algorithm instead, enable the `--disable_opt_rtn` option. ```python