Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type hints #1789

Merged
merged 6 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion neural_compressor_ort/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

DEFAULT_WORKSPACE = "./nc_workspace/{}/".format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))

OP_NAME_OR_MODULE_TYPE = Union[str, Callable]

ONNXRT116_VERSION = version.Version("1.16.0")
ONNXRT1161_VERSION = version.Version("1.16.1")
Expand Down
44 changes: 23 additions & 21 deletions neural_compressor_ort/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,11 @@ class BaseConfig(ABC):
"""The base config for all algorithm configs."""

name = constants.BASE_CONFIG
params_list = []
params_list: List[Union[str, TuningParam]] = []

def __init__(
self, white_list: Optional[List[constants.OP_NAME_OR_MODULE_TYPE]] = constants.DEFAULT_WHITE_LIST
self,
white_list: Optional[Union[Union[str, Callable], List[Union[str, Callable]]]] = constants.DEFAULT_WHITE_LIST,
) -> None:
self._global_config: Optional[BaseConfig] = None
# For PyTorch, operator_type is the collective name for module type and functional operation type,
Expand Down Expand Up @@ -233,7 +234,7 @@ def white_list(self):
return self._white_list

@white_list.setter
def white_list(self, op_name_or_type_list: Optional[List[constants.OP_NAME_OR_MODULE_TYPE]]):
def white_list(self, op_name_or_type_list: Optional[List[Union[str, Callable]]]):
self._white_list = op_name_or_type_list

@property
Expand Down Expand Up @@ -316,7 +317,7 @@ def to_json_file(self, filename):
json.dump(config_dict, file, indent=4)
utility.logger.info("Dump the config into %s.", filename)

def to_json_string(self, use_diff: bool = False) -> str:
def to_json_string(self, use_diff: bool = False) -> Union[str, Dict]:
"""Serializes this instance to a JSON string.

Args:
Expand All @@ -333,7 +334,8 @@ def to_json_string(self, use_diff: bool = False) -> str:
config_dict = self.to_dict()
try:
return json.dumps(config_dict, indent=2) + "\n"
except:
except Exception as e:
utility.logger.error("Failed to serialize the config to JSON string: %s", e)
return config_dict

def __repr__(self) -> str:
Expand Down Expand Up @@ -452,8 +454,8 @@ def _get_op_name_op_type_config(self):
return op_type_config_dict, op_name_config_dict

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDict[Union[str, Callable], OrderedDict[str, BaseConfig]]:
self, config_list: Optional[List[BaseConfig]] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDict[Tuple[str, str], OrderedDict[str, BaseConfig]]:
config_mapping = OrderedDict()
if config_list is None:
config_list = [self]
Expand All @@ -468,7 +470,7 @@ def to_config_mapping(
for op_name_pattern in op_name_config_dict:
if isinstance(op_name, str) and re.match(op_name_pattern, op_name):
config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern]
elif op_name_pattern == op_name: # TODO: map ipex opname to stock pt op_name
elif op_name_pattern == op_name:
config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern]
return config_mapping

Expand Down Expand Up @@ -496,7 +498,7 @@ def __add__(self, other: BaseConfig) -> BaseConfig:
self.config_list.append(other)
return self

def to_dict(self, params_list=[], operator2str=None):
def to_dict(self):
result = {}
for config in self.config_list:
result[config.name] = config.to_dict()
Expand Down Expand Up @@ -551,8 +553,8 @@ def get_model_info(self, model, *args, **kwargs):
return model_info_dict


def get_all_config_set_from_config_registry() -> Union[BaseConfig, List[BaseConfig]]:
all_registered_config_cls: List[BaseConfig] = config_registry.get_all_config_cls()
def get_all_config_set_from_config_registry() -> List[BaseConfig]:
all_registered_config_cls: List[Type[BaseConfig]] = config_registry.get_all_config_cls()
config_set = []
for config_cls in all_registered_config_cls:
config_set.append(config_cls.get_config_set_for_tuning())
Expand All @@ -561,7 +563,7 @@ def get_all_config_set_from_config_registry() -> Union[BaseConfig, List[BaseConf

def register_supported_configs():
"""Register supported configs."""
all_registered_config_cls: List[BaseConfig] = config_registry.get_all_config_cls()
all_registered_config_cls: List[Type[BaseConfig]] = config_registry.get_all_config_cls()
for config_cls in all_registered_config_cls:
config_cls.register_supported_configs()

Expand All @@ -580,7 +582,7 @@ class RTNConfig(BaseConfig):
"""Config class for round-to-nearest weight-only quantization."""

supported_configs: List[_OperatorConfig] = []
params_list: List[str] = [
params_list: List[Union[str, TuningParam]] = [
"weight_dtype",
"weight_bits",
"weight_group_size",
Expand All @@ -607,7 +609,7 @@ def __init__(
providers: List[str] = ["CPUExecutionProvider"],
layer_wise_quant: bool = False,
quant_last_matmul: bool = True,
white_list: List[constants.OP_NAME_OR_MODULE_TYPE] = constants.DEFAULT_WHITE_LIST,
white_list: List[Union[str, Callable]] = constants.DEFAULT_WHITE_LIST,
):
"""Init RTN weight-only quantization config.

