Skip to content

Commit

Permalink
Make quantization config contain only serializable properties.
Browse files Browse the repository at this point in the history
  • Loading branch information
nikita-savelyevv committed Apr 11, 2024
1 parent 20fd761 commit f7fa3a1
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 255 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ REAL_CLONE_URL = $(if $(CLONE_URL),$(CLONE_URL),$(DEFAULT_CLONE_URL))

# Run code quality checks
style_check:
black --check .
black .
ruff check .

style:
Expand Down
228 changes: 119 additions & 109 deletions optimum/intel/openvino/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import copy
import inspect
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union

import datasets
import nncf
import torch
from nncf.quantization.advanced_parameters import OverflowFix
Expand Down Expand Up @@ -52,36 +52,6 @@
}


class _replace_properties_values:
"""
A context manager for temporarily overriding an object's properties
"""

def __init__(self, obj, property_names, property_values):
self.obj = obj
self.property_names = property_names
self.new_property_values = property_values
self.old_property_values = [None] * len(property_names)
for i, property_name in enumerate(self.property_names):
self.old_property_values[i] = getattr(obj, property_name)

def __enter__(self):
for property_name, new_property_value in zip(self.property_names, self.new_property_values):
setattr(self.obj, property_name, new_property_value)

def __exit__(self, exc_type, exc_val, exc_tb):
for property_name, old_property_value in zip(self.property_names, self.old_property_values):
setattr(self.obj, property_name, old_property_value)


def _is_serializable(obj):
try:
json.dumps(obj)
return True
except Exception:
return False


