Skip to content

Commit

Permalink
Enhance 3.x API (#1397)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Co-authored-by: chensuyue <suyue.chen@intel.com>
  • Loading branch information
yiliu30 and chensuyue committed Nov 21, 2023
1 parent 54e4d43 commit 8447d70
Show file tree
Hide file tree
Showing 19 changed files with 1,438 additions and 63 deletions.
1 change: 1 addition & 0 deletions .azure-pipelines/scripts/ut/run_3x_pt.sh
Expand Up @@ -6,6 +6,7 @@ echo "${test_case}"
# install requirements
echo "set up UT env..."
pip install -r /neural-compressor/requirements_pt.txt
pip install transformers
pip install coverage
pip install pytest
pip list
Expand Down
1 change: 0 additions & 1 deletion .azure-pipelines/ut-3x-pt.yml
Expand Up @@ -14,7 +14,6 @@ pr:
- setup.py
- requirements.txt
- requirements_pt.txt
- .azure-pipelines/scripts/ut

pool: ICX-16C

Expand Down
1 change: 0 additions & 1 deletion .azure-pipelines/ut-basic-no-cover.yml
Expand Up @@ -12,7 +12,6 @@ pr:
- test
- setup.py
- requirements.txt
- .azure-pipelines/scripts/ut
exclude:
- test/neural_coder
- test/3x
Expand Down
1 change: 0 additions & 1 deletion .azure-pipelines/ut-basic.yml
Expand Up @@ -12,7 +12,6 @@ pr:
- test
- setup.py
- requirements.txt
- .azure-pipelines/scripts/ut
exclude:
- test/neural_coder
- test/3x
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/common/__init__.py
Expand Up @@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from neural_compressor.common.logger import level, log, info, debug, warn, warning, error, fatal
142 changes: 117 additions & 25 deletions neural_compressor/common/base_config.py
Expand Up @@ -19,10 +19,14 @@

import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Union
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from neural_compressor.common.logger import Logger
from neural_compressor.common.utility import BASE_CONFIG, COMPOSABLE_CONFIG, GLOBAL, LOCAL

logger = Logger().get_logger()

from neural_compressor.common.utility import BASE_CONFIG, GLOBAL, OPERATOR_NAME
from neural_compressor.utils import logger

# Dictionary to store registered configurations
registered_configs = {}
Expand Down Expand Up @@ -57,31 +61,47 @@ class BaseConfig(ABC):
name = BASE_CONFIG

def __init__(self) -> None:
self.global_config: Optional[BaseConfig] = None
self._global_config: Optional[BaseConfig] = None
# For PyTorch, operator_type is the collective name for module type and functional operation type,
# for example, `torch.nn.Linear`, and `torch.nn.functional.linear`.
self.operator_type_config: Dict[Union[str, Callable], Optional[BaseConfig]] = {}
self.operator_name_config: Dict[str, Optional[BaseConfig]] = {}

def set_operator_name(self, operator_name: str, config: BaseConfig) -> BaseConfig:
self.operator_name_config[operator_name] = config
return self

def _set_operator_type(self, operator_type: Union[str, Callable], config: BaseConfig) -> BaseConfig:
# TODO (Yi), clean the usage
# hide it from user, as we can use set_operator_name with regular expression to convert its functionality
self.operator_type_config[operator_type] = config
# local config is the collections of operator_type configs and operator configs
self._local_config: Dict[str, Optional[BaseConfig]] = {}

@property
def global_config(self):
if self._global_config is None:
self._global_config = self.__class__(**self.to_dict())
return self._global_config

@global_config.setter
def global_config(self, config):
self._global_config = config

@property
def local_config(self):
return self._local_config

@local_config.setter
def local_config(self, config):
self._local_config = config

def set_local(self, operator_name: str, config: BaseConfig) -> BaseConfig:
if operator_name in self.local_config:
logger.warning("The configuration for %s has already been set, update it.", operator_name)
if self.global_config is None:
self.global_config = self.__class__(**self.to_dict())
self.local_config[operator_name] = config
return self

def to_dict(self, params_list=[], operator2str=None):
result = {}
global_config = {}
for param in params_list:
global_config[param] = getattr(self, param)
if bool(self.operator_name_config):
result[OPERATOR_NAME] = {}
for op_name, config in self.operator_name_config.items():
result[OPERATOR_NAME][op_name] = config.to_dict()
if bool(self.local_config):
result[LOCAL] = {}
for op_name, config in self.local_config.items():
result[LOCAL][op_name] = config.to_dict()
result[GLOBAL] = global_config
else:
result = global_config
Expand All @@ -99,10 +119,10 @@ def from_dict(cls, config_dict, str2operator=None):
The constructed config.
"""
config = cls(**config_dict.get(GLOBAL, {}))
operator_config = config_dict.get(OPERATOR_NAME, {})
operator_config = config_dict.get(LOCAL, {})
if operator_config:
for op_name, op_config in operator_config.items():
config.set_operator_name(op_name, cls(**op_config))
config.set_local(op_name, cls(**op_config))
return config

@classmethod
Expand All @@ -120,7 +140,7 @@ def to_json_file(self, filename):
config_dict = self.to_dict()
with open(filename, "w", encoding="utf-8") as file:
json.dump(config_dict, file, indent=4)
logger.info(f"Dump the config into {filename}")
logger.info("Dump the config into %s.", filename)

def to_json_string(self, use_diff: bool = False) -> str:
"""Serializes this instance to a JSON string.
Expand All @@ -137,7 +157,7 @@ def to_json_string(self, use_diff: bool = False) -> str:
config_dict = self.to_diff_dict(self)
else:
config_dict = self.to_dict()
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
return json.dumps(config_dict, indent=2) + "\n"

def __repr__(self) -> str:
return f"{self.__class__.__name__} {self.to_json_string()}"
Expand All @@ -154,10 +174,82 @@ def validate(self, user_config: BaseConfig):
pass

def __add__(self, other: BaseConfig) -> BaseConfig:
# TODO(Yi) implement config add, like RTNWeightOnlyQuantConfig() + GPTQWeightOnlyQuantConfig()
pass
if isinstance(other, type(self)):
for op_name, config in other.local_config.items():
self.set_local(op_name, config)
return self
else:
return ComposableConfig(configs=[self, other])

def _get_op_name_op_type_config(self):
op_type_config_dict = dict()
op_name_config_dict = dict()
for name, config in self.local_config.items():
if self._is_op_type(name):
op_type_config_dict[name] = config
else:
op_name_config_dict[name] = config
return op_type_config_dict, op_name_config_dict

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDict[Union[str, Callable], OrderedDict[str, BaseConfig]]:
config_mapping = OrderedDict()
if config_list is None:
config_list = [self]
for config in config_list:
global_config = config.global_config
op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config()
for op_name, op_type in model_info:
config_mapping.setdefault(op_type, OrderedDict())[op_name] = global_config
if op_type in op_type_config_dict:
config_mapping[op_type][op_name] = op_name_config_dict[op_type]
if op_name in op_name_config_dict:
config_mapping[op_type][op_name] = op_name_config_dict[op_name]
return config_mapping

@staticmethod
def _is_op_type(name: str) -> bool:
# TODO (Yi), ort and tf need override it
return not isinstance(name, str)


class ComposableConfig(BaseConfig):
name = COMPOSABLE_CONFIG

def __init__(self, configs: List[BaseConfig]) -> None:
self.config_list = configs

def __add__(self, other: BaseConfig) -> BaseConfig:
if isinstance(other, type(self)):
self.config_list.extend(other.config_list)
else:
self.config_list.append(other)
return self

def to_dict(self, params_list=[], operator2str=None):
result = {}
for config in self.config_list:
result[config.name] = config.to_dict()
return result

@classmethod
def from_dict(cls, config_dict, str2operator=None):
# TODO(Yi)
pass

def to_json_string(self, use_diff: bool = False) -> str:
return json.dumps(self.to_dict(), indent=2) + "\n"

def __repr__(self) -> str:
return f"{self.__class__.__name__} {self.to_json_string()}"

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDict[str, BaseConfig]:
return super().to_config_mapping(self.config_list, model_info)

@classmethod
def register_supported_configs(cls):
"""Add all supported configs."""
raise NotImplementedError
134 changes: 134 additions & 0 deletions neural_compressor/common/logger.py
@@ -0,0 +1,134 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Logger: handles logging functionalities."""

import logging
import os


class Logger(object):
"""Logger class."""

__instance = None

def __new__(cls):
"""Create a singleton Logger instance."""
if Logger.__instance is None:
Logger.__instance = object.__new__(cls)
Logger.__instance._log()
return Logger.__instance

def _log(self):
"""Setup the logger format and handler."""
LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper()
self._logger = logging.getLogger("neural_compressor")
self._logger.handlers.clear()
self._logger.setLevel(LOGLEVEL)
formatter = logging.Formatter(
"%(asctime)s [%(levelname)s][%(filename)s:%(lineno)d] %(message)s", "%Y-%m-%d %H:%M:%S"
)
streamHandler = logging.StreamHandler()
streamHandler.setFormatter(formatter)
self._logger.addHandler(streamHandler)
self._logger.propagate = False

def get_logger(self):
"""Get the logger."""
return self._logger


def _pretty_dict(value, indent=0):
"""Make the logger dict pretty."""
prefix = "\n" + " " * (indent + 4)
if isinstance(value, dict):
items = [prefix + repr(key) + ": " + _pretty_dict(value[key], indent + 4) for key in value]
return "{%s}" % (",".join(items) + "\n" + " " * indent)
elif isinstance(value, list):
items = [prefix + _pretty_dict(item, indent + 4) for item in value]
return "[%s]" % (",".join(items) + "\n" + " " * indent)
elif isinstance(value, tuple):
items = [prefix + _pretty_dict(item, indent + 4) for item in value]
return "(%s)" % (",".join(items) + "\n" + " " * indent)
else:
return repr(value)


level = Logger().get_logger().level
DEBUG = logging.DEBUG


def log(level, msg, *args, **kwargs):
"""Output log with the level as a parameter."""
if isinstance(msg, dict):
for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().log(level, line, *args, **kwargs)
else:
Logger().get_logger().log(level, msg, *args, **kwargs)


def debug(msg, *args, **kwargs):
"""Output log with the debug level."""
if isinstance(msg, dict):
for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().debug(line, *args, **kwargs)
else:
Logger().get_logger().debug(msg, *args, **kwargs)


def error(msg, *args, **kwargs):
"""Output log with the error level."""
if isinstance(msg, dict):
for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().error(line, *args, **kwargs)
else:
Logger().get_logger().error(msg, *args, **kwargs)


def fatal(msg, *args, **kwargs):
"""Output log with the fatal level."""
if isinstance(msg, dict):
for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().fatal(line, *args, **kwargs)
else:
Logger().get_logger().fatal(msg, *args, **kwargs)


def info(msg, *args, **kwargs):
"""Output log with the info level."""
if isinstance(msg, dict):
for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().info(line, *args, **kwargs)
else:
Logger().get_logger().info(msg, *args, **kwargs)


def warn(msg, *args, **kwargs):
"""Output log with the warning level."""
if isinstance(msg, dict):
for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().warning(line, *args, **kwargs)
else:
Logger().get_logger().warning(msg, *args, **kwargs)


def warning(msg, *args, **kwargs):
"""Output log with the warning level (Alias of the method warn)."""
if isinstance(msg, dict):
for _, line in enumerate(_pretty_dict(msg).split("\n")):
Logger().get_logger().warning(line, *args, **kwargs)
else:
Logger().get_logger().warning(msg, *args, **kwargs)
4 changes: 3 additions & 1 deletion neural_compressor/common/utility.py
Expand Up @@ -20,8 +20,10 @@

# constants for configs
GLOBAL = "global"
OPERATOR_NAME = "operator_name"
LOCAL = "local"

# config name
BASE_CONFIG = "base_config"
COMPOSABLE_CONFIG = "composable_config"
RTN_WEIGHT_ONLY_QUANT = "rtn_weight_only_quant"
DUMMY_CONFIG = "dummy_config"
8 changes: 7 additions & 1 deletion neural_compressor/torch/__init__.py
Expand Up @@ -15,4 +15,10 @@
from neural_compressor.torch.utils import register_algo
from neural_compressor.torch.algorithms import rtn_quantize_entry

from neural_compressor.torch.quantization import quantize, RTNWeightQuantConfig, get_default_rtn_config
from neural_compressor.torch.quantization import (
quantize,
RTNWeightQuantConfig,
get_default_rtn_config,
DummyConfig,
get_default_dummy_config,
)

0 comments on commit 8447d70

Please sign in to comment.