From 51d821099ac184b16ff335cf60b0dedef756de39 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 12 Jan 2025 23:52:58 -0500 Subject: [PATCH 1/7] add bf16 for mixed precision Signed-off-by: yiliu30 --- neural_compressor/common/base_config.py | 8 +++++++- .../torch/algorithms/pt2e_quant/core.py | 4 +++- .../pt2e_quant/half_precision_rewriter.py | 5 +++-- .../torch/algorithms/pt2e_quant/utility.py | 3 ++- .../torch/quantization/autotune.py | 15 ++++++++++++--- test/3x/torch/quantization/test_pt2e_quant.py | 19 +++++++++++++++++++ 6 files changed, 46 insertions(+), 8 deletions(-) diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 97c50c2333d..2e0e195b0e3 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -18,6 +18,7 @@ from __future__ import annotations +import copy import inspect import json import os @@ -539,6 +540,7 @@ def expand(self) -> List[BaseConfig]: tuning_param_pair = dict(zip(tuning_param_name_lst, params_values)) tmp_params_dict = {**not_tuning_param_pair, **tuning_param_pair} new_config = self.__class__(**tmp_params_dict) + new_config.local_config = copy.deepcopy(self.local_config) logger.info(new_config.to_dict()) config_list.append(new_config) logger.info("Expanded the %s and got %d configs.", self.__class__.name, len(config_list)) @@ -629,9 +631,13 @@ def __eq__(self, other: BaseConfig) -> bool: """ if not isinstance(other, type(self)): return False - return self.params_list == other.params_list and all( + + params_equal = self.params_list == other.params_list and all( getattr(self, str(attr)) == getattr(other, str(attr)) for attr in self.params_list ) + local_config_equal = self.local_config == other.local_config + global_config_equal = self.global_config == other.global_config + return params_equal and local_config_equal and global_config_equal class ComposableConfig(BaseConfig): diff --git a/neural_compressor/torch/algorithms/pt2e_quant/core.py b/neural_compressor/torch/algorithms/pt2e_quant/core.py index 4707295cd32..491e2b1456d 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/core.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/core.py @@ -18,6 +18,7 @@ from typing import Any +import torch import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer @@ -102,4 +103,5 @@ def half_precision_transformation(self, model, config): """ half_precision_node_set = hp_rewriter.get_half_precision_node_set(model, config) logger.info("Try to convert %d nodes to half precision.", len(half_precision_node_set)) - hp_rewriter.transformation(model, half_precision_node_set) + hp_rewriter.transformation(model, half_precision_node_set, torch.float16) + hp_rewriter.transformation(model, half_precision_node_set, torch.bfloat16) diff --git a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py index 0c3c328845f..37c3cf5f847 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py @@ -107,6 +107,7 @@ def _register_pattern_pair(dtype: torch.dtype) -> None: _register_pattern_pair(torch.float16) +_register_pattern_pair(torch.bfloat16) def get_filter_fn(node_list, fn): @@ -201,11 +202,11 @@ def _parse_node_candidate_set_from_user_config(config, gm): op_name_filters = [] for op_type_name, config in op_type_configs.items(): # pragma: no cover op_type = getattr(torch.nn, op_type_name) - if config.act_dtype == "fp16": # pragma: no cover + if config.act_dtype in ["fp16", "bf16"]: # pragma: no cover filter = xpq._get_module_type_filter(op_type) op_type_filters.append(filter) for op_name, config in op_name_configs.items(): - if config.act_dtype == "fp16": # pragma: no cover + if config.act_dtype in ["fp16", "bf16"]: # pragma: no cover filter = xpq._get_module_name_filter(op_name) op_name_filters.append(filter) node_set_from_user_config = set() diff --git a/neural_compressor/torch/algorithms/pt2e_quant/utility.py b/neural_compressor/torch/algorithms/pt2e_quant/utility.py index e31efabf0a6..409c8dacdbc 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/utility.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/utility.py @@ -26,7 +26,7 @@ from torch.ao.quantization.quantizer import QuantizationSpec from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer -from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5 +from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5, logger def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec: @@ -79,6 +79,7 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> QuantizationConfig: NOT_QUANT_DTYPES = ["fp32", "fp16", "bf16"] if inc_config.act_dtype in NOT_QUANT_DTYPES and inc_config.w_dtype in NOT_QUANT_DTYPES: # pragma: no cover + logger.debug("Got non-quantizable data types, skipping quantization.") return None default_quant_config = xiq.get_default_x86_inductor_quantization_config(is_dynamic=is_dynamic) input_act_quant_spec = create_quant_spec_from_config( diff --git a/neural_compressor/torch/quantization/autotune.py b/neural_compressor/torch/quantization/autotune.py index 2c6dcaa768f..fe290908127 100644 --- a/neural_compressor/torch/quantization/autotune.py +++ b/neural_compressor/torch/quantization/autotune.py @@ -54,6 +54,15 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]: return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME) +def _deepcopy_warp(model): + additional_attr_lst = ["_exported", "dynamic_shapes"] + original_attr = {key: getattr(model, key, None) for key in additional_attr_lst} + new_model = deepcopy(model) + for key, value in original_attr.items(): + setattr(new_model, key, value) + return new_model + + @dump_elapsed_time("Pass auto-tune") def autotune( model: torch.nn.Module, @@ -81,7 +90,7 @@ def autotune( best_quant_model = None eval_func_wrapper = EvaluationFuncWrapper(eval_fn, eval_args) config_loader, tuning_logger, tuning_monitor = init_tuning(tuning_config=tune_config) - baseline: float = eval_func_wrapper.evaluate(deepcopy(model)) + baseline: float = eval_func_wrapper.evaluate(_deepcopy_warp(model)) tuning_monitor.set_baseline(baseline) tuning_logger.tuning_start() for trial_index, quant_config in enumerate(config_loader, 1): @@ -90,7 +99,7 @@ def autotune( logger.info(quant_config.to_dict()) # !!! Make sure to use deepcopy only when inplace is set to `True`. q_model = quantize( - deepcopy(model), + _deepcopy_warp(model), quant_config=quant_config, run_fn=run_fn, run_args=run_args, @@ -112,7 +121,7 @@ def autotune( best_quant_config: BaseConfig = best_trial_record.quant_config # !!! Make sure to use deepcopy only when inplace is set to `True`. q_model = quantize( - deepcopy(model), + _deepcopy_warp(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, diff --git a/test/3x/torch/quantization/test_pt2e_quant.py b/test/3x/torch/quantization/test_pt2e_quant.py index 0ef8571157d..22372d18bd8 100644 --- a/test/3x/torch/quantization/test_pt2e_quant.py +++ b/test/3x/torch/quantization/test_pt2e_quant.py @@ -283,3 +283,22 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex): opt_model = torch.compile(converted_model) out = opt_model(*example_inputs) assert out is not None + + @pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5") + @pytest.mark.parametrize("half_precision_dtype", ["fp16", "bf16"]) + def test_auto_tune_mixed_int8_and_16bits(self, half_precision_dtype, force_not_import_ipex): + # config1: int8 for all + # config2: half precision for linear + from neural_compressor.torch.quantization.config import INT8StaticQuantConfig + from neural_compressor.torch.quantization.autotune import autotune, TuningConfig + config1 = INT8StaticQuantConfig() + config2 = INT8StaticQuantConfig().set_local("fc1", StaticQuantConfig(w_dtype=half_precision_dtype, act_dtype=half_precision_dtype)) + tune_config = TuningConfig(config_set=[config1, config2], tolerable_loss=-0.1) + def fake_eval_fn(model): + return 1.0 + def run_fn(model): + for i in range(2): + model(*example_inputs) + model, example_inputs = self.build_model_include_conv_and_linear() + model = export(model, example_inputs=example_inputs) + qmodel = autotune(model=model, tune_config=tune_config, eval_fn=fake_eval_fn,run_fn=run_fn, example_inputs=example_inputs) From e4a4fb3ac474a71f102f3717e759127db48942e2 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 20 Jan 2025 03:33:24 -0500 Subject: [PATCH 2/7] add more ops Signed-off-by: yiliu30 --- .../pt2e_quant/half_precision_rewriter.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py index 37c3cf5f847..618633d4185 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py @@ -54,16 +54,34 @@ class PatternPair: # Align with https://pytorch.org/docs/stable/amp.html#cpu-ops-that-can-autocast-to-bfloat16 -# TODO: complete the mapping +# conv1d, conv2d, conv3d, bmm, mm, linalg_vecdot, baddbmm, addmm, addbmm, +# linear, matmul, _convolution, conv_tbc, mkldnn_rnn_layer, conv_transpose1d, +# conv_transpose2d, conv_transpose3d, prelu, scaled_dot_product_attention, _native_multi_head_attention + FN_ARGS_MAPPING: FuncArgsMappingType = { torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0)), # linear w/o bias torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0), torch.randn(0)), # linear w/ bias + torch.nn.functional.conv2d: (torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0)), # conv2d w/o bias + torch.nn.functional.conv2d: (torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0), torch.randn(0)), # conv2d w/ bias + torch.bmm: (torch.randn(0, 0, 0), torch.randn(0, 0, 0)), # bmm + torch.mm: (torch.randn(0, 0), torch.randn(0, 0)), # mm } -# TODO: complete the mapping -FN_ATEN_OPS_MAPPING = { - torch.nn.functional.linear: torch.ops.aten.linear.default, + +# module cls -> function name +NN_MODULES_MAPPING = { + torch.nn.Linear: torch.nn.functional.linear, + torch.nn.Conv2d: torch.nn.functional.conv2d, + torch.nn.MaxPool2d: torch.nn.functional.max_pool2d, } +for nn_cls, fn in NN_MODULES_MAPPING.items(): + if fn in FN_ARGS_MAPPING: + FN_ARGS_MAPPING[nn_cls] = FN_ARGS_MAPPING[fn] + + +# Use the mapping from xiq +FN_ATEN_OPS_MAPPING = xiq._map_module_function_to_aten_operator_type() + SUPPORTED_OPERATORS = FN_ATEN_OPS_MAPPING.values() @@ -101,7 +119,7 @@ def _register_pattern_pair(dtype: torch.dtype) -> None: for fn, fn_args in FN_ARGS_MAPPING.items(): pattern_pair = pattern_factory(fn, fn_args) HALF_PRECISION_PATTERN_REGISTRY[dtype][fn] = pattern_pair - utils.logger.info( + utils.logger.debug( f"Registered {len(HALF_PRECISION_PATTERN_REGISTRY[dtype])} search and replace patterns for {dtype}." ) From de59f73cdd3254a05a6641fb2af82a02133938e7 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 21 Jan 2025 04:02:58 -0500 Subject: [PATCH 3/7] add more ops Signed-off-by: yiliu30 --- .../pt2e_quant/half_precision_rewriter.py | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py index 618633d4185..82375e132be 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py @@ -25,7 +25,7 @@ from torch.fx.subgraph_rewriter import Match from typing_extensions import TypeAlias -from neural_compressor.common import utils +from neural_compressor.common import logger, utils # ============================================================================= # Search and replace patterns @@ -53,32 +53,24 @@ class PatternPair: FuncArgsMappingType: TypeAlias = Dict[TorchFuncType, Tuple[torch.Tensor, ...]] -# Align with https://pytorch.org/docs/stable/amp.html#cpu-ops-that-can-autocast-to-bfloat16 -# conv1d, conv2d, conv3d, bmm, mm, linalg_vecdot, baddbmm, addmm, addbmm, -# linear, matmul, _convolution, conv_tbc, mkldnn_rnn_layer, conv_transpose1d, -# conv_transpose2d, conv_transpose3d, prelu, scaled_dot_product_attention, _native_multi_head_attention - +# Align with xiq, as it relay on xiq's set_module_xx capability FN_ARGS_MAPPING: FuncArgsMappingType = { torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0)), # linear w/o bias torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0), torch.randn(0)), # linear w/ bias - torch.nn.functional.conv2d: (torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0)), # conv2d w/o bias - torch.nn.functional.conv2d: (torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0), torch.randn(0)), # conv2d w/ bias - torch.bmm: (torch.randn(0, 0, 0), torch.randn(0, 0, 0)), # bmm - torch.mm: (torch.randn(0, 0), torch.randn(0, 0)), # mm + torch.nn.functional.conv2d: (torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1)), # conv2d w/o bias + torch.nn.functional.conv2d: (torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1), torch.randn(1)), # conv2d w/ bias + torch.matmul: (torch.randn(0, 0), torch.randn(0, 0)), # matmul + torch.matmul: (torch.randn(0, 0, 0), torch.randn(0, 0, 0)), # matmul + torch.matmul: (torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0)), # matmul } -# module cls -> function name -NN_MODULES_MAPPING = { +# module cls <-> function name +NN_MODULES_TO_NN_FN = { torch.nn.Linear: torch.nn.functional.linear, torch.nn.Conv2d: torch.nn.functional.conv2d, torch.nn.MaxPool2d: torch.nn.functional.max_pool2d, } -for nn_cls, fn in NN_MODULES_MAPPING.items(): - if fn in FN_ARGS_MAPPING: - FN_ARGS_MAPPING[nn_cls] = FN_ARGS_MAPPING[fn] - - # Use the mapping from xiq FN_ATEN_OPS_MAPPING = xiq._map_module_function_to_aten_operator_type() @@ -117,6 +109,7 @@ def replace_fn_wrapper(fn_args, fn): def _register_pattern_pair(dtype: torch.dtype) -> None: for fn, fn_args in FN_ARGS_MAPPING.items(): + logger.debug(f"Registering search and replace patterns for {fn} with args: {fn_args}.") pattern_pair = pattern_factory(fn, fn_args) HALF_PRECISION_PATTERN_REGISTRY[dtype][fn] = pattern_pair utils.logger.debug( From d7f99dad97e5182b859bded499f60364f2507150 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 22 Jan 2025 04:16:14 -0500 Subject: [PATCH 4/7] add more uts Signed-off-by: yiliu30 --- .../pt2e_quant/half_precision_rewriter.py | 50 +++++++++------ test/3x/torch/quantization/test_pt2e_quant.py | 62 ++++++++++++++----- 2 files changed, 80 insertions(+), 32 deletions(-) diff --git a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py index 82375e132be..b95681174ae 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py @@ -13,6 +13,7 @@ # limitations under the License. """Rewrite the FP32 operators to FP16 or BF16 operators.""" +from collections import defaultdict from dataclasses import dataclass from functools import partial from typing import Any, Callable, Dict, List, Tuple @@ -50,25 +51,31 @@ class PatternPair: # key: torch func # value: the tuple of args -FuncArgsMappingType: TypeAlias = Dict[TorchFuncType, Tuple[torch.Tensor, ...]] +FuncArgsMappingType: TypeAlias = Dict[TorchFuncType, List[Tuple[torch.Tensor, ...]]] # Align with xiq, as it relay on xiq's set_module_xx capability FN_ARGS_MAPPING: FuncArgsMappingType = { - torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0)), # linear w/o bias - torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0), torch.randn(0)), # linear w/ bias - torch.nn.functional.conv2d: (torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1)), # conv2d w/o bias - torch.nn.functional.conv2d: (torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1), torch.randn(1)), # conv2d w/ bias - torch.matmul: (torch.randn(0, 0), torch.randn(0, 0)), # matmul - torch.matmul: (torch.randn(0, 0, 0), torch.randn(0, 0, 0)), # matmul - torch.matmul: (torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0)), # matmul + # Note: ORDER is matter + torch.nn.functional.linear: [ + (torch.randn(0, 0), torch.randn(0, 0)), # linear w/o bias + (torch.randn(0, 0), torch.randn(0, 0), torch.randn(0)), # linear w/ bias + ], + torch.nn.functional.conv2d: [ + (torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1)), # conv2d w/o bias + (torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1), torch.randn(1)), # conv2d w/ bias + ], + torch.matmul: [ + (torch.randn(0, 0), torch.randn(0, 0)), + (torch.randn(0, 0, 0), torch.randn(0, 0, 0)), + (torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0)), + ], } # module cls <-> function name NN_MODULES_TO_NN_FN = { torch.nn.Linear: torch.nn.functional.linear, torch.nn.Conv2d: torch.nn.functional.conv2d, - torch.nn.MaxPool2d: torch.nn.functional.max_pool2d, } # Use the mapping from xiq @@ -78,7 +85,10 @@ class PatternPair: PatternRegistryType: TypeAlias = Dict[TorchFuncType, PatternPair] -HALF_PRECISION_PATTERN_REGISTRY: Dict[torch.dtype, PatternRegistryType] = {torch.float16: {}, torch.bfloat16: {}} +HALF_PRECISION_PATTERN_REGISTRY: Dict[torch.dtype, PatternRegistryType] = { + torch.float16: defaultdict(list), + torch.bfloat16: defaultdict(list), +} # FP16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.float16] # BF16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.bfloat16] @@ -108,10 +118,11 @@ def replace_fn_wrapper(fn_args, fn): def _register_pattern_pair(dtype: torch.dtype) -> None: - for fn, fn_args in FN_ARGS_MAPPING.items(): - logger.debug(f"Registering search and replace patterns for {fn} with args: {fn_args}.") - pattern_pair = pattern_factory(fn, fn_args) - HALF_PRECISION_PATTERN_REGISTRY[dtype][fn] = pattern_pair + for fn, fn_args_lst in FN_ARGS_MAPPING.items(): + for fn_args in fn_args_lst: + logger.debug(f"Registering search and replace patterns for {fn} with args: {fn_args}.") + pattern_pair = pattern_factory(fn, fn_args) + HALF_PRECISION_PATTERN_REGISTRY[dtype][fn].append(pattern_pair) utils.logger.debug( f"Registered {len(HALF_PRECISION_PATTERN_REGISTRY[dtype])} search and replace patterns for {dtype}." ) @@ -194,9 +205,10 @@ def get_unquantized_node_set(gm: torch.fx.GraphModule): def transformation(gm: torch.fx.GraphModule, node_candidate_list: List[str], target_dtype: torch.dtype = torch.float16): """Convert the nodes in `node_candidate_list` to `target_dtype` if possible.""" - for pattern_pair in HALF_PRECISION_PATTERN_REGISTRY[target_dtype].values(): - apply_single_pattern_pair(gm, pattern_pair, node_candidate_list) - utils.logger.info("Half precision conversion is done:") + for pattern_pair_lst in HALF_PRECISION_PATTERN_REGISTRY[target_dtype].values(): + for pattern_pair in pattern_pair_lst: + apply_single_pattern_pair(gm, pattern_pair, node_candidate_list) + utils.logger.info(f"Half precision conversion({target_dtype}) completed.") if utils.level_name == "DEBUG": # pragma: no cover gm.print_readable(True) @@ -249,5 +261,7 @@ def get_half_precision_node_set(gm, config): for node in possible_node_set: if node.target in SUPPORTED_OPERATORS: half_precision_node_set.add(node) - utils.logger.info(f"Found {len(half_precision_node_set)} nodes to convert to half precision.") + utils.logger.info( + f"Found {len(half_precision_node_set)} nodes to convert to half precision: {half_precision_node_set}" + ) return half_precision_node_set diff --git a/test/3x/torch/quantization/test_pt2e_quant.py b/test/3x/torch/quantization/test_pt2e_quant.py index 22372d18bd8..642d16b5239 100644 --- a/test/3x/torch/quantization/test_pt2e_quant.py +++ b/test/3x/torch/quantization/test_pt2e_quant.py @@ -29,7 +29,6 @@ def _is_ipex_imported(): monkeypatch.setattr("neural_compressor.torch.quantization.algorithm_entry.is_ipex_imported", _is_ipex_imported) monkeypatch.setattr("neural_compressor.torch.export.pt2e_export.is_ipex_imported", _is_ipex_imported) - class TestPT2EQuantization: def teardown_class(self): shutil.rmtree("saved_results", ignore_errors=True) @@ -53,15 +52,15 @@ def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return bar, example_inputs @staticmethod - def build_model_include_conv_and_linear(): + def build_model_include_conv_and_linear(bias=True): class Model(torch.nn.Module): - def __init__(self): + def __init__(self, bias=True): super(Model, self).__init__() - self.conv1 = torch.nn.Conv2d(3, 6, 5) + self.conv1 = torch.nn.Conv2d(3, 6, 5, bias=bias) self.pool = torch.nn.MaxPool2d(2, 2) - self.conv2 = torch.nn.Conv2d(6, 16, 5) - self.fc1 = torch.nn.Linear(16 * 5 * 5, 120) - self.fc2 = torch.nn.Linear(120, 84) + self.conv2 = torch.nn.Conv2d(6, 16, 5, bias=bias) + self.fc1 = torch.nn.Linear(16 * 5 * 5, 120, bias=bias) + self.fc2 = torch.nn.Linear(120, 84, bias=bias) def forward(self, x): x = self.conv1(x) @@ -74,7 +73,7 @@ def forward(self, x): return x - model = Model() + model = Model(bias) example_inputs = (torch.randn(1, 3, 32, 32),) return model, example_inputs @@ -286,19 +285,54 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex): @pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5") @pytest.mark.parametrize("half_precision_dtype", ["fp16", "bf16"]) - def test_auto_tune_mixed_int8_and_16bits(self, half_precision_dtype, force_not_import_ipex): + @pytest.mark.parametrize("op_name", ["conv1", "fc1"]) + @pytest.mark.parametrize("bias", [True, False]) + def test_auto_tune_mixed_int8_and_16bits(self, half_precision_dtype, op_name, bias, force_not_import_ipex): + # Test for auto-tune with mixed int8 and 16bits + # Just make sure the pattern matches, not the accuracy. # config1: int8 for all - # config2: half precision for linear + # config2: half precision for linear/conv from neural_compressor.torch.quantization.config import INT8StaticQuantConfig from neural_compressor.torch.quantization.autotune import autotune, TuningConfig + config1 = INT8StaticQuantConfig() - config2 = INT8StaticQuantConfig().set_local("fc1", StaticQuantConfig(w_dtype=half_precision_dtype, act_dtype=half_precision_dtype)) + config2 = INT8StaticQuantConfig().set_local( + op_name, StaticQuantConfig(w_dtype=half_precision_dtype, act_dtype=half_precision_dtype) + ) tune_config = TuningConfig(config_set=[config1, config2], tolerable_loss=-0.1) + eval_result = [1, 1, 2] + def fake_eval_fn(model): - return 1.0 + res = eval_result.pop(0) + return res + def run_fn(model): for i in range(2): model(*example_inputs) - model, example_inputs = self.build_model_include_conv_and_linear() + + model, example_inputs = self.build_model_include_conv_and_linear(bias) model = export(model, example_inputs=example_inputs) - qmodel = autotune(model=model, tune_config=tune_config, eval_fn=fake_eval_fn,run_fn=run_fn, example_inputs=example_inputs) + qmodel = autotune( + model=model, tune_config=tune_config, eval_fn=fake_eval_fn, run_fn=run_fn, example_inputs=example_inputs + ) + + # check the half node + expected_node_occurrence = { + # 4 `aten.to` for target op if bias else 3 + torch.ops.aten.to.dtype: (3 + int(bias)) + } + expected_node_occurrence = { + torch_test_quant_common.NodeSpec.call_function(k): v for k, v in expected_node_occurrence.items() + } + node_in_graph = self.get_node_in_graph(qmodel) + for node, cnt in expected_node_occurrence.items(): + assert ( + node_in_graph.get(node, 0) == cnt + ), f"Node {node} should occur {cnt} times, but {node_in_graph.get(node, 0)}" + # inference + from torch._inductor import config + + config.freezing = True + opt_model = torch.compile(qmodel) + out = opt_model(*example_inputs) + assert out is not None From a468783862202fff8770d45e81e543c10249a3ab Mon Sep 17 00:00:00 2001 From: Xin He Date: Wed, 22 Jan 2025 17:17:57 +0800 Subject: [PATCH 5/7] disable conv1d for rtn on HPU (#2112) --- test/3x/torch/quantization/weight_only/test_rtn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index 4f2b9c44752..352e0246001 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -309,6 +309,8 @@ def test_rtn_with_quantize_API(self): ), "The results of calling `convert` + `prepare` and calling `quantize` should be equal." # TODO: (4, True, 32, 0), group_dim=0, format not supported + # TODO [SW-216127]: it's not in high priority, so we can implement it later. + @pytest.mark.skipif(is_hpex_available(), reason="These tests are not supported on HPU for now.") @pytest.mark.parametrize( "bits, use_sym, group_size, group_dim", [ From ee23363fe64204ddfae61cdc27d467e24e5162a5 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 22 Jan 2025 04:20:16 -0500 Subject: [PATCH 6/7] minor fix Signed-off-by: yiliu30 --- test/3x/torch/quantization/test_pt2e_quant.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/3x/torch/quantization/test_pt2e_quant.py b/test/3x/torch/quantization/test_pt2e_quant.py index 642d16b5239..e593f3e1528 100644 --- a/test/3x/torch/quantization/test_pt2e_quant.py +++ b/test/3x/torch/quantization/test_pt2e_quant.py @@ -288,7 +288,6 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex): @pytest.mark.parametrize("op_name", ["conv1", "fc1"]) @pytest.mark.parametrize("bias", [True, False]) def test_auto_tune_mixed_int8_and_16bits(self, half_precision_dtype, op_name, bias, force_not_import_ipex): - # Test for auto-tune with mixed int8 and 16bits # Just make sure the pattern matches, not the accuracy. # config1: int8 for all # config2: half precision for linear/conv @@ -318,7 +317,7 @@ def run_fn(model): # check the half node expected_node_occurrence = { - # 4 `aten.to` for target op if bias else 3 + # 4 `aten.to` for target op if bias else 3 `aten.to` torch.ops.aten.to.dtype: (3 + int(bias)) } expected_node_occurrence = { From c1807391b013361c38f2a4b75f625cdbbe2d0836 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 23 Jan 2025 01:55:26 -0500 Subject: [PATCH 7/7] add op_type Signed-off-by: yiliu30 --- test/3x/torch/quantization/test_pt2e_quant.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/test/3x/torch/quantization/test_pt2e_quant.py b/test/3x/torch/quantization/test_pt2e_quant.py index e593f3e1528..96410054e6b 100644 --- a/test/3x/torch/quantization/test_pt2e_quant.py +++ b/test/3x/torch/quantization/test_pt2e_quant.py @@ -285,9 +285,9 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex): @pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5") @pytest.mark.parametrize("half_precision_dtype", ["fp16", "bf16"]) - @pytest.mark.parametrize("op_name", ["conv1", "fc1"]) + @pytest.mark.parametrize("op_name_or_type", ["conv1", "fc1", torch.nn.Linear, torch.nn.Conv2d]) @pytest.mark.parametrize("bias", [True, False]) - def test_auto_tune_mixed_int8_and_16bits(self, half_precision_dtype, op_name, bias, force_not_import_ipex): + def test_auto_tune_mixed_int8_and_16bits(self, half_precision_dtype, op_name_or_type, bias, force_not_import_ipex): # Just make sure the pattern matches, not the accuracy. # config1: int8 for all # config2: half precision for linear/conv @@ -296,7 +296,7 @@ def test_auto_tune_mixed_int8_and_16bits(self, half_precision_dtype, op_name, bi config1 = INT8StaticQuantConfig() config2 = INT8StaticQuantConfig().set_local( - op_name, StaticQuantConfig(w_dtype=half_precision_dtype, act_dtype=half_precision_dtype) + op_name_or_type, StaticQuantConfig(w_dtype=half_precision_dtype, act_dtype=half_precision_dtype) ) tune_config = TuningConfig(config_set=[config1, config2], tolerable_loss=-0.1) eval_result = [1, 1, 2] @@ -315,11 +315,17 @@ def run_fn(model): model=model, tune_config=tune_config, eval_fn=fake_eval_fn, run_fn=run_fn, example_inputs=example_inputs ) - # check the half node + # Calculate the expected number of `aten.to` operations based on bias and op_name_or_type + """ + | Bias | op_name | nn.Module | + |-------|---------|-----------| + | True | 4 | 8 | + | False | 3 | 6 | + """ expected_node_occurrence = { - # 4 `aten.to` for target op if bias else 3 `aten.to` - torch.ops.aten.to.dtype: (3 + int(bias)) + torch.ops.aten.to.dtype: (3 + int(bias)) * (1 if isinstance(op_name_or_type, str) else 2) } + expected_node_occurrence = { torch_test_quant_common.NodeSpec.call_function(k): v for k, v in expected_node_occurrence.items() }