@dataclass
class OVQuantizationConfigBase(QuantizationConfigMixin):
"""
Expand All @@ -90,53 +60,41 @@ class OVQuantizationConfigBase(QuantizationConfigMixin):

def __init__(
self,
dataset: Optional[Union[str, List[str], nncf.Dataset, datasets.Dataset]] = None,
ignored_scope: Optional[Union[dict, nncf.IgnoredScope]] = None,
ignored_scope: Optional[dict] = None,
num_samples: Optional[int] = None,
weight_only: Optional[bool] = None,
**kwargs,
):
"""
Args:
dataset (`str or List[str] or nncf.Dataset or datasets.Dataset`, *optional*):
The dataset used for data-aware weight compression or quantization with NNCF.
ignored_scope (`dict or nncf.IgnoredScope`, *optional*):
An ignored scope that defines the list of model nodes to be ignored during quantization.
ignored_scope (`dict`, *optional*):
An ignored scope that defines a list of model nodes to be ignored during quantization. Dictionary
entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class.
num_samples (`int`, *optional*):
The maximum number of samples composing the calibration dataset.
weight_only (`bool`, *optional*):
Used to explicitly specify type of quantization (weight-only of full) to apply.
"""
self.dataset = dataset
if isinstance(ignored_scope, dict):
ignored_scope = nncf.IgnoredScope(**ignored_scope)
if isinstance(ignored_scope, nncf.IgnoredScope):
ignored_scope = ignored_scope.__dict__
self.ignored_scope = ignored_scope
self.num_samples = num_samples
self.weight_only = weight_only

def post_init(self):
if not (self.dataset is None or isinstance(self.dataset, (str, list, nncf.Dataset, datasets.Dataset))):
try:
self.get_ignored_scope_instance()
except Exception as e:
raise ValueError(
"Dataset must be a instance of either string, list of strings, nncf.Dataset or "
f"dataset.Dataset, but found {type(self.dataset)}"
)
if not (self.ignored_scope is None or isinstance(self.ignored_scope, nncf.IgnoredScope)):
raise ValueError(
"Ignored scope must be a instance of either dict, or nncf.IgnoredScope but found "
f"{type(self.dataset)}"
f"Can't create an `IgnoredScope` object from the provided ignored scope dict: {self.ignored_scope}.\n{e}"
)
if not (self.num_samples is None or isinstance(self.num_samples, int) and self.num_samples > 0):
raise ValueError(f"`num_samples` is expected to be a positive integer, but found: {self.num_samples}")

def _to_dict_without_properties(self, property_names: Union[List[str], Tuple[str]]) -> Dict[str, Any]:
"""
Calls to_dict() with given properties overwritten with None. Useful for hiding non-serializable properties.
"""
if len(property_names) == 0:
return super().to_dict()
with _replace_properties_values(self, property_names, [None] * len(property_names)):
result = super().to_dict()
return result

def to_dict(self) -> Dict[str, Any]:
properties_to_omit = [] if _is_serializable(self.dataset) else ["dataset"]
if isinstance(self.ignored_scope, nncf.IgnoredScope):
with _replace_properties_values(self, ["ignored_scope"], [self.ignored_scope.__dict__]):
return self._to_dict_without_properties(properties_to_omit)
return self._to_dict_without_properties(properties_to_omit)
def get_ignored_scope_instance(self) -> nncf.IgnoredScope:
if self.ignored_scope is None:
return nncf.IgnoredScope()
return nncf.IgnoredScope(**copy.deepcopy(self.ignored_scope))


class OVConfig(BaseConfig):
Expand All @@ -155,16 +113,11 @@ def __init__(
self.input_info = input_info
self.save_onnx_model = save_onnx_model
self.optimum_version = kwargs.pop("optimum_version", None)
if isinstance(quantization_config, dict):
quantization_config = self._quantization_config_from_dict(quantization_config)
self.quantization_config = quantization_config
self.compression = None # A backward-compatability field for training-time compression parameters

if isinstance(self.quantization_config, dict):
# Config is loaded as dict during deserialization
logger.info(
"`quantization_config` was provided as a dict, in this form it can't be used for quantization. "
"Please provide config as an instance of OVWeightQuantizationConfig or OVQuantizationConfig"
)

bits = (
self.quantization_config.bits if isinstance(self.quantization_config, OVWeightQuantizationConfig) else None
)
Expand All @@ -180,12 +133,40 @@ def add_input_info(self, model_inputs: Dict, force_batch_one: bool = False):
for name, value in model_inputs.items()
]

@staticmethod
def _quantization_config_from_dict(quantization_config: dict) -> OVQuantizationConfigBase:
wq_args = inspect.getfullargspec(OVWeightQuantizationConfig.__init__).args
q_args = inspect.getfullargspec(OVQuantizationConfig.__init__).args
config_keys = quantization_config.keys()
matches_wq_config_signature = all(arg_name in wq_args for arg_name in config_keys)
matches_q_config_signature = all(arg_name in q_args for arg_name in config_keys)
if matches_wq_config_signature == matches_q_config_signature:
weight_only = quantization_config.get("weight_only", None)
if weight_only is None:
logger.warning(
"Can't determine type of OV quantization config. Please specify explicitly whether you intend to "
"run weight-only quantization or not with `weight_only` parameter. Creating an instance of "
"OVWeightQuantizationConfig."
)
return OVWeightQuantizationConfig.from_dict(quantization_config)
matches_wq_config_signature = weight_only

config_type = OVWeightQuantizationConfig if matches_wq_config_signature else OVQuantizationConfig
return config_type.from_dict(quantization_config)

def _to_dict_safe(self, to_diff_dict: bool = False) -> Dict[str, Any]:
class ConfigStub:
def to_dict(self):
return None

def to_diff_dict(self):
return None

if self.quantization_config is None:
# Parent to_dict() implementation does not support quantization_config being None
with _replace_properties_values(self, ("quantization_config",), (OVQuantizationConfigBase(),)):
result = super().to_diff_dict() if to_diff_dict else super().to_dict()
del result["quantization_config"]
self_copy = copy.deepcopy(self)
self_copy.quantization_config = ConfigStub()
result = self_copy.to_diff_dict() if to_diff_dict else self_copy.to_dict()
else:
result = super().to_diff_dict() if to_diff_dict else super().to_dict()
return result
Expand All @@ -212,9 +193,8 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
The number of bits to quantize to.
sym (`bool`, defaults to `False`):
Whether to use symmetric quantization.
tokenizer (`str` or `PreTrainedTokenizerBase`, *optional*):
tokenizer (`str`, *optional*):
The tokenizer used to process the dataset. You can pass either:
- A custom tokenizer object.
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
user or organization name, like `dbmdz/bert-base-german-cased`.
Expand All @@ -224,6 +204,8 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
The dataset used for data-aware compression or quantization with NNCF. You can provide your own dataset
in a list of strings or just use the one from the list ['wikitext','c4','c4-new','ptb','ptb-new'] for LLLMs
or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models.
Alternatively, you can provide data objects via `calibration_dataset` argument
of `OVQuantizer.quantize()` method.
ratio (`float`, defaults to 1.0):
The ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to INT4_ASYM
and the rest to INT8_ASYM).
Expand All @@ -235,32 +217,44 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
The sensitivity metric for assigning quantization precision to layers. In order to
preserve the accuracy of the model, the more sensitive layers receives a higher precision.
ignored_scope (`dict`, *optional*):
An ignored scope that defined the list of model control flow graph nodes to be ignored during quantization.
An ignored scope that defines the list of model nodes to be ignored during quantization. Dictionary
entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class.
num_samples (`int`, *optional*):
The maximum number of samples composing the calibration dataset.
quant_method (`str`, defaults of OVQuantizationMethod.DEFAULT):
Weight compression method to apply.
weight_only (`bool`, *optional*):
Used to explicitly specify type of quantization to apply.
weight_only (`bool`, *optional*):
Used to explicitly specify type of quantization (weight-only of full) to apply.
"""

