Skip to content

Commit

Permalink
Introduce quant_level into mixed precision (#950)
Browse files Browse the repository at this point in the history
* Introduce quant_level into mixed precision

Signed-off-by: yiliu30 <yi4.liu@intel.com>

---------

Signed-off-by: yiliu30 <yi4.liu@intel.com>
  • Loading branch information
yiliu30 committed Jul 10, 2023
1 parent 00e5cb5 commit 0dc6a92
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 39 deletions.
14 changes: 14 additions & 0 deletions neural_compressor/config.py
Expand Up @@ -1691,6 +1691,8 @@ class MixedPrecisionConfig(object):
model_name (str, optional): The name of the model. Default value is empty.
inputs (list, optional): Inputs of model, default is [].
outputs (list, optional): Outputs of model, default is [].
quant_level: Support auto, 0 and 1, 0 is conservative(fallback in op type wise),
1(fallback in op wise), auto (default) is the combination of 0 and 1.
tuning_criterion (TuningCriterion object, optional): Accuracy tuning settings,
it won't work if there is no accuracy tuning process.
accuracy_criterion (AccuracyCriterion object, optional): Accuracy constraint settings,
Expand Down Expand Up @@ -1739,6 +1741,7 @@ def __init__(self,
model_name="",
inputs=[],
outputs=[],
quant_level="auto",
tuning_criterion=tuning_criterion,
accuracy_criterion=accuracy_criterion,
excluded_precisions=[],
Expand All @@ -1750,6 +1753,7 @@ def __init__(self,
self.outputs = outputs
self.backend = backend
self.device = device
self.quant_level = quant_level
self.excluded_precisions = excluded_precisions
self.accuracy_criterion = accuracy_criterion
self.tuning_criterion = tuning_criterion
Expand Down Expand Up @@ -1788,6 +1792,16 @@ def model_name(self, model_name):
if _check_value("model_name", model_name, str):
self._model_name = model_name

@property
def quant_level(self):
"""Get the quantization level."""
return self._quant_level

@quant_level.setter
def quant_level(self, quant_level):
"""Set the quantization level."""
self._quant_level = quant_level

@property
def accuracy_criterion(self):
"""Get the accuracy criterion."""
Expand Down
111 changes: 87 additions & 24 deletions neural_compressor/strategy/auto_mixed_precision.py
Expand Up @@ -18,11 +18,11 @@
"""The auto-mixed precision strategy."""

import copy
import numpy as np
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from itertools import groupby
from .strategy import strategy_registry, TuneStrategy
from ..utils import logger
from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler
from .utils.tuning_sampler import FallbackTuningSampler
from .utils.tuning_structs import OpTuningConfig
from neural_compressor.adaptor.torch_utils.mixed_precision import ipex_mixed_precision

Expand Down Expand Up @@ -50,6 +50,7 @@ def _initialize_config(self, conf):
config.domain = getattr(config, 'domain', None)
config.reduce_range = getattr(config, 'reduce_range', None)
config.example_inputs = getattr(config, 'example_inputs', None)
config.quant_level = getattr(config, "quant_level", "auto")
return config

def next_tune_cfg(self):
Expand Down Expand Up @@ -79,54 +80,116 @@ def next_tune_cfg(self):
if not target_dtypes:
target_dtypes = ['bf16']
# step1. target_dtype AMAP, collect the ops that support target_dtype
bf16_items_name = []
lower_precision_items_name = []
op_tuning_cfg = {}
for idx, target_dtype in enumerate(target_dtypes):
bf16_items = tuning_space.query_items_by_quant_mode(target_dtype)
if len(bf16_items) == 0 and \
not (idx == len(target_dtypes) - 1 and len(bf16_items_name) == 0):
lower_precision_items = tuning_space.query_items_by_quant_mode(target_dtype)
if len(lower_precision_items) == 0 and \
not (idx == len(target_dtypes) - 1 and len(lower_precision_items_name) == 0):
continue
bf16_items_name = [item.name for item in bf16_items]
lower_precision_items_name = [item.name for item in lower_precision_items]
op_tuning_cfg = deepcopy(initial_op_tuning_cfg)
for op_name_type in bf16_items_name:
for op_name_type in lower_precision_items_name:
op_tuning_cfg[op_name_type] = \
OpTuningConfig(op_name_type[0], op_name_type[1], target_dtype, tuning_space)
calib_sampling_size = 1
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
yield op_tuning_cfg

# step2. fallback
target_dtype = 'fp32'
fallback_items_name_lst = bf16_items_name[::-1]
# step 2, fallback op into fp32
# quant_level:
# auto: op-type-wise -> op-wise
# 0: op-type wise
# 1: op-wise

# if quant level is auto or 0, do op type wise fallback
target_dtype = "fp32"
fallback_items_name_lst = lower_precision_items_name[::-1]
if fallback_items_name_lst:
logger.info(f"Start to fallback op to {target_dtype} one by one.")
self._fallback_started()
op_dtypes = OrderedDict(zip(fallback_items_name_lst, [target_dtype] * len(fallback_items_name_lst)))
logger.info("[Strategy] start fallback op into fp32.")
initial_op_tuning_cfg = deepcopy(op_tuning_cfg)
if self.config.quant_level in ["auto", 0]:
logger.info(f"[Strategy] fallback op into fp32 in op type wise, \
as quant level is {self.config.quant_level}")
for op_tuning_cfg in self.fallback_in_op_type_wise(tuning_space, fallback_items_name_lst,\
deepcopy(initial_op_tuning_cfg), target_dtype):
yield op_tuning_cfg

# if quant level is auto or 1, do op instance fallback
if self.config.quant_level in ["auto", 1]:
logger.info(f"[Strategy] fallback op into fp32 in op wise, \
as quant level is {self.config.quant_level}")
for op_tuning_cfg in self.fallback_in_op_wise(tuning_space, fallback_items_name_lst,\
deepcopy(initial_op_tuning_cfg), target_dtype):
yield op_tuning_cfg

def fallback_in_op_type_wise(self, tuning_space, fallback_items_name_lst, initial_op_tuning_cfg, target_dtype):
"""Fallback op in op type wise.
Args:
tuning_space: tuning space
fallback_items_name_lst: the list of items to be fallback
initial_op_tuning_cfg: initial tuning config
target_dtype: target data type, such as fp32
Yields:
tuning config
"""
fallback_items_name_lst.sort(key=lambda x: x[1])
op_type_groups = groupby(fallback_items_name_lst, key=lambda x: x[1])
# key: ((op1_name, op_type1),(op2_name, op_type1), (op3_name, op_type1), ...)
# value: target dtype
ops_dtypes = OrderedDict()
for op_type, op_lst in op_type_groups:
ops_dtypes[tuple(op_lst)] = target_dtype
fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[],
initial_op_tuning_cfg=initial_op_tuning_cfg,
op_dtypes=op_dtypes, accumulate=False)
initial_op_tuning_cfg=initial_op_tuning_cfg,
op_dtypes=ops_dtypes, accumulate=False)
op_fallback_acc_impact = OrderedDict()
for op_index, op_tuning_cfg in enumerate(fallback_sampler):
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
op_tuning_cfg['calib_sampling_size'] = -1
yield op_tuning_cfg
acc, _ = self.last_tune_result
op_fallback_acc_impact[fallback_items_name_lst[op_index]] = acc

def fallback_in_op_wise(self, tuning_space, fallback_items_name_lst, initial_op_tuning_cfg, target_dtype):
"""Fallback op in op wise.
Args:
tuning_space: tuning space
fallback_items_name_lst: the list of items to be fallback
initial_op_tuning_cfg: initial tuning config
target_dtype: target data type, such as fp32
Yields:
tuning config
"""
op_dtypes = OrderedDict(zip(fallback_items_name_lst, [target_dtype] * len(fallback_items_name_lst)))
fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[],
initial_op_tuning_cfg=initial_op_tuning_cfg,
op_dtypes=op_dtypes, accumulate=False)
op_fallback_acc_impact = OrderedDict()
for op_index, op_tuning_cfg in enumerate(fallback_sampler):
op_tuning_cfg['calib_sampling_size'] = -1
yield op_tuning_cfg
acc, _ = self.last_tune_result
op_fallback_acc_impact[fallback_items_name_lst[op_index]] = acc

