From cc4291c33f5e5db582e4ce8546db333fe9fb8bfb Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Fri, 14 Nov 2025 01:16:37 -0500 Subject: [PATCH 1/2] support tuning target_bits Signed-off-by: He, Xin3 --- neural_compressor/common/base_config.py | 6 +- neural_compressor/common/tuning_param.py | 2 +- .../torch/algorithms/weight_only/autoround.py | 58 +++++++++++++- .../torch/quantization/algorithm_entry.py | 17 +++- .../torch/quantization/autotune.py | 14 +++- .../torch/quantization/config.py | 45 ++++++++++- .../weight_only/test_autoround.py | 79 ++++++++++++++++++- 7 files changed, 206 insertions(+), 15 deletions(-) diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index d6cb8108c13..b8c1eac83ab 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -475,7 +475,7 @@ def get_the_default_value_of_param(config: BaseConfig, param: str) -> Any: # Get the parameters and their default values parameters = signature.parameters - return parameters.get(param).default + return parameters.get(param).annotation def expand(self) -> List[BaseConfig]: """Expand the config. @@ -522,8 +522,8 @@ def expand(self) -> List[BaseConfig]: # 1. The param is a string. # 2. The param is a `TuningParam` instance. if isinstance(param, str): - default_param = self.get_the_default_value_of_param(config, param) - tuning_param = TuningParam(name=param, tunable_type=List[type(default_param)]) + param_annotation = self.get_the_default_value_of_param(config, param) + tuning_param = TuningParam(name=param, tunable_type=List[param_annotation]) elif isinstance(param, TuningParam): tuning_param = param else: diff --git a/neural_compressor/common/tuning_param.py b/neural_compressor/common/tuning_param.py index d3f7d452e1c..924a6eed8a8 100644 --- a/neural_compressor/common/tuning_param.py +++ b/neural_compressor/common/tuning_param.py @@ -118,8 +118,8 @@ def is_tunable(self, value: Any) -> bool: assert isinstance( self.tunable_type, typing._GenericAlias ), f"Expected a type hint, got {self.tunable_type} instead." - DynamicInputArgsModel = TuningParam.create_input_args_model(self.tunable_type) try: + DynamicInputArgsModel = TuningParam.create_input_args_model(self.tunable_type) new_args = DynamicInputArgsModel(input_args=value) return True except Exception as e: diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index 5fa3b253cfa..69d00a320be 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -16,7 +16,7 @@ import json import time from functools import lru_cache -from typing import Optional, Union +from typing import Optional, Union, Iterable import torch @@ -41,6 +41,7 @@ def _is_auto_round_available(): from neural_compressor.torch.algorithms import Quantizer from neural_compressor.torch.utils import get_accelerator, logger +from neural_compressor.common.utils import Statistics from .utility import CapturedDataloader, InputCaptureModule @@ -104,6 +105,14 @@ def __init__( guidance_scale: float = 7.5, num_inference_steps: int = 50, generator_seed: int = None, + # 0.9 + target_bits: int = None, + options: Union[str, list[Union[str]], tuple[Union[str], ...]] = ("MXFP4", "MXFP8"), + shared_layers: Optional[Iterable[Iterable[str]]] = None, + ignore_scale_zp_bits: bool = False, + auto_scheme_method: str = "default", + auto_scheme_batch_size: int = None, + auto_scheme_device_map: str = None, **kwargs, ): """Init a AutQRoundQuantizer object. @@ -238,6 +247,13 @@ def __init__( self.guidance_scale = guidance_scale self.num_inference_steps = num_inference_steps self.generator_seed = generator_seed + self.target_bits = target_bits + self.options = options + self.shared_layers = shared_layers + self.ignore_scale_zp_bits = ignore_scale_zp_bits + self.auto_scheme_method = auto_scheme_method + self.auto_scheme_batch_size = auto_scheme_batch_size + self.auto_scheme_device_map = auto_scheme_device_map def _is_w4afp8(self) -> bool: return any([v.get("data_type", None) == "fp8_to_int_sym" for v in self.quant_config.values()]) @@ -273,6 +289,18 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): model = model.orig_model if pipe is not None: model = pipe + if self.target_bits is not None: + from auto_round import AutoScheme + self.scheme = AutoScheme( + avg_bits=self.target_bits, + options=self.options, + shared_layers=self.shared_layers, + ignore_scale_zp_bits=self.ignore_scale_zp_bits, + method=self.auto_scheme_method, + batch_size=self.auto_scheme_batch_size, + device_map=self.auto_scheme_device_map, + low_gpu_mem_usage=self.low_gpu_mem_usage, + ) rounder = AutoRound( model, layer_config=self.layer_config, @@ -338,6 +366,9 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): rounder.quantize_and_save(output_dir=self.output_dir, format=self.export_format, inplace=True) model = rounder.model model.autoround_config = rounder.layer_config + + dump_model_op_stats(rounder.layer_config) + return model @@ -452,3 +483,28 @@ def get_mllm_dataloader( quant_nontext_module=quant_nontext_module, ) return dataloader, template, truncation, batch_size, gradient_accumulate_steps, seqlen, nsamples + + +def dump_model_op_stats(layer_config): + """Dump quantizable ops stats of model to user.""" + # TODO: collect more ops besides Linear + res = {} + res["Linear"] = {} + for name, info in layer_config.items(): + if 'data_type' in info: + data_type_str = info['data_type'].upper() + if 'bits' in info and str(info["bits"]) not in info["data_type"]: + data_type_str += str(info['bits']) + res["Linear"][data_type_str] = res.get("Linear", {}).get(data_type_str, 0) + 1 + + # update stats format for dump. + field_names = ["Op Type", "Total"] + dtype_list = list(res["Linear"].keys()) + field_names.extend(dtype_list) + output_data = [] + for op_type in res.keys(): + field_results = [op_type, sum(res[op_type].values())] + field_results.extend([res[op_type][dtype] for dtype in dtype_list]) + output_data.append(field_results) + + Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat() diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 1f9b50d7339..953e5b48271 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -646,6 +646,15 @@ def autoround_quantize_entry( guidance_scale = quant_config.to_dict().get("guidance_scale", 7.5) num_inference_steps = quant_config.to_dict().get("num_inference_steps", 50) generator_seed = quant_config.to_dict().get("generator_seed", None) + # 0.9.0: auto scheme parameters + target_bits=quant_config.target_bits + options=quant_config.options + shared_layers=quant_config.shared_layers + ignore_scale_zp_bits=quant_config.ignore_scale_zp_bits + auto_scheme_method=quant_config.auto_scheme_method + auto_scheme_batch_size=quant_config.auto_scheme_batch_size + auto_scheme_device_map=quant_config.auto_scheme_device_map + kwargs.pop("example_inputs") quantizer = get_quantizer( @@ -702,12 +711,18 @@ def autoround_quantize_entry( guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator_seed=generator_seed, + target_bits=target_bits, + options=options, + shared_layers=shared_layers, + ignore_scale_zp_bits=ignore_scale_zp_bits, + auto_scheme_method=auto_scheme_method, + auto_scheme_batch_size=auto_scheme_batch_size, + auto_scheme_device_map=auto_scheme_device_map, ) model = quantizer.execute(model=model, mode=mode, *args, **kwargs) model.qconfig = configs_mapping model.save = MethodType(save, model) postprocess_model(model, mode, quantizer) - dump_model_op_stats(mode, configs_mapping) return model diff --git a/neural_compressor/torch/quantization/autotune.py b/neural_compressor/torch/quantization/autotune.py index fe290908127..cda8a739507 100644 --- a/neural_compressor/torch/quantization/autotune.py +++ b/neural_compressor/torch/quantization/autotune.py @@ -23,7 +23,7 @@ from neural_compressor.common.base_tuning import EvaluationFuncWrapper, TuningConfig, init_tuning from neural_compressor.common.utils import dump_elapsed_time from neural_compressor.torch.quantization import quantize -from neural_compressor.torch.quantization.config import FRAMEWORK_NAME, RTNConfig +from neural_compressor.torch.quantization.config import FRAMEWORK_NAME, AutoRoundConfig, RTNConfig from neural_compressor.torch.utils import constants, logger __all__ = [ @@ -62,6 +62,17 @@ def _deepcopy_warp(model): setattr(new_model, key, value) return new_model +def _preprocess_model_quant_config(model, quant_config): + """Preprocess model and quant config before quantization.""" + for config in quant_config.config_set: + # handle tokenizer attribute in AutoRoundConfig + if isinstance(config, AutoRoundConfig): + _tokenizer_backup = getattr(config, "tokenizer", None) + if _tokenizer_backup is not None: + setattr(model, "tokenizer", _tokenizer_backup) + delattr(config, "tokenizer") + return model, quant_config + @dump_elapsed_time("Pass auto-tune") def autotune( @@ -88,6 +99,7 @@ def autotune( The quantized model. """ best_quant_model = None + model, tune_config = _preprocess_model_quant_config(model, tune_config) eval_func_wrapper = EvaluationFuncWrapper(eval_fn, eval_args) config_loader, tuning_logger, tuning_monitor = init_tuning(tuning_config=tune_config) baseline: float = eval_func_wrapper.evaluate(_deepcopy_warp(model)) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 9c8f4eb3609..e181a9dd0b0 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -19,12 +19,12 @@ import copy -import importlib +import inspect import json from collections import OrderedDict from typing import Any, Callable, Dict, List, NamedTuple, Optional from typing import OrderedDict as OrderedDictType -from typing import Tuple, Union +from typing import Tuple, Union, Iterable import torch @@ -99,6 +99,18 @@ def _get_op_name_op_type_config(self): op_type_config_dict[name] = config return op_type_config_dict, op_name_config_dict + @classmethod + def _generate_params_list(cls) -> List[str]: + sig = inspect.signature(cls.__init__) + params_list = list(sig.parameters.keys())[1:] + if "white_list" in params_list: + params_list.remove("white_list") + if "args" in params_list: + params_list.remove("args") + if "kwargs" in params_list: + params_list.remove("kwargs") + return params_list + ######################## RNT Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN, priority=PRIORITY_RTN) @@ -976,7 +988,7 @@ def __init__( enable_torch_compile: bool = False, # v0.7 scheme: str | dict = "W4A16", - device_map: [str, int, torch.device, dict] = 0, + device_map: str | int | torch.device | dict = 0, # mllm quant_nontext_module: bool = False, extra_data_dir: str = None, @@ -987,6 +999,15 @@ def __init__( quant_lm_head: bool = False, # v0.8 enable_adam: bool = False, + # v0.9: auto scheme parameters + target_bits: int = None, + options: Union[str, list[Union[str]], tuple[Union[str], ...]] = ("MXFP4", "MXFP8"), + shared_layers: Optional[Iterable[Iterable[str]]] = None, + ignore_scale_zp_bits: bool = False, + auto_scheme_method: str = "default", + auto_scheme_device_map: str = None, + auto_scheme_batch_size: int = None, + # Tuning space white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, **kwargs, ): @@ -1039,9 +1060,17 @@ def __init__( device_map: The device to be used for tuning. scheme (str| dict | QuantizationScheme ): A preset scheme that defines the quantization configurations. white_list (Optional[List[OP_NAME_OR_MODULE_TYPE]]): White list of operator names or module types. - Default is DEFAULT_WHITE_LIST. + target_bits (int): The target bit width for quantization (default is None). + options (Union[str, list[Union[str]], tuple[Union[str], ...]]): The options for mixed-precision quantization. + shared_layers (Optional[Iterable[Iterable[str]]]): The shared layers for mixed-precision quantization. + ignore_scale_zp_bits (bool): Whether to ignore scale and zero-point bits (default is False). + auto_scheme_method (str): The method for automatic scheme selection (default is "default"). + auto_scheme_device_map (str): The device map for automatic scheme selection (default is None). + auto_scheme_batch_size (int): The batch size for automatic scheme selection (default is 8). """ super().__init__(white_list=white_list) + self.params_list = self.__class__._generate_params_list() + self.params_list.remove("options") # option is a list but not a tunable parameter self.enable_full_range = enable_full_range self.batch_size = batch_size @@ -1057,6 +1086,7 @@ def __init__( self.super_bits = super_bits self.super_group_size = super_group_size self.amp = amp + self.enable_adam = enable_adam self.lr_scheduler = lr_scheduler self.enable_quanted_input = enable_quanted_input self.enable_minmax_tuning = enable_minmax_tuning @@ -1087,6 +1117,13 @@ def __init__( self.scheme = scheme self.device_map = device_map self.quant_lm_head = quant_lm_head + self.target_bits = target_bits + self.options = options + self.shared_layers = shared_layers + self.ignore_scale_zp_bits = ignore_scale_zp_bits + self.auto_scheme_method = auto_scheme_method + self.auto_scheme_device_map = auto_scheme_device_map + self.auto_scheme_batch_size = auto_scheme_batch_size # add kwargs for k, v in kwargs.items(): setattr(self, k, v) diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index 96bb39d38f2..3eb5d52850a 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -7,6 +7,8 @@ from packaging.version import Version, parse import os from functools import lru_cache +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaConfig + @lru_cache(None) def is_habana_framework_installed(): @@ -299,7 +301,6 @@ def test_mllm(self): @pytest.mark.parametrize("scheme", ["W4A16","W2A16","W3A16","W8A16","MXFP4","MXFP8", "NVFP4","FPW8A16","FP8_STATIC"]) def test_scheme(self, scheme): # INC API - from transformers import AutoModelForCausalLM, AutoTokenizer fp32_model = AutoModelForCausalLM.from_pretrained( "facebook/opt-125m", torchscript=True, @@ -335,7 +336,6 @@ def test_scheme(self, scheme): out = inc_model(inp)[0] # AutoRound API - from transformers import AutoModelForCausalLM, AutoTokenizer fp32_model = transformers.AutoModelForCausalLM.from_pretrained( "facebook/opt-125m", torchscript=True, @@ -367,7 +367,79 @@ def test_scheme(self, scheme): assert torch.all(out_ar.eq(out)) shutil.rmtree(output_dir, ignore_errors=True) shutil.rmtree(quantized_model_path, ignore_errors=True) - + + + @pytest.mark.skipif(not ct_installed, reason="The compressed-tensors module is not installed.") + def test_target_bits(self): + fp32_model = AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + torchscript=True, + device_map="auto", + ) + tokenizer = AutoTokenizer.from_pretrained( + "facebook/opt-125m", trust_remote_code=True) + + output_dir = "./saved_inc" + quant_config = AutoRoundConfig( + tokenizer=tokenizer, + nsamples=32, + seqlen=10, + iters=1, + target_bits=5, + options=("MXFP4", "MXFP8"), + enable_torch_compile=True, + low_gpu_mem_usage=True, + export_format="auto_round", + ) + # quantizer execute + model = prepare(model=fp32_model, quant_config=quant_config) + model = convert(model) + # mxfp4/8 model inference relys on autoround extension for vLLM. + assert model.model.decoder.layers[0].self_attn.k_proj.data_type =="mx_fp8", \ + "model is not quantized correctly, please check." + assert model.model.decoder.layers[1].fc1.data_type =="mx_fp4", \ + "model is not quantized correctly, please check." + + + def test_target_bits_autotune(self): + from neural_compressor.torch.quantization import TuningConfig, autotune + baseline = 1 + eval_result = [0.9, 0.8, 0.99] + acc_list = [baseline] + eval_result + + def eval_acc_fn(model) -> float: + acc = acc_list.pop(0) + return acc + + fp32_model = AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + torchscript=True, + device_map="auto", + ) + tokenizer = AutoTokenizer.from_pretrained( + "facebook/opt-125m", trust_remote_code=True) + # AutoRound API + custom_tune_config = TuningConfig( + config_set=[ + AutoRoundConfig( + tokenizer=tokenizer, + target_bits=[5, 6, 7], + options=("MXFP4", "MXFP8"), + enable_torch_compile=True, + low_gpu_mem_usage=True, + export_format="auto_round", + iters=0, + ) + ] + ) + best_model = autotune(model=fp32_model, tune_config=custom_tune_config, eval_fn=eval_acc_fn) + # mxfp4/8 model inference relys on autoround extension for vLLM. + assert best_model.model.decoder.layers[0].self_attn.k_proj.data_type =="mx_fp8", \ + "model is not quantized correctly, please check." + assert best_model.model.decoder.layers[1].fc1.data_type =="mx_fp8", \ + "model is not quantized correctly, please check." + + @pytest.mark.skipif(not is_habana_framework_installed(), reason="Habana framework is not installed") @pytest.mark.skipif(os.getenv("PT_HPU_LAZY_MODE", "0") == "1", reason="Lazy mode is enabled") @pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed") @@ -376,7 +448,6 @@ class TestAutoRoundHPU: def setup_class(self): model_name = "TheBloke/Llama-2-7B-Chat-GPTQ" - from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaConfig from neural_compressor.torch.algorithms.weight_only.autoround import get_dataloader config = LlamaConfig(num_hidden_layers=2) From d70e5bbc235372157bdad95dd01b4cf2470173c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Nov 2025 06:24:43 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../torch/algorithms/weight_only/autoround.py | 13 +++++++------ .../torch/quantization/algorithm_entry.py | 15 +++++++-------- neural_compressor/torch/quantization/autotune.py | 1 + neural_compressor/torch/quantization/config.py | 4 ++-- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index 69d00a320be..2ab9aead9e2 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -16,7 +16,7 @@ import json import time from functools import lru_cache -from typing import Optional, Union, Iterable +from typing import Iterable, Optional, Union import torch @@ -39,9 +39,9 @@ def _is_auto_round_available(): from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401 from auto_round.schemes import QuantizationScheme +from neural_compressor.common.utils import Statistics from neural_compressor.torch.algorithms import Quantizer from neural_compressor.torch.utils import get_accelerator, logger -from neural_compressor.common.utils import Statistics from .utility import CapturedDataloader, InputCaptureModule @@ -291,6 +291,7 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): model = pipe if self.target_bits is not None: from auto_round import AutoScheme + self.scheme = AutoScheme( avg_bits=self.target_bits, options=self.options, @@ -491,10 +492,10 @@ def dump_model_op_stats(layer_config): res = {} res["Linear"] = {} for name, info in layer_config.items(): - if 'data_type' in info: - data_type_str = info['data_type'].upper() - if 'bits' in info and str(info["bits"]) not in info["data_type"]: - data_type_str += str(info['bits']) + if "data_type" in info: + data_type_str = info["data_type"].upper() + if "bits" in info and str(info["bits"]) not in info["data_type"]: + data_type_str += str(info["bits"]) res["Linear"][data_type_str] = res.get("Linear", {}).get(data_type_str, 0) + 1 # update stats format for dump. diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 953e5b48271..296fa35c683 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -647,14 +647,13 @@ def autoround_quantize_entry( num_inference_steps = quant_config.to_dict().get("num_inference_steps", 50) generator_seed = quant_config.to_dict().get("generator_seed", None) # 0.9.0: auto scheme parameters - target_bits=quant_config.target_bits - options=quant_config.options - shared_layers=quant_config.shared_layers - ignore_scale_zp_bits=quant_config.ignore_scale_zp_bits - auto_scheme_method=quant_config.auto_scheme_method - auto_scheme_batch_size=quant_config.auto_scheme_batch_size - auto_scheme_device_map=quant_config.auto_scheme_device_map - + target_bits = quant_config.target_bits + options = quant_config.options + shared_layers = quant_config.shared_layers + ignore_scale_zp_bits = quant_config.ignore_scale_zp_bits + auto_scheme_method = quant_config.auto_scheme_method + auto_scheme_batch_size = quant_config.auto_scheme_batch_size + auto_scheme_device_map = quant_config.auto_scheme_device_map kwargs.pop("example_inputs") quantizer = get_quantizer( diff --git a/neural_compressor/torch/quantization/autotune.py b/neural_compressor/torch/quantization/autotune.py index cda8a739507..26209b60beb 100644 --- a/neural_compressor/torch/quantization/autotune.py +++ b/neural_compressor/torch/quantization/autotune.py @@ -62,6 +62,7 @@ def _deepcopy_warp(model): setattr(new_model, key, value) return new_model + def _preprocess_model_quant_config(model, quant_config): """Preprocess model and quant config before quantization.""" for config in quant_config.config_set: diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index e181a9dd0b0..dd1bc132776 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -22,9 +22,9 @@ import inspect import json from collections import OrderedDict -from typing import Any, Callable, Dict, List, NamedTuple, Optional +from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional from typing import OrderedDict as OrderedDictType -from typing import Tuple, Union, Iterable +from typing import Tuple, Union import torch