Expand Down Expand Up @@ -650,7 +652,7 @@ def get_model_params_dict(self):
return result

@classmethod
def register_supported_configs(cls) -> List[_OperatorConfig]:
def register_supported_configs(cls) -> None:
supported_configs = []
linear_rtn_config = RTNConfig(
weight_dtype=["int"],
Expand Down Expand Up @@ -724,15 +726,15 @@ class GPTQConfig(BaseConfig):
"""Config class for gptq weight-only quantization."""

supported_configs: List[_OperatorConfig] = []
params_list: List[str] = [
params_list: List[Union[str, TuningParam]] = [
"weight_dtype",
"weight_bits",
"weight_group_size",
"weight_sym",
"act_dtype",
"accuracy_level",
]
model_params_list: List[str] = [
model_params_list: List[Union[str, TuningParam]] = [
"percdamp",
"blocksize",
"actorder",
Expand All @@ -759,7 +761,7 @@ def __init__(
providers: List[str] = ["CPUExecutionProvider"],
layer_wise_quant: bool = False,
quant_last_matmul: bool = True,
white_list: List[constants.OP_NAME_OR_MODULE_TYPE] = constants.DEFAULT_WHITE_LIST,
white_list: List[Union[str, Callable]] = constants.DEFAULT_WHITE_LIST,
):
"""Init GPTQ weight-only quantization config.

Expand Down Expand Up @@ -812,7 +814,7 @@ def get_model_params_dict(self):
return result

@classmethod
def register_supported_configs(cls) -> List[_OperatorConfig]:
def register_supported_configs(cls) -> None:
supported_configs = []
linear_gptq_config = GPTQConfig(
weight_dtype=["int"],
Expand Down Expand Up @@ -922,7 +924,7 @@ def __init__(
enable_mse_search: bool = True,
providers: List[str] = ["CPUExecutionProvider"],
quant_last_matmul: bool = True,
white_list: List[constants.OP_NAME_OR_MODULE_TYPE] = constants.DEFAULT_WHITE_LIST,
white_list: List[Union[str, Callable]] = constants.DEFAULT_WHITE_LIST,
):
"""Init AWQ weight-only quantization config.

Expand Down Expand Up @@ -1064,7 +1066,7 @@ def __init__(
scales_per_op: bool = True,
auto_alpha_args: dict = {"alpha_min": 0.3, "alpha_max": 0.7, "alpha_step": 0.05, "attn_method": "min"},
providers: List[str] = ["CPUExecutionProvider"],
white_list: List[constants.OP_NAME_OR_MODULE_TYPE] = constants.DEFAULT_WHITE_LIST,
white_list: List[Union[str, Callable]] = constants.DEFAULT_WHITE_LIST,
**kwargs,
):
"""Init smooth quant config.
Expand Down
2 changes: 0 additions & 2 deletions neural_compressor_ort/quantization/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
# limitations under the License.

import copy
import enum
import os
import pathlib
import tempfile
import uuid
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Sized, Tuple, Union, _GenericAlias

import onnx
import pydantic

from neural_compressor_ort import data_reader, utility
from neural_compressor_ort.quantization import config
Expand Down
6 changes: 3 additions & 3 deletions test/utils/test_general.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tests for general components."""

import unittest
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union

from neural_compressor_ort import constants, utility
from neural_compressor_ort.quantization import config, tuning
Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(
weight_dtype: str = "int",
weight_bits: int = 4,
target_op_type_list: List[str] = ["Conv", "Gemm"],
white_list: Optional[List[constants.OP_NAME_OR_MODULE_TYPE]] = constants.DEFAULT_WHITE_LIST,
white_list: Optional[List[Union[str, Callable]]] = constants.DEFAULT_WHITE_LIST,
):
"""Init fake config.

Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
weight_dtype: str = "int",
weight_bits: int = 4,
target_op_type_list: List[str] = ["Conv", "Gemm"],
white_list: Optional[List[constants.OP_NAME_OR_MODULE_TYPE]] = constants.DEFAULT_WHITE_LIST,
white_list: Optional[List[Union[str, Callable]]] = constants.DEFAULT_WHITE_LIST,
):
"""Init fake config.

Expand Down