Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def get_the_default_value_of_param(config: BaseConfig, param: str) -> Any:

# Get the parameters and their default values
parameters = signature.parameters
return parameters.get(param).default
return parameters.get(param).annotation

def expand(self) -> List[BaseConfig]:
"""Expand the config.
Expand Down Expand Up @@ -522,8 +522,8 @@ def expand(self) -> List[BaseConfig]:
# 1. The param is a string.
# 2. The param is a `TuningParam` instance.
if isinstance(param, str):
default_param = self.get_the_default_value_of_param(config, param)
tuning_param = TuningParam(name=param, tunable_type=List[type(default_param)])
param_annotation = self.get_the_default_value_of_param(config, param)
tuning_param = TuningParam(name=param, tunable_type=List[param_annotation])
elif isinstance(param, TuningParam):
tuning_param = param
else:
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/common/tuning_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def is_tunable(self, value: Any) -> bool:
assert isinstance(
self.tunable_type, typing._GenericAlias
), f"Expected a type hint, got {self.tunable_type} instead."
DynamicInputArgsModel = TuningParam.create_input_args_model(self.tunable_type)
try:
DynamicInputArgsModel = TuningParam.create_input_args_model(self.tunable_type)
new_args = DynamicInputArgsModel(input_args=value)
return True
except Exception as e:
Expand Down
59 changes: 58 additions & 1 deletion neural_compressor/torch/algorithms/weight_only/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import json
import time
from functools import lru_cache
from typing import Optional, Union
from typing import Iterable, Optional, Union

import torch

Expand All @@ -39,6 +39,7 @@ def _is_auto_round_available():
from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401
from auto_round.schemes import QuantizationScheme

from neural_compressor.common.utils import Statistics
from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import get_accelerator, logger

Expand Down Expand Up @@ -104,6 +105,14 @@ def __init__(
guidance_scale: float = 7.5,
num_inference_steps: int = 50,
generator_seed: int = None,
# 0.9
target_bits: int = None,
options: Union[str, list[Union[str]], tuple[Union[str], ...]] = ("MXFP4", "MXFP8"),
shared_layers: Optional[Iterable[Iterable[str]]] = None,
ignore_scale_zp_bits: bool = False,
auto_scheme_method: str = "default",
auto_scheme_batch_size: int = None,
auto_scheme_device_map: str = None,
**kwargs,
):
"""Init a AutQRoundQuantizer object.
Expand Down Expand Up @@ -238,6 +247,13 @@ def __init__(
self.guidance_scale = guidance_scale
self.num_inference_steps = num_inference_steps
self.generator_seed = generator_seed
self.target_bits = target_bits
self.options = options
self.shared_layers = shared_layers
self.ignore_scale_zp_bits = ignore_scale_zp_bits
self.auto_scheme_method = auto_scheme_method
self.auto_scheme_batch_size = auto_scheme_batch_size
self.auto_scheme_device_map = auto_scheme_device_map

def _is_w4afp8(self) -> bool:
return any([v.get("data_type", None) == "fp8_to_int_sym" for v in self.quant_config.values()])
Expand Down Expand Up @@ -273,6 +289,19 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
model = model.orig_model
if pipe is not None:
model = pipe
if self.target_bits is not None:
from auto_round import AutoScheme

self.scheme = AutoScheme(
avg_bits=self.target_bits,
options=self.options,
shared_layers=self.shared_layers,
ignore_scale_zp_bits=self.ignore_scale_zp_bits,
method=self.auto_scheme_method,
batch_size=self.auto_scheme_batch_size,
device_map=self.auto_scheme_device_map,
low_gpu_mem_usage=self.low_gpu_mem_usage,
)
rounder = AutoRound(
model,
layer_config=self.layer_config,
Expand Down Expand Up @@ -338,6 +367,9 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
rounder.quantize_and_save(output_dir=self.output_dir, format=self.export_format, inplace=True)
model = rounder.model
model.autoround_config = rounder.layer_config

dump_model_op_stats(rounder.layer_config)

return model


Expand Down Expand Up @@ -452,3 +484,28 @@ def get_mllm_dataloader(
quant_nontext_module=quant_nontext_module,
)
return dataloader, template, truncation, batch_size, gradient_accumulate_steps, seqlen, nsamples


def dump_model_op_stats(layer_config):
"""Dump quantizable ops stats of model to user."""
# TODO: collect more ops besides Linear
res = {}
res["Linear"] = {}
for name, info in layer_config.items():
if "data_type" in info:
data_type_str = info["data_type"].upper()
if "bits" in info and str(info["bits"]) not in info["data_type"]:
data_type_str += str(info["bits"])
res["Linear"][data_type_str] = res.get("Linear", {}).get(data_type_str, 0) + 1

# update stats format for dump.
field_names = ["Op Type", "Total"]
dtype_list = list(res["Linear"].keys())
field_names.extend(dtype_list)
output_data = []
for op_type in res.keys():
field_results = [op_type, sum(res[op_type].values())]
field_results.extend([res[op_type][dtype] for dtype in dtype_list])
output_data.append(field_results)

Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat()
16 changes: 15 additions & 1 deletion neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,14 @@ def autoround_quantize_entry(
guidance_scale = quant_config.to_dict().get("guidance_scale", 7.5)
num_inference_steps = quant_config.to_dict().get("num_inference_steps", 50)
generator_seed = quant_config.to_dict().get("generator_seed", None)
# 0.9.0: auto scheme parameters
target_bits = quant_config.target_bits
options = quant_config.options
shared_layers = quant_config.shared_layers
ignore_scale_zp_bits = quant_config.ignore_scale_zp_bits
auto_scheme_method = quant_config.auto_scheme_method
auto_scheme_batch_size = quant_config.auto_scheme_batch_size
auto_scheme_device_map = quant_config.auto_scheme_device_map

kwargs.pop("example_inputs")
quantizer = get_quantizer(
Expand Down Expand Up @@ -702,12 +710,18 @@ def autoround_quantize_entry(
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator_seed=generator_seed,
target_bits=target_bits,
options=options,
shared_layers=shared_layers,
ignore_scale_zp_bits=ignore_scale_zp_bits,
auto_scheme_method=auto_scheme_method,
auto_scheme_batch_size=auto_scheme_batch_size,
auto_scheme_device_map=auto_scheme_device_map,
)
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)
return model


Expand Down
15 changes: 14 additions & 1 deletion neural_compressor/torch/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from neural_compressor.common.base_tuning import EvaluationFuncWrapper, TuningConfig, init_tuning
from neural_compressor.common.utils import dump_elapsed_time
from neural_compressor.torch.quantization import quantize
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME, RTNConfig
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME, AutoRoundConfig, RTNConfig
from neural_compressor.torch.utils import constants, logger

__all__ = [
Expand Down Expand Up @@ -63,6 +63,18 @@ def _deepcopy_warp(model):
return new_model


def _preprocess_model_quant_config(model, quant_config):
"""Preprocess model and quant config before quantization."""
for config in quant_config.config_set:
# handle tokenizer attribute in AutoRoundConfig
if isinstance(config, AutoRoundConfig):
_tokenizer_backup = getattr(config, "tokenizer", None)
if _tokenizer_backup is not None:
setattr(model, "tokenizer", _tokenizer_backup)
delattr(config, "tokenizer")
return model, quant_config


@dump_elapsed_time("Pass auto-tune")
def autotune(
model: torch.nn.Module,
Expand All @@ -88,6 +100,7 @@ def autotune(
The quantized model.
"""
best_quant_model = None
model, tune_config = _preprocess_model_quant_config(model, tune_config)
eval_func_wrapper = EvaluationFuncWrapper(eval_fn, eval_args)
config_loader, tuning_logger, tuning_monitor = init_tuning(tuning_config=tune_config)
baseline: float = eval_func_wrapper.evaluate(_deepcopy_warp(model))
Expand Down
45 changes: 41 additions & 4 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@


import copy
import importlib
import inspect
import json
from collections import OrderedDict
from typing import Any, Callable, Dict, List, NamedTuple, Optional
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional
from typing import OrderedDict as OrderedDictType
from typing import Tuple, Union

Expand Down Expand Up @@ -99,6 +99,18 @@ def _get_op_name_op_type_config(self):
op_type_config_dict[name] = config
return op_type_config_dict, op_name_config_dict

@classmethod
def _generate_params_list(cls) -> List[str]:
sig = inspect.signature(cls.__init__)
params_list = list(sig.parameters.keys())[1:]
if "white_list" in params_list:
params_list.remove("white_list")
if "args" in params_list:
params_list.remove("args")
if "kwargs" in params_list:
params_list.remove("kwargs")
return params_list


######################## RNT Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN, priority=PRIORITY_RTN)
Expand Down Expand Up @@ -976,7 +988,7 @@ def __init__(
enable_torch_compile: bool = False,
# v0.7
scheme: str | dict = "W4A16",
device_map: [str, int, torch.device, dict] = 0,
device_map: str | int | torch.device | dict = 0,
# mllm
quant_nontext_module: bool = False,
extra_data_dir: str = None,
Expand All @@ -987,6 +999,15 @@ def __init__(
quant_lm_head: bool = False,
# v0.8
enable_adam: bool = False,
# v0.9: auto scheme parameters
target_bits: int = None,
options: Union[str, list[Union[str]], tuple[Union[str], ...]] = ("MXFP4", "MXFP8"),
shared_layers: Optional[Iterable[Iterable[str]]] = None,
ignore_scale_zp_bits: bool = False,
auto_scheme_method: str = "default",
auto_scheme_device_map: str = None,
auto_scheme_batch_size: int = None,
# Tuning space
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
**kwargs,
):
Expand Down Expand Up @@ -1039,9 +1060,17 @@ def __init__(
device_map: The device to be used for tuning.
scheme (str| dict | QuantizationScheme ): A preset scheme that defines the quantization configurations.
white_list (Optional[List[OP_NAME_OR_MODULE_TYPE]]): White list of operator names or module types.
Default is DEFAULT_WHITE_LIST.
target_bits (int): The target bit width for quantization (default is None).
options (Union[str, list[Union[str]], tuple[Union[str], ...]]): The options for mixed-precision quantization.
shared_layers (Optional[Iterable[Iterable[str]]]): The shared layers for mixed-precision quantization.
ignore_scale_zp_bits (bool): Whether to ignore scale and zero-point bits (default is False).
auto_scheme_method (str): The method for automatic scheme selection (default is "default").
auto_scheme_device_map (str): The device map for automatic scheme selection (default is None).
auto_scheme_batch_size (int): The batch size for automatic scheme selection (default is 8).
"""
super().__init__(white_list=white_list)
self.params_list = self.__class__._generate_params_list()
self.params_list.remove("options") # option is a list but not a tunable parameter

self.enable_full_range = enable_full_range
self.batch_size = batch_size
Expand All @@ -1057,6 +1086,7 @@ def __init__(
self.super_bits = super_bits
self.super_group_size = super_group_size
self.amp = amp
self.enable_adam = enable_adam
self.lr_scheduler = lr_scheduler
self.enable_quanted_input = enable_quanted_input
self.enable_minmax_tuning = enable_minmax_tuning
Expand Down Expand Up @@ -1087,6 +1117,13 @@ def __init__(
self.scheme = scheme
self.device_map = device_map
self.quant_lm_head = quant_lm_head
self.target_bits = target_bits
self.options = options
self.shared_layers = shared_layers
self.ignore_scale_zp_bits = ignore_scale_zp_bits
self.auto_scheme_method = auto_scheme_method
self.auto_scheme_device_map = auto_scheme_device_map
self.auto_scheme_batch_size = auto_scheme_batch_size
# add kwargs
for k, v in kwargs.items():
setattr(self, k, v)
Expand Down
Loading
Loading