def __init__(
self,
bits: int = 8,
sym: bool = False,
tokenizer: Optional[Any] = None,
dataset: Optional[Union[str, List[str], nncf.Dataset, datasets.Dataset]] = None,
tokenizer: Optional[str] = None,
dataset: Optional[Union[str, List[str]]] = None,
ratio: float = 1.0,
group_size: Optional[int] = None,
all_layers: Optional[bool] = None,
sensitivity_metric: Optional[str] = None,
ignored_scope: Optional[Union[dict, nncf.IgnoredScope]] = None,
ignored_scope: Optional[dict] = None,
num_samples: Optional[int] = None,
quant_method: Optional[Union[QuantizationMethod, OVQuantizationMethod]] = OVQuantizationMethod.DEFAULT,
weight_only: Optional[bool] = True,
**kwargs,
):
super().__init__(dataset, ignored_scope, num_samples)
if weight_only is False:
logger.warning(
"Trying to create an instance of `OVWeightQuantizationConfig` with `weight_only` being "
"False. Please check your configuration."
)
super().__init__(ignored_scope, num_samples, True)
self.bits = bits
self.sym = sym
self.tokenizer = tokenizer
self.dataset = dataset
self.group_size = group_size or (-1 if bits == 8 else 128)
self.ratio = ratio
self.all_layers = all_layers
Expand All @@ -277,6 +271,11 @@ def post_init(self):
raise ValueError("`ratio` must between 0 and 1.")
if self.group_size is not None and self.group_size != -1 and self.group_size <= 0:
raise ValueError("`group_size` must be greater than 0 or equal to -1")
if not (self.dataset is None or isinstance(self.dataset, (str, list))):
raise ValueError(
f"Dataset must be a instance of either string or list of strings, but found {type(self.dataset)}. "
f"If you wish to provide a custom dataset please pass it via `calibration_dataset` argument."
)
if self.dataset is not None and isinstance(self.dataset, str):
llm_datasets = ["wikitext", "c4", "c4-new", "ptb", "ptb-new"]
stable_diffusion_datasets = [
Expand All @@ -303,34 +302,31 @@ def post_init(self):
f"For 8-bit quantization, `group_size` is expected to be set to -1, but was set to {self.group_size}"
)

