diff --git a/.azure-pipelines/scripts/models/env_setup.sh b/.azure-pipelines/scripts/models/env_setup.sh index 532e5c333d5..2c356d7307b 100644 --- a/.azure-pipelines/scripts/models/env_setup.sh +++ b/.azure-pipelines/scripts/models/env_setup.sh @@ -83,27 +83,30 @@ if [[ "${inc_new_api}" == "false" ]]; then fi cd ${model_src_dir} -pip install ruamel.yaml==0.17.40 -pip install psutil -pip install protobuf==4.23.4 -if [[ "${framework}" == "tensorflow" ]]; then - if [[ "${fwk_ver}" == *"-official" ]]; then - pip install tensorflow==${fwk_ver%-official} - else - pip install intel-tensorflow==${fwk_ver} + +if [[ "${fwk_ver}" != "latest" ]]; then + pip install ruamel.yaml==0.17.40 + pip install psutil + pip install protobuf==4.23.4 + if [[ "${framework}" == "tensorflow" ]]; then + if [[ "${fwk_ver}" == *"-official" ]]; then + pip install tensorflow==${fwk_ver%-official} + else + pip install intel-tensorflow==${fwk_ver} + fi + elif [[ "${framework}" == "pytorch" ]]; then + pip install torch==${fwk_ver} -f https://download.pytorch.org/whl/torch_stable.html + pip install torchvision==${torch_vision_ver} -f https://download.pytorch.org/whl/torch_stable.html + elif [[ "${framework}" == "onnxrt" ]]; then + pip install onnx==1.15.0 + pip install onnxruntime==${fwk_ver} + elif [[ "${framework}" == "mxnet" ]]; then + pip install numpy==1.23.5 + echo "re-install pycocotools resolve the issue with numpy..." + pip uninstall pycocotools -y + pip install --no-cache-dir pycocotools + pip install mxnet==${fwk_ver} fi -elif [[ "${framework}" == "pytorch" ]]; then - pip install torch==${fwk_ver} -f https://download.pytorch.org/whl/torch_stable.html - pip install torchvision==${torch_vision_ver} -f https://download.pytorch.org/whl/torch_stable.html -elif [[ "${framework}" == "onnxrt" ]]; then - pip install onnx==1.15.0 - pip install onnxruntime==${fwk_ver} -elif [[ "${framework}" == "mxnet" ]]; then - pip install numpy==1.23.5 - echo "re-install pycocotools resolve the issue with numpy..." - pip uninstall pycocotools -y - pip install --no-cache-dir pycocotools - pip install mxnet==${fwk_ver} fi if [ -f "requirements.txt" ]; then diff --git a/.azure-pipelines/scripts/models/run_model_trigger_common.sh b/.azure-pipelines/scripts/models/run_model_trigger_common.sh index 27886a000bc..6d0bb6ead05 100644 --- a/.azure-pipelines/scripts/models/run_model_trigger_common.sh +++ b/.azure-pipelines/scripts/models/run_model_trigger_common.sh @@ -48,6 +48,13 @@ do esac done +function check_results() { + local control_phrase=$1 + if [ $(grep "${control_phrase}" ${log_dir}/${model}/${framework}-${model}-tune.log | wc -l) == 0 ];then + $BOLD_RED && echo "====== Quantization FAILED!! ======" && $RESET; exit 1 + fi +} + log_dir="/neural-compressor/.azure-pipelines/scripts/models" SCRIPTS_PATH="/neural-compressor/.azure-pipelines/scripts/models" if [[ "${inc_new_api}" == "3x"* ]]; then @@ -90,16 +97,19 @@ elif [ "${mode}" == "tuning" ]; then 2>&1 | tee -a ${log_dir}/${model}/${framework}-${model}-tune.log $BOLD_YELLOW && echo "====== check tuning status. ======" && $RESET if [[ "${inc_new_api}" == "3x"* ]]; then - control_phrase="Quantization end." + control_phrase_1="Preparation end." + check_results $control_phrase_1 + control_phrase_2="Conversion end." + check_results $control_phrase_2 else control_phrase="model which meet accuracy goal." + check_results $control_phrase + if [ $(grep "${control_phrase}" ${log_dir}/${model}/${framework}-${model}-tune.log | grep "Not found" | wc -l) == 1 ];then + $BOLD_RED && echo "====== Quantization FAILED!! ======" && $RESET; exit 1 + fi fi - if [ $(grep "${control_phrase}" ${log_dir}/${model}/${framework}-${model}-tune.log | wc -l) == 0 ];then - $BOLD_RED && echo "====== Quantization FAILED!! ======" && $RESET; exit 1 - fi - if [ $(grep "${control_phrase}" ${log_dir}/${model}/${framework}-${model}-tune.log | grep "Not found" | wc -l) == 1 ];then - $BOLD_RED && echo "====== Quantization FAILED!! ======" && $RESET; exit 1 - fi + + $BOLD_GREEN && echo "====== Quantization SUCCEED!! ======" && $RESET elif [ "${mode}" == "fp32_benchmark" ]; then cd ${WORK_SOURCE_DIR}/${model_src_dir} @@ -149,6 +159,10 @@ elif [ "${mode}" == "collect_log" ]; then cd ${WORK_SOURCE_DIR}/${model_src_dir} $BOLD_YELLOW && echo "workspace ${WORK_SOURCE_DIR}/${model_src_dir}" && $RESET $BOLD_YELLOW && echo "====== collect logs of model ${model} =======" && $RESET + if [ "${framework}" == "pytorch" ] && [ "${fwk_ver}" == "latest" ]; then + fwk_ver=$(python -c "import torch; print(torch.__version__)") + fi + python -u ${SCRIPTS_PATH}/collect_log_model.py \ --framework=${framework} \ --fwk_ver=${fwk_ver} \ diff --git a/.azure-pipelines/scripts/models/run_pytorch_models_trigger.sh b/.azure-pipelines/scripts/models/run_pytorch_models_trigger.sh index d8172bd7c2d..32bd2eb0109 100644 --- a/.azure-pipelines/scripts/models/run_pytorch_models_trigger.sh +++ b/.azure-pipelines/scripts/models/run_pytorch_models_trigger.sh @@ -21,12 +21,6 @@ do esac done -echo "specify FWs version..." -source /neural-compressor/.azure-pipelines/scripts/fwk_version.sh 'latest' -FRAMEWORK="pytorch" -FRAMEWORK_VERSION=${pytorch_version} -TORCH_VISION_VERSION=${torchvision_version} - dataset_location="" input_model="" yaml="" @@ -72,6 +66,17 @@ elif [ "${model}" == "opt_125m_woq_gptq_int4_dq_ggml" ]; then tuning_cmd="bash run_quant.sh --topology=opt_125m_woq_gptq_int4_dq_ggml" fi +echo "Specify FWs version..." + +FRAMEWORK="pytorch" +source /neural-compressor/.azure-pipelines/scripts/fwk_version.sh 'latest' +if [[ "${inc_new_api}" == "3x"* ]]; then + FRAMEWORK_VERSION="latest" +else + FRAMEWORK_VERSION=${pytorch_version} + TORCH_VISION_VERSION=${torchvision_version} +fi + /bin/bash run_model_trigger_common.sh \ --yaml=${yaml} \ diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py index 090474f4356..acd0e8987b6 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py @@ -236,11 +236,11 @@ def get_user_model(): # 3.x api if args.approach == 'weight_only': - from neural_compressor.torch.quantization import RTNConfig, GPTQConfig, quantize + from neural_compressor.torch.quantization import RTNConfig, GPTQConfig, prepare, convert, quantize from neural_compressor.torch.utils import get_double_quant_config weight_sym = True if args.woq_scheme == "sym" else False double_quant_config_dict = get_double_quant_config(args.double_quant_type) - + if args.woq_algo == "RTN": if args.double_quant_type is not None: double_quant_config_dict.update( @@ -269,9 +269,8 @@ def get_user_model(): double_quant_group_size=args.double_quant_group_size, ) quant_config.set_local("lm_head", RTNConfig(dtype="fp32")) - user_model = quantize( - model=user_model, quant_config=quant_config - ) + user_model = prepare(model=user_model, quant_config=quant_config) + user_model = convert(model=user_model) elif args.woq_algo == "GPTQ": from utils import DataloaderPreprocessor dataloaderPreprocessor = DataloaderPreprocessor( @@ -326,12 +325,12 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args): double_quant_group_size=args.double_quant_group_size, ) quant_config.set_local("lm_head", GPTQConfig(dtype="fp32")) - user_model = quantize( - model=user_model, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=(dataloader_for_calibration, ) - ) + user_model = prepare(model=user_model, quant_config=quant_config) + run_fn_for_gptq(user_model, dataloader_for_calibration) + user_model = convert(user_model) else: if args.sq: - from neural_compressor.torch.quantization import SmoothQuantConfig, quantize + from neural_compressor.torch.quantization import SmoothQuantConfig # alpha can be a float number of a list of float number. args.alpha = args.alpha if args.alpha == "auto" else eval(args.alpha) @@ -339,11 +338,11 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args): quant_config = SmoothQuantConfig(alpha=args.alpha, folding=False) else: quant_config = SmoothQuantConfig(alpha=args.alpha, folding=True) - + if re.search("gpt", user_model.config.model_type): quant_config.set_local("add", SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32")) else: - from neural_compressor.torch.quantization import quantize, get_default_static_config, StaticQuantConfig + from neural_compressor.torch.quantization import get_default_static_config, StaticQuantConfig quant_config = get_default_static_config() if re.search("gpt", user_model.config.model_type): @@ -364,12 +363,23 @@ def run_fn(model): except ValueError: pass return - + from utils import get_example_inputs example_inputs = get_example_inputs(user_model, calib_dataloader) - user_model = quantize( - model=user_model, quant_config=quant_config, example_inputs=example_inputs, run_fn=run_fn - ) + if args.sq: + # currently, smooth quant only support quantize API + # TODO: support prepare/convert API for smooth quant + from neural_compressor.torch.quantization import quantize + + user_model = quantize( + model=user_model, quant_config=quant_config, example_inputs=example_inputs, run_fn=run_fn + ) + else: + from neural_compressor.torch.quantization import prepare, convert + + user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(user_model) + user_model = convert(user_model) user_model.save(args.output_dir) @@ -394,7 +404,7 @@ def run_fn(model): user_model.eval() from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser eval_args = LMEvalParser( - model="hf", + model="hf", user_model=user_model, tokenizer=tokenizer, batch_size=args.batch_size, @@ -417,7 +427,7 @@ def run_fn(model): samples = args.iters * args.batch_size eval_args = LMEvalParser( - model="hf", + model="hf", user_model=user_model, tokenizer=tokenizer, batch_size=args.batch_size, @@ -436,4 +446,4 @@ def run_fn(model): print("Accuracy: %.5f" % acc) print('Throughput: %.3f samples/sec' % (samples / (end - start))) print('Latency: %.3f ms' % ((end - start) * 1000 / samples)) - print('Batch size = %d' % args.batch_size) \ No newline at end of file + print('Batch size = %d' % args.batch_size) diff --git a/neural_compressor/common/__init__.py b/neural_compressor/common/__init__.py index 6e340fd1a04..93b3de4b22b 100644 --- a/neural_compressor/common/__init__.py +++ b/neural_compressor/common/__init__.py @@ -17,7 +17,7 @@ logger, Logger, TuningLogger, - log_quant_execution, + log_process, set_random_seed, set_resume_from, set_workspace, @@ -32,6 +32,8 @@ "level", "logger", "Logger", + "TuningLogger", + "log_process", "set_workspace", "set_random_seed", "set_resume_from", diff --git a/neural_compressor/common/utils/constants.py b/neural_compressor/common/utils/constants.py index 097bae60381..f0ddc3b442b 100644 --- a/neural_compressor/common/utils/constants.py +++ b/neural_compressor/common/utils/constants.py @@ -45,3 +45,12 @@ from typing import Callable, Union OP_NAME_OR_MODULE_TYPE = Union[str, Callable] + +# mode name +from enum import Enum + + +class Mode(Enum): + PREPARE = "prepare" + CONVERT = "convert" + QUANTIZE = "quantize" diff --git a/neural_compressor/common/utils/logger.py b/neural_compressor/common/utils/logger.py index 2f86dc3e92b..94a7ca09c50 100644 --- a/neural_compressor/common/utils/logger.py +++ b/neural_compressor/common/utils/logger.py @@ -20,6 +20,8 @@ import logging import os +from neural_compressor.common.utils import Mode + __all__ = [ "level", "Logger", # TODO: not expose it @@ -140,6 +142,17 @@ def warning(msg, *args, **kwargs): logger = Logger +def _get_log_msg(mode): + log_msg = None + if mode == Mode.QUANTIZE: + log_msg = "Quantization" + elif mode == Mode.PREPARE: # pragma: no cover + log_msg = "Preparation" + elif mode == Mode.CONVERT: # pragma: no cover + log_msg = "Conversion" + return log_msg + + class TuningLogger: """A unified logger for the tuning/quantization process. @@ -155,12 +168,16 @@ def trial_start(cls, trial_index: int = None) -> None: logger.info("%d-trail started.", trial_index) @classmethod - def quantization_start(cls, stacklevel=2) -> None: - logger.info("Quantization started.", stacklevel=stacklevel) + def execution_start(cls, mode=Mode.QUANTIZE, stacklevel=2): + log_msg = _get_log_msg(mode) + assert log_msg is not None, "Please check `mode` in execution_start function of TuningLogger class." + logger.info("{} started.".format(log_msg), stacklevel=stacklevel) @classmethod - def quantization_end(cls, stacklevel=2) -> None: - logger.info("Quantization end.", stacklevel=stacklevel) + def execution_end(cls, mode=Mode.QUANTIZE, stacklevel=2): + log_msg = _get_log_msg(mode) + assert log_msg is not None, "Please check `mode` in execution_end function of TuningLogger class." + logger.info("{} end.".format(log_msg), stacklevel=stacklevel) @classmethod def evaluation_start(cls) -> None: diff --git a/neural_compressor/common/utils/utility.py b/neural_compressor/common/utils/utility.py index b169d418b68..1d92b2277ad 100644 --- a/neural_compressor/common/utils/utility.py +++ b/neural_compressor/common/utils/utility.py @@ -22,7 +22,7 @@ import cpuinfo import psutil -from neural_compressor.common.utils import TuningLogger, logger +from neural_compressor.common.utils import Mode, TuningLogger, logger __all__ = [ "set_workspace", @@ -30,7 +30,7 @@ "set_resume_from", "set_tensorboard", "dump_elapsed_time", - "log_quant_execution", + "log_process", "singleton", "LazyImport", "CpuInfo", @@ -206,14 +206,21 @@ def set_tensorboard(tensorboard: bool): default_tuning_logger = TuningLogger() -def log_quant_execution(func): - def wrapper(*args, **kwargs): - default_tuning_logger.quantization_start(stacklevel=4) +def log_process(mode=Mode.QUANTIZE): + def log_process_wrapper(func): + def inner_wrapper(*args, **kwargs): + start_log = default_tuning_logger.execution_start + end_log = default_tuning_logger.execution_end - # Call the original function - result = func(*args, **kwargs) + start_log(mode=mode, stacklevel=4) - default_tuning_logger.quantization_end(stacklevel=4) - return result + # Call the original function + result = func(*args, **kwargs) - return wrapper + end_log(mode=mode, stacklevel=4) + + return result + + return inner_wrapper + + return log_process_wrapper diff --git a/neural_compressor/onnxrt/quantization/autotune.py b/neural_compressor/onnxrt/quantization/autotune.py index 45a53737384..7cddcc3a8b3 100644 --- a/neural_compressor/onnxrt/quantization/autotune.py +++ b/neural_compressor/onnxrt/quantization/autotune.py @@ -75,10 +75,10 @@ def autotune( if calibration_data_reader is not None: calibration_data_reader.rewind() tuning_logger.trial_start(trial_index=trial_index) - tuning_logger.quantization_start() + tuning_logger.execution_start() logger.debug("quant config: {}".format(quant_config)) q_model = _quantize(model_input, quant_config=quant_config, calibration_data_reader=calibration_data_reader) - tuning_logger.quantization_end() + tuning_logger.execution_end() tuning_logger.evaluation_start() with tempfile.TemporaryDirectory(prefix="ort.quant.") as tmp_dir: # evaluate API requires str input diff --git a/neural_compressor/onnxrt/quantization/quantize.py b/neural_compressor/onnxrt/quantization/quantize.py index 487715f8e16..eee9f3162f1 100644 --- a/neural_compressor/onnxrt/quantization/quantize.py +++ b/neural_compressor/onnxrt/quantization/quantize.py @@ -19,7 +19,7 @@ from neural_compressor.common import Logger from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry -from neural_compressor.common.utils import log_quant_execution +from neural_compressor.common.utils import log_process from neural_compressor.onnxrt.quantization.calibrate import CalibrationDataReader from neural_compressor.onnxrt.quantization.config import FRAMEWORK_NAME from neural_compressor.onnxrt.utils.utility import algos_mapping @@ -32,7 +32,7 @@ def _need_apply(quant_config: BaseConfig, algo_name): # * only for internal usage now -@log_quant_execution +@log_process() def _quantize( model_input: Union[Path, str], quant_config: BaseConfig, diff --git a/neural_compressor/tensorflow/quantization/autotune.py b/neural_compressor/tensorflow/quantization/autotune.py index e89756eece6..5bd588c0c0c 100644 --- a/neural_compressor/tensorflow/quantization/autotune.py +++ b/neural_compressor/tensorflow/quantization/autotune.py @@ -53,10 +53,10 @@ def autotune( tuning_logger.tuning_start() for trial_index, quant_config in enumerate(config_loader): tuning_logger.trial_start(trial_index=trial_index) - tuning_logger.quantization_start() + tuning_logger.execution_start() logger.info(quant_config.to_dict()) q_model = quantize_model(model, quant_config, calib_dataloader, calib_iteration) - tuning_logger.quantization_end() + tuning_logger.execution_end() tuning_logger.evaluation_start() eval_result: float = eval_func_wrapper.evaluate(q_model) tuning_logger.evaluation_end() diff --git a/neural_compressor/tensorflow/quantization/quantize.py b/neural_compressor/tensorflow/quantization/quantize.py index 52285bff2e4..fa613759515 100644 --- a/neural_compressor/tensorflow/quantization/quantize.py +++ b/neural_compressor/tensorflow/quantization/quantize.py @@ -18,7 +18,7 @@ from neural_compressor.common import logger from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry -from neural_compressor.common.utils import STATIC_QUANT, log_quant_execution +from neural_compressor.common.utils import STATIC_QUANT, log_process from neural_compressor.tensorflow.utils import BaseModel, KerasModel, Model, algos_mapping @@ -26,7 +26,7 @@ def need_apply(configs_mapping: Dict[Tuple[str, callable], BaseConfig], algo_nam return any(config.name == algo_name for config in configs_mapping.values()) -@log_quant_execution +@log_process() def quantize_model( model: Union[str, tf.keras.Model, BaseModel], quant_config: Union[BaseConfig, list], diff --git a/neural_compressor/torch/algorithms/base_algorithm.py b/neural_compressor/torch/algorithms/base_algorithm.py index c458c210e33..50a8d189233 100644 --- a/neural_compressor/torch/algorithms/base_algorithm.py +++ b/neural_compressor/torch/algorithms/base_algorithm.py @@ -18,7 +18,7 @@ import torch -from neural_compressor.torch.utils import Mode +from neural_compressor.common.utils import Mode class Quantizer(ABC): diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 41b63ae0490..8ca5e222877 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -18,7 +18,18 @@ import torch -from neural_compressor.common.utils import AUTOROUND, AWQ, FP8_QUANT, GPTQ, HQQ, RTN, SMOOTH_QUANT, STATIC_QUANT, TEQ +from neural_compressor.common.utils import ( + AUTOROUND, + AWQ, + FP8_QUANT, + GPTQ, + HQQ, + RTN, + SMOOTH_QUANT, + STATIC_QUANT, + TEQ, + Mode, +) from neural_compressor.torch.quantization import ( AutoRoundConfig, AWQConfig, @@ -30,14 +41,7 @@ StaticQuantConfig, TEQConfig, ) -from neural_compressor.torch.utils import ( - Mode, - get_quantizer, - is_ipex_imported, - logger, - postprocess_model, - register_algo, -) +from neural_compressor.torch.utils import get_quantizer, is_ipex_imported, logger, postprocess_model, register_algo from neural_compressor.torch.utils.constants import PT2E_STATIC_QUANT diff --git a/neural_compressor/torch/quantization/autotune.py b/neural_compressor/torch/quantization/autotune.py index 403d53c36cd..2d0aa2bd2e0 100644 --- a/neural_compressor/torch/quantization/autotune.py +++ b/neural_compressor/torch/quantization/autotune.py @@ -61,7 +61,7 @@ def autotune( tuning_logger.tuning_start() for trial_index, quant_config in enumerate(config_loader): tuning_logger.trial_start(trial_index=trial_index) - tuning_logger.quantization_start() + tuning_logger.execution_start() logger.info(quant_config.to_dict()) # !!! Make sure to use deepcopy only when inplace is set to `True`. q_model = quantize( @@ -72,7 +72,7 @@ def autotune( inplace=True, example_inputs=example_inputs, ) - tuning_logger.quantization_end() + tuning_logger.execution_end() tuning_logger.evaluation_start() eval_result: float = eval_func_wrapper.evaluate(q_model) tuning_logger.evaluation_end() diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index 4d27ac263d6..47f1e89667b 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -18,9 +18,9 @@ import torch from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry -from neural_compressor.common.utils import log_quant_execution +from neural_compressor.common.utils import Mode, log_process from neural_compressor.torch.quantization.config import SmoothQuantConfig, StaticQuantConfig -from neural_compressor.torch.utils import Mode, is_ipex_available, logger +from neural_compressor.torch.utils import is_ipex_available, logger from neural_compressor.torch.utils.utility import WHITE_MODULE_LIST, algos_mapping, get_model_info FRAMEWORK_NAME = "torch" @@ -30,7 +30,7 @@ def need_apply(configs_mapping: Dict[Tuple[str, callable], BaseConfig], algo_nam return any(config.name == algo_name for config in configs_mapping.values()) -@log_quant_execution +@log_process(mode=Mode.QUANTIZE) def quantize( model: torch.nn.Module, quant_config: BaseConfig, @@ -86,6 +86,7 @@ def quantize( return q_model +@log_process(mode=Mode.PREPARE) def prepare( model: torch.nn.Module, quant_config: BaseConfig, @@ -143,6 +144,7 @@ def prepare( return prepared_model +@log_process(mode=Mode.CONVERT) def convert( model: torch.nn.Module, quant_config: BaseConfig = None, diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index 3ea196ccdfb..af0b4f8b79d 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -13,7 +13,6 @@ # limitations under the License. -from enum import Enum from typing import Callable, Dict, List, Tuple, Union import torch @@ -24,6 +23,7 @@ from typing_extensions import TypeAlias from neural_compressor.common import logger +from neural_compressor.common.utils import Mode OP_NAME_AND_TYPE_TUPLE_TYPE: TypeAlias = Tuple[str, Union[torch.nn.Module, Callable]] @@ -131,12 +131,6 @@ def get_double_quant_config(double_quant_type): return DOUBLE_QUANT_CONFIGS[double_quant_type] -class Mode(Enum): - PREPARE = "prepare" - CONVERT = "convert" - QUANTIZE = "quantize" - - def get_quantizer(model, quantizer_cls, quant_config=None, *args, **kwargs): """Get the quantizer.