# do accumulated fallback according to the order in the previous stage
if len(op_fallback_acc_impact) > 0:
ordered_ops = sorted(op_fallback_acc_impact.keys(), key=lambda key: op_fallback_acc_impact[key],
reverse=self.higher_is_better)
ordered_ops = sorted(op_fallback_acc_impact.keys(), key=lambda key: op_fallback_acc_impact[key], \
reverse=self.higher_is_better)
op_dtypes = OrderedDict(zip(ordered_ops, [target_dtype] * len(fallback_items_name_lst)))
logger.info("Start to accumulate fallback to {target_dtype}.")
initial_op_tuning_cfg = deepcopy(op_tuning_cfg)
initial_op_tuning_cfg = copy.deepcopy(op_tuning_cfg)
fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[],
initial_op_tuning_cfg=initial_op_tuning_cfg,
op_dtypes=op_dtypes, accumulate=True)
initial_op_tuning_cfg=initial_op_tuning_cfg,
op_dtypes=op_dtypes, accumulate=True)
for op_tuning_cfg in fallback_sampler:
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
op_tuning_cfg['calib_sampling_size'] = -1
yield op_tuning_cfg


def traverse(self):
"""Traverse the tuning space according to auto-mixed precision strategy."""
if self.config.backend == "ipex":
Expand Down
30 changes: 16 additions & 14 deletions neural_compressor/strategy/utils/tuning_sampler.py
Expand Up @@ -20,7 +20,7 @@
from itertools import product
import copy
from collections import deque, OrderedDict, defaultdict
from typing import List, Dict, Any
from typing import List, Dict, Any, Union, Tuple
from .tuning_space import TuningSpace, pattern_to_internal, pattern_to_path, quant_mode_from_pattern
from .tuning_structs import OpTuningConfig
from ...utils import logger
Expand Down Expand Up @@ -382,8 +382,8 @@ class FallbackTuningSampler(TuningSampler):
def __init__(self,
tuning_space: TuningSpace,
tuning_order_lst: List[TuningOrder],
initial_op_tuning_cfg: Dict[tuple, Any],
op_dtypes: Dict[str, str],
initial_op_tuning_cfg: Dict[Tuple, Any],
op_dtypes: Dict[Union[Tuple, Tuple[Tuple]], str],
accumulate: bool,
skip_first: bool = True
):
Expand Down Expand Up @@ -414,21 +414,23 @@ def __iter__(self):
# Only support fallback to lower precision.
if not self.accumulate:
new_tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg)
full_path = self.tuning_space.get_op_default_path_by_pattern(op_name_type, target_dtype)
self.op_complete_path[op_name_type] = copy.deepcopy(full_path)
config_args = {}
self._set_dtype(op_name_type, config_args)
internal_pattern = pattern_to_internal(target_dtype)
quant_mode = quant_mode_from_pattern(internal_pattern)
new_op_config = OpTuningConfig(op_name_type[0], op_name_type[1],
quant_mode, self.tuning_space,
kwargs=config_args)
op_name_type_lst = [op_name_type] if len(op_name_type) != 1 and \
isinstance(op_name_type[1], str) else op_name_type
for op_name_type in op_name_type_lst:
full_path = self.tuning_space.get_op_default_path_by_pattern(op_name_type, target_dtype)
self.op_complete_path[op_name_type] = copy.deepcopy(full_path)
config_args = {}
self._set_dtype(op_name_type, config_args)
internal_pattern = pattern_to_internal(target_dtype)
quant_mode = quant_mode_from_pattern(internal_pattern)
new_op_config = OpTuningConfig(op_name_type[0], op_name_type[1], quant_mode, \
self.tuning_space, kwargs=config_args)

