Skip to content

Commit

Permalink
Add auto quantization level as the default tuning process (#595)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <yi4.liu@intel.com>
  • Loading branch information
yiliu30 committed Mar 1, 2023
1 parent 35b6f27 commit cdfb994
Show file tree
Hide file tree
Showing 8 changed files with 408 additions and 22 deletions.
6 changes: 2 additions & 4 deletions neural_compressor/conf/config.py
Expand Up @@ -805,7 +805,6 @@ def percent_to_float(data):
'pre_post_process_quantization': True},
'model_wise': {'weight': {'bit': [7.0]},
'activation': {}},
'quant_level': 1,
}): {
Optional('approach', default='post_training_static_quant'): And(
str,
Expand Down Expand Up @@ -899,10 +898,9 @@ def percent_to_float(data):
Optional('op_wise', default=None): {
str: ops_schema
},
Optional('quant_level', default=1): And(int, lambda level: level in [0, 1]),
},
Optional('use_bf16', default=True): bool,
Optional('quant_level', default=1): And(int, lambda level: level in [0, 1]),
Optional('quant_level', default="auto"): And(Or(str, int), lambda level: level in ["auto", 0, 1]),
Optional('graph_optimization'): graph_optimization_schema,
Optional('mixed_precision'): mixed_precision_schema,

Expand Down Expand Up @@ -1178,7 +1176,7 @@ def percent_to_float(data):
'activation': {}},
}): dict,
Optional('use_bf16', default=False): bool,
Optional('quant_level', default=1): int,
Optional('quant_level', default="auto"): Or(str, int),
Optional('tuning', default={
'strategy': {'name': 'basic'},
'accuracy_criterion': {'relative': 0.01, 'higher_is_better': True},
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/conf/pythonic_config.py
Expand Up @@ -41,7 +41,7 @@ def __init__(self,
performance_only=False,
reduce_range=None,
use_bf16=True,
quant_level=1,
quant_level="auto",
accuracy_criterion=accuracy_criterion,
use_distributed_tuning=False):
excluded_precisions = ["bf16"] if not use_bf16 else []
Expand Down
11 changes: 6 additions & 5 deletions neural_compressor/config.py
Expand Up @@ -390,7 +390,7 @@ def __init__(self,
reduce_range=None,
example_inputs=None,
excluded_precisions=[],
quant_level=1,
quant_level="auto",
accuracy_criterion=accuracy_criterion,
use_distributed_tuning=False):
"""Initialize _BaseQuantizationConfig class.
Expand Down Expand Up @@ -431,7 +431,8 @@ def __init__(self,
reduce_range: whether use 7 bit
example_inputs: used to trace PyTorch model with torch.jit/torch.fx
excluded_precisions: precisions to be excluded, support 'bf16'
quant_level: support 0 and 1, 0 is conservative strategy, 1 is basic(default) or user-specified strategy
quant_level: support auto, 0 and 1, 0 is conservative strategy, 1 is basic or user-specified
strategy, auto (default) is the combination of 0 and 1.
accuracy_criterion: accuracy constraint settings
use_distributed_tuning: whether use distributed tuning or not
"""
Expand Down Expand Up @@ -868,7 +869,7 @@ class PostTrainingQuantConfig(_BaseQuantizationConfig):
from neural_compressor.config PostTrainingQuantConfig, TuningCriterion
conf = PostTrainingQuantConfig(
quant_level=0, # the quantization level.
quant_level="auto", # the quantization level.
tuning_criterion=TuningCriterion(
timeout=0, # optional. tuning timeout (seconds). When set to 0, early stopping is enabled.
max_trials=100, # optional. max tuning times.
Expand All @@ -890,7 +891,7 @@ def __init__(self,
op_name_list=None,
reduce_range=None,
excluded_precisions=[],
quant_level=1,
quant_level="auto",
tuning_criterion=tuning_criterion,
accuracy_criterion=accuracy_criterion,
use_distributed_tuning=False,
Expand Down Expand Up @@ -964,7 +965,7 @@ def __init__(self,
op_name_list=None,
reduce_range=None,
excluded_precisions=[],
quant_level=1):
quant_level="auto"):
"""Init a QuantizationAwareTrainingConfig object."""
super().__init__(inputs=inputs,
outputs=outputs,
Expand Down
5 changes: 5 additions & 0 deletions neural_compressor/quantization.py
Expand Up @@ -82,6 +82,11 @@ def pre_proccess(self):
" force setting 'tuning.exit_policy.performance_only = True'.".format(performance_only))

strategy = cfg.tuning.strategy.name.lower()

if cfg.quant_level == "auto" or cfg.quantization.quant_level == "auto":
strategy = "auto"
logger.info(f"Start auto tuning.")

if cfg.quantization.quant_level == 0:
strategy = "conservative"
logger.info(f"On the premise that the accuracy meets the conditions, improve the performance.")
Expand Down
106 changes: 106 additions & 0 deletions neural_compressor/strategy/auto.py
@@ -0,0 +1,106 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The auto tuning strategy."""
import copy
from copy import deepcopy
import numpy as np
from collections import OrderedDict
from .strategy import strategy_registry, TuneStrategy, STRATEGIES
from ..utils import logger

from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler, ModelWiseTuningSampler
from .utils.tuning_structs import OpTuningConfig
from .utils.constant import TUNING_ITEMS_LST

@strategy_registry
class AutoTuneStrategy(TuneStrategy):
"""The auto tuning strategy.
There are three stages executed by auto strategy sequentially,
and the tuning process ends once the condition meets the exit policy.
"""

def __init__(self, model, conf, q_dataloader=None, q_func=None, \
eval_dataloader=None, eval_func=None, resume=None, q_hooks=None):
"""Init an auto tuning strategy.
Args:
model: The FP32 model specified for low precision tuning.
conf: The Conf class instance includes all user configurations.
q_dataloader: Data loader for calibration, mandatory for post-training quantization. Defaults to None.
q_func: Training function for quantization aware training. Defaults to None. Defaults to None.
eval_dataloader: Data loader for evaluation. Defaults to None.
eval_func: The evaluation function provided by user. This function takes model as parameter, and
evaluation dataset and metrics should be encapsulated in this function implementation and
outputs a higher-is-better accuracy scalar value.
resume: The dict containing resume information. Defaults to None.
q_hooks: The dict of training hooks, supported keys are: on_epoch_begin, on_epoch_end, on_step_begin,
on_step_end. Their values are functions to be executed in adaptor layer.. Defaults to None.
"""
super().__init__(model, conf, q_dataloader, q_func, eval_dataloader,\
eval_func, resume, q_hooks)
logger.info(f"*** Start auto tuning")
self.model = model
self.conf = conf
self.q_dataloader = q_dataloader
self.q_func = q_func
self.eval_dataloader = eval_dataloader
self.eval_func = eval_func
self.resume = resume
self.q_hooks = q_hooks
self.strategies_sequence = ['conservative', 'basic']

def sequential_traverse(self):
"""Try different strategies sequentially."""
pre_strategy = self
for strategy_name in self.strategies_sequence:
logger.info(f"*** Start {strategy_name} tuning.")
strategy = STRATEGIES[strategy_name](self.model, self.conf, self.q_dataloader, self.q_func, \
self.eval_dataloader, self.eval_func, self.resume, self.q_hooks)
if pre_strategy:
#TODO add tuning history from the previous stage to current stage.
strategy.baseline = deepcopy(pre_strategy.baseline)
strategy.trials_count = pre_strategy.trials_count
strategy.objectives.baseline = deepcopy(pre_strategy.baseline)
pre_strategy = strategy
strategy.traverse()
self.best_qmodel = strategy.best_qmodel
if self.best_qmodel:
return

def next_tune_cfg(self):
"""Generate and yield the default tuning config.
Returns:
tune_config (dict): A dict containing the tuning configuration for quantization.
"""
tuning_space = self.tuning_space
calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options
_, _, op_tuning_cfg = self.initial_tuning_cfg()
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size_lst[0]
logger.info(f"Quantize the model with default config.")
yield op_tuning_cfg

def traverse(self):
"""Traverse the tuning space."""
# Quantize model with default config
super().traverse()
if self.best_qmodel:
return
else:
# Start to try different strategies sequentially
self.sequential_traverse()
7 changes: 3 additions & 4 deletions neural_compressor/strategy/auto_mixed_precision.py
Expand Up @@ -112,13 +112,12 @@ def traverse(self):
# get fp32 model baseline
self._eval_baseline()

trials_count = 0
for op_tuning_cfg in self.next_tune_cfg():
# add tune_cfg here as quantize use tune_cfg
tune_cfg = self._tune_cfg_converter(op_tuning_cfg)
trials_count += 1
self.trials_count += 1
tuning_history = self._find_tuning_history(tune_cfg)
if tuning_history and trials_count < self.cfg.tuning.exit_policy.max_trials:
if tuning_history and self.trials_count < self.cfg.tuning.exit_policy.max_trials:
self.last_tune_result = tuning_history['last_tune_result']
self.best_tune_result = tuning_history['best_tune_result']
logger.warn("Find evaluated tuning config, skip.")
Expand All @@ -139,7 +138,7 @@ def traverse(self):
q_config = copy.deepcopy(self.last_qmodel.q_config)
self.last_tune_result = self._evaluate(self.last_qmodel)
self.cur_best_acc, self.cur_best_tuning_cfg = self.update_best_op_tuning_cfg(op_tuning_cfg)
need_stop = self.stop(self.cfg.tuning.exit_policy.timeout, trials_count)
need_stop = self.stop(self.cfg.tuning.exit_policy.timeout, self.trials_count)
# record the tuning history
saved_tune_cfg = copy.deepcopy(tune_cfg)
saved_last_tune_result = copy.deepcopy(self.last_tune_result)
Expand Down
15 changes: 7 additions & 8 deletions neural_compressor/strategy/strategy.py
Expand Up @@ -152,7 +152,7 @@ def __init__(self, model, conf, q_dataloader=None, q_func=None, eval_dataloader=
self.cur_best_acc = self.initial_best_acc() # track the current best accuracy
self.cur_best_tuning_cfg = {} # track tuning cfg with the current best accuracy
self.re_quant = False

self.trials_count = 0
self.capability = self.adaptor.query_fw_capability(model)
logger.debug(self.capability)
self.set_tuning_space(conf)
Expand Down Expand Up @@ -575,14 +575,13 @@ def traverse(self):
logger.info("use distributed traverse: {}".format(self.cfg.tuning.use_distributed_tuning))
if self.cfg.tuning.use_distributed_tuning:
return self.distributed_traverse()
trials_count = 0
traverse_start_time = time()
for op_tuning_cfg in self.next_tune_cfg():
tuning_start_time = time()
tune_cfg = self._tune_cfg_converter(op_tuning_cfg)
trials_count += 1
self.trials_count += 1
tuning_history = self._find_tuning_history(tune_cfg)
if tuning_history and trials_count < self.cfg.tuning.exit_policy.max_trials:
if tuning_history and self.trials_count < self.cfg.tuning.exit_policy.max_trials:
self.last_tune_result = tuning_history['last_tune_result']
self.best_tune_result = tuning_history['best_tune_result']
logger.warn("Find evaluated tuning config, skip.")
Expand Down Expand Up @@ -612,7 +611,7 @@ def traverse(self):
return
self.last_tune_result = self._evaluate(self.last_qmodel)
self.cur_best_acc, self.cur_best_tuning_cfg = self.update_best_op_tuning_cfg(op_tuning_cfg)
need_stop = self.stop(self.cfg.tuning.exit_policy.timeout, trials_count)
need_stop = self.stop(self.cfg.tuning.exit_policy.timeout, self.trials_count)

# record the tuning history
saved_tune_cfg = copy.deepcopy(tune_cfg)
Expand Down Expand Up @@ -1390,7 +1389,7 @@ def stop(self, timeout, trials_count):
else:
self.tune_data[name][2] = 'n/a'

logger.info("Tune {} result is: {}, Best tune result is: {}".format(trials_count,
logger.info("Tune {} result is: {}, Best tune result is: {}".format(self.trials_count,
last_tune_msg,
best_tune_msg))
output_data = [[info_type,
Expand All @@ -1410,15 +1409,15 @@ def stop(self, timeout, trials_count):
self.tuning_result_data = output_data
Statistics(output_data,
header='Tune Result Statistics',
field_names=['Info Type', 'Baseline', 'Tune {} result'.format(trials_count), \
field_names=['Info Type', 'Baseline', 'Tune {} result'.format(self.trials_count), \
'Best tune result']).print_stat()


if self.cfg.tuning.exit_policy.performance_only:
need_stop = True
elif timeout == 0 and self.best_tune_result:
need_stop = True
elif trials_count >= self.cfg.tuning.exit_policy.max_trials:
elif self.trials_count >= self.cfg.tuning.exit_policy.max_trials:
need_stop = True
else:
need_stop = False
Expand Down

0 comments on commit cdfb994

Please sign in to comment.