def to_dict(self) -> Dict[str, Any]:
if not _is_serializable(self.tokenizer):
return self._to_dict_without_properties(("tokenizer",))
return super().to_dict()
if self.tokenizer is not None and not isinstance(self.tokenizer, str):
raise ValueError(f"Tokenizer is expected to be a string, but found {self.tokenizer}")


@dataclass
class OVQuantizationConfig(OVQuantizationConfigBase):
def __init__(
self,
dataset: Union[str, List[str], nncf.Dataset, datasets.Dataset],
ignored_scope: Optional[Union[dict, nncf.IgnoredScope]] = None,
ignored_scope: Optional[dict] = None,
num_samples: Optional[int] = 300,
preset: nncf.QuantizationPreset = None,
model_type: nncf.ModelType = nncf.ModelType.TRANSFORMER,
fast_bias_correction: bool = True,
overflow_fix: OverflowFix = OverflowFix.DISABLE,
weight_only: Optional[bool] = False,
**kwargs,
):
"""
Configuration class containing parameters related to model quantization with NNCF. Compared to weight
compression, during quantization both weights and activations are converted to lower precision.
For weight-only model quantization please see OVWeightQuantizationConfig.
Args:
dataset (`str or List[str] or nncf.Dataset or datasets.Dataset`):
A dataset used for quantization parameters calibration. Required parameter.
ignored_scope (`dict or nncf.IgnoredScope`, *optional*):
An ignored scope that defines the list of model nodes to be ignored during quantization.
ignored_scope (`dict`, *optional*):
An ignored scope that defines the list of model nodes to be ignored during quantization. Dictionary
entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class.
num_samples (`int`, *optional*):
The maximum number of samples composing the calibration dataset.
preset (`nncf.QuantizationPreset`, *optional*):
Expand All @@ -346,31 +342,45 @@ def __init__(
Whether to apply fast or full bias correction algorithm.
overflow_fix (`nncf.OverflowFix`, default to OverflowFix.DISABLE):
Parameter for controlling overflow fix setting.
weight_only (`bool`, *optional*):
Used to explicitly specify type of quantization (weight-only of full) to apply.
"""
super().__init__(dataset, ignored_scope, num_samples)
if weight_only is True:
logger.warning(
"Trying to create an instance of `OVQuantizationConfig` with `weight_only` being True. "
"Please check your configuration."
)
super().__init__(ignored_scope, num_samples, False)
# TODO: remove checks below once NNCF is updated to 2.10
if isinstance(overflow_fix, str):
overflow_fix = OverflowFix(overflow_fix)
if isinstance(preset, str):
preset = nncf.QuantizationPreset(preset)

self.preset = preset
self.model_type = model_type
self.fast_bias_correction = fast_bias_correction
self.overflow_fix = overflow_fix
self.post_init()

def post_init(self):
"""
Safety checker that arguments are correct
"""
super().post_init()
if self.dataset is None:
raise ValueError(
"`dataset` is needed to compute the activations range during the calibration step and was not provided."
" In case you only want to apply quantization on the weights, please run weight-only quantization."
)

def to_dict(self) -> Dict[str, Any]:
# TODO: remove code below once NNCF is updated to 2.10
overflow_fix_value = None if self.overflow_fix is None else self.overflow_fix.value
preset_value = None if self.preset is None else self.preset.value
with _replace_properties_values(self, ("overflow_fix", "preset"), (overflow_fix_value, preset_value)):
return super().to_dict()
if isinstance(self.overflow_fix, Enum) or isinstance(self.preset, Enum):
overflow_fix_value = (
None
if self.overflow_fix is None
else self.overflow_fix
if isinstance(self.overflow_fix, str)
else self.overflow_fix.value
)
preset_value = (
None if self.preset is None else self.preset if isinstance(self.preset, str) else self.preset.value
)
self_copy = copy.deepcopy(self)
self_copy.overflow_fix = overflow_fix_value
self_copy.preset = preset_value
return self_copy.to_dict()
return super().to_dict()


def _check_default_4bit_configs(config: PretrainedConfig):
Expand Down
Loading

0 comments on commit f7fa3a1

Please sign in to comment.