new_tune_cfg.update({op_name_type: new_op_config})
new_tune_cfg.update({op_name_type: new_op_config})
if self.accumulate and skip_first: # skip the first one
skip_first = False
continue
logger.info(f"fallback {op_name_type} to {target_dtype}")
logger.info(f"fallback {op_name_type_lst} to {target_dtype}")
yield new_tune_cfg # need to skip the first one

class LowerBitsSampler(TuningSampler):
Expand Down
79 changes: 78 additions & 1 deletion test/mixed_precision/test_mixed_precision.py
Expand Up @@ -328,7 +328,7 @@ def test_mixed_precision_with_eval_func(self):
def eval(model):
return 0.5

result = [0., 0.1, 0.102, 0.1006, 0.1005, 0.1004, 0.1002]
result = [0., 0.1, 0.102, 0.1003, 0.1005, 0.1004, 0.1002]

def eval2(model):
del result[0]
Expand Down Expand Up @@ -371,6 +371,83 @@ def eval2(model):
output_model = fit(self.tf_model, conf, eval)
self.assertTrue(any([i.op == 'Cast' for i in output_model.graph_def.node]))


def test_mixed_precision_with_quant_level_1(self):

result = [0., 0.1, 0.102]
def eval_func(model):
del result[0]
return result[0]

conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level="auto")

output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func)
self.assertTrue(any([i.op == 'Cast' for i in output_model.graph_def.node]))
self.assertEqual(conf.inputs, 'input')
self.assertEqual(conf.outputs, 'final')

def test_mixed_precision_with_quant_level_2(self):

result = [0., 1, 0.9, 1.1]
# meet acc if fallback all conv
def eval_func(model):
del result[0]
return result[0]

conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level="auto")

output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func)
# no cast in output model
self.assertFalse(any([i.op == 'Cast' for i in output_model.graph_def.node]))

def test_mixed_precision_with_quant_level_3(self):

result = [0., 1, 0.9, 0.9, 1.1]
# meet acc if fallback 1 conv
def eval_func(model):
del result[0]
return result[0]

conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level="auto")

output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func)
# no cast in output model
count_cast = 0
for node in output_model.graph_def.node:
if node.op == "Cast":
count_cast += 1
self.assertEqual(count_cast, 4)

def test_mixed_precision_with_quant_level_4(self):

result = [0., 1, 0.9, 0.9, 1.1]
# meet acc if fallback the second conv
def eval_func(model):
del result[0]
return result[0]

conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level=1)

output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func)
# no cast in output model
count_cast = 0
for node in output_model.graph_def.node:
if node.op == "Cast":
count_cast += 1
self.assertEqual(count_cast, 4)

def test_mixed_precision_with_quant_level_5(self):
result = [0., 1, 0.9, 0.9, 0.9]
# meet not meet
def eval_func(model):
del result[0]
return result[0]

conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level=0)

output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func)
self.assertIsNone(output_model)

@unittest.skipIf(PT_VERSION.release < Version("1.11.0").release,
"Please use PyTroch 1.11 or higher version for mixed precision.")
def test_mixed_precision_with_eval_func_pt(self):
Expand Down

0 comments on commit 0dc6a92

Please sign in to comment.