diff --git a/docs/source/tuning_strategies.md b/docs/source/tuning_strategies.md index 6e11941559a..f922b77feed 100644 --- a/docs/source/tuning_strategies.md +++ b/docs/source/tuning_strategies.md @@ -200,6 +200,39 @@ tuning: random_seed: 9527 ``` +### MSE_v2 + +#### Design + +`MSE_v2` is a two-stage fallback strategy for few-shot mixed quantization, +which is composed of three key components. First, a multi-batch order +combination based on per-layer fallback MSE values helps evaluate layer +sensitivity with few-shot. Second, a sensitivity gradient is proposed to +better evaluate the sensitivity, together with the beam search to solve +the local optimum problem. Third, a quantize-again procedure is introduced +to remove redundancy in fallback layers to protect performance. MSE_v2 performs +better especially in models with a long full-dataset evaluation time and a +large number of tuning counts. + +#### Usage +`MSE_v2` is similar to `MSE` in usage. To use the `MSE_v2` tuning strategy, +the specific strategy name of `mse_v2` must be included. Also, the option +`confidence_batches` can be included optionally to specify the count of batches +in sensitivity calculation process. + + +```yaml +tuning: + strategy: + name: mse_v2 + confidence_batches: 2 + accuracy_criterion: + relative: 0.01 + exit_policy: + timeout: 0 + random_seed: 9527 +``` + ### TPE #### Design diff --git a/examples/.config/model_params_pytorch.json b/examples/.config/model_params_pytorch.json index df42ff22308..fcd42576fcb 100644 --- a/examples/.config/model_params_pytorch.json +++ b/examples/.config/model_params_pytorch.json @@ -216,6 +216,33 @@ "batch_size": 100, "new_benchmark": false }, + "efficientnet_b0_fx": { + "model_src_dir": "image_recognition/torchvision_models/quantization/ptq/cpu/fx/", + "dataset_location": "/tf_dataset/pytorch/ImageNet/raw", + "input_model": "", + "yaml": "conf.yaml", + "strategy": "mse_v2", + "batch_size": 100, + "new_benchmark": false + }, + "efficientnet_b3_fx": { + "model_src_dir": "image_recognition/torchvision_models/quantization/ptq/cpu/fx/", + "dataset_location": "/tf_dataset/pytorch/ImageNet/raw", + "input_model": "", + "yaml": "conf.yaml", + "strategy": "mse_v2", + "batch_size": 100, + "new_benchmark": false + }, + "efficientnet_b7_fx": { + "model_src_dir": "image_recognition/torchvision_models/quantization/ptq/cpu/fx/", + "dataset_location": "/tf_dataset/pytorch/ImageNet/raw", + "input_model": "", + "yaml": "conf.yaml", + "strategy": "mse_v2", + "batch_size": 100, + "new_benchmark": false + }, "bert_base_MRPC": { "model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx", "dataset_location": "", diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 64dcd747092..9627ad2386a 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -3182,7 +3182,6 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops): Returns: None """ - module_dict = dict(model.named_modules()) for op_name, child in model.named_modules(): if self.is_fused_module(child): @@ -3507,6 +3506,28 @@ def _check_dynamic_control(module): logger.info('Module has no forward function') return False + def get_output_op_names(self, *args, **kwargs): + return None + + def calculate_op_sensitivity(self, model, dataloader, tune_cfg, output_op_names, + confidence_batches, fallback=True, requantize_cfgs=None): + """This is a helper function for `query_fw_capability`, + and it will get all quantizable ops from model. + + Args: + model (object): INC model containing fp32 model + dataloader (string): dataloader contains real data. + tune_cfg (dict): dictionary of tune configure for each op. + fallback (bool): switch method in fallback stage and re-quantize stage + + Returns: + ops_lst (list): sorted op list by sensitivity + """ + from .torch_utils.util import get_fallback_order + ordered_ops = get_fallback_order(self, model.model, dataloader, tune_cfg, + confidence_batches, fallback, requantize_cfgs) + return ordered_ops + class PyTorchQuery(QueryBackendCapability): def __init__(self, local_config_file=None): diff --git a/neural_compressor/adaptor/tensorflow.py b/neural_compressor/adaptor/tensorflow.py index da5118b61a8..d56760bfd7a 100644 --- a/neural_compressor/adaptor/tensorflow.py +++ b/neural_compressor/adaptor/tensorflow.py @@ -94,6 +94,8 @@ def __init__(self, framework_specific_info): self.optype_statistics = None + self._last_dequantize_ops = None + def log_histogram(self, writer, tag, values, step=0, bins=1000): import tensorflow as tf # Convert to a numpy array @@ -1453,8 +1455,161 @@ def recover_tuned_model(self, model, q_config): def diagnosis_helper(self, fp32_model, quan_model, tune_cfg, save_path): from .tf_utils.util import tf_diagnosis_helper return tf_diagnosis_helper(fp32_model, quan_model, tune_cfg, save_path) + + def get_output_op_names(self, qmodel): + from .tf_utils.graph_util import GraphAnalyzer + + graph_def = GraphAnalyzer().parse_graph(qmodel.graph_def) + output_op_names = set() + + for output_opname in qmodel.output_node_names: + op_count = 0 + stack = [output_opname] + while stack: + opname = stack.pop() + while True: + op_count += 1 + if opname not in graph_def: + break + op = graph_def[opname] + if op.node.op == 'Dequantize': + output_op_names.add(opname) + break + next_opnames = op.node.input + if not next_opnames: + break + elif len(next_opnames) > 1: + stack += next_opnames[1:] + + opname = next_opnames[0] + + output_op_names = list(output_op_names) + logger.debug(f"output op names: {output_op_names}") + return output_op_names + + def calculate_op_sensitivity(self, model, dataloader, tune_cfg, output_op_names, + confidence_batches, fallback=True, requantize_cfgs=None): + """Compute the op sensitivity. + + The sensitivity metric is the mse between the output of the last quantized op of + the quantized model and the output of its corresponding op in the fp32 model. + + 1. Backup the tune cfg + 2. Fallback each int8 op and compute its mse if use fallback (with 'fallback == True'), + or re-quantize each fp32 op(fallen back in the previous stage) and compute its MSE if not. + 3. Sorted op name list according to its MSE + + Args: + fp32_model: The fp32 model. + dataloader: the dataloader with full dataset. + tune_cfg: tuning config + fallback: denote fallback stage or re-quantize stage + requantize_cfgs: the dict of tuning configs for all re-quantizable ops + + Returns: + A list of op names, sorted by its MSE sensitivity. + """ + from copy import deepcopy + fp32_op_cfg = {'activation': {'dtype': 'fp32', 'quant_mode': 'fp32'}, + 'weight': {'dtype': 'fp32'}} + if fallback: + ops_list = [op for op, config in tune_cfg['op'].items() + if config['activation']['quant_mode'] in ('static', 'dynamic')] + replace_cfgs = {op : fp32_op_cfg for op in tune_cfg['op']} + else: + ops_list = [op for op, config in tune_cfg['op'].items() + if config['activation']['quant_mode'] == 'fp32' and op in requantize_cfgs] + replace_cfgs = requantize_cfgs + + # Step2. compute mse + mse_result = self._get_mse_order( + model, deepcopy(tune_cfg), replace_cfgs, ops_list, dataloader, + output_op_names, confidence_batches) + + # Step3. sort + mse_order = [op for op, _ in sorted(mse_result.items(), key=lambda i: i[1])] + logger.debug("Dump MSE order:") + for op in mse_order: + logger.debug(f"{op}: {mse_result[op]}") + return mse_order + + def _get_mse_order(self, fp32_model, tune_cfg, replace_cfgs, ops_lst, dataloader, + output_op_names, confidence_batches): + op_cfg = tune_cfg['op'] + mse_result = {} + partial_dataloader = self._partial_dataloader(dataloader, confidence_batches) + + fp32_output = self._inference_model_on_batches( + fp32_model, tune_cfg, partial_dataloader, output_op_names) + + for op in ops_lst: + # backup and set replace tuning config + backup_cfg = op_cfg[op] + op_cfg[op] = replace_cfgs[op] + + # quantize and inference the model + q_model = self.quantize(tune_cfg, fp32_model, partial_dataloader) + q_output = self._inference_model_on_batches( + q_model, tune_cfg, partial_dataloader, output_op_names) + + mse_result[op] = self._calculate_mse(fp32_output, q_output) + + # recover tune_cfg + op_cfg[op] = backup_cfg + + return mse_result + + def _partial_dataset_of(self, dataloader, confidence_batches): + from neural_compressor.experimental.data.datasets.dummy_dataset import DummyDataset + if isinstance(dataloader.dataset, DummyDataset): + assert(isinstance(confidence_batches, int)) + ds = copy.deepcopy(dataloader.dataset) + ds.dataset = ds.dataset[:confidence_batches] + return ds + else: + return dataloader.dataset.take(confidence_batches) + + def _partial_dataloader(self, dataloader, confidence_batches): + return type(dataloader)( + dataset=self._partial_dataset_of(dataloader, confidence_batches), + batch_size=dataloader.batch_size, + last_batch=dataloader.last_batch, + collate_fn=dataloader.collate_fn, + sampler=dataloader.sampler, + batch_sampler=dataloader.batch_sampler, + num_workers=dataloader.num_workers, + pin_memory=dataloader.pin_memory, + shuffle=dataloader.shuffle, + distributed=dataloader.distributed) + + def _calculate_mse(self, fp32_output, q_output): + result = [] + for i, j in zip(fp32_output, q_output): + result.append(np.square(i - j).mean()) + return np.array(result).mean() + + def _inference_model_on_batches(self, model, tune_cfg, dataloader, + output_op_names): + from .tf_utils.util import generate_feed_dict + + input_tensors = model.input_tensor + output_tensors = [] + for op in output_op_names: + for tensor in model.graph.get_operation_by_name(op).outputs: + output_tensors.append(tensor) + + predictions = [] + for index, (inputs, _) in enumerate(dataloader): + feed_dict = generate_feed_dict(input_tensors, inputs) + + pred = model.sess.run(output_tensors, feed_dict) + for item in pred: + predictions.append(item) + + return predictions + @adaptor_registry class Tensorflow_ITEXAdaptor(TensorFlowAdaptor): def __init__(self, framework_specific_info): diff --git a/neural_compressor/adaptor/tf_utils/graph_converter.py b/neural_compressor/adaptor/tf_utils/graph_converter.py index 5f255aa71f1..caa9afe033d 100644 --- a/neural_compressor/adaptor/tf_utils/graph_converter.py +++ b/neural_compressor/adaptor/tf_utils/graph_converter.py @@ -34,7 +34,7 @@ from .transform_graph.insert_logging import InsertLogging from .transform_graph.rerange_quantized_concat import RerangeQuantizedConcat from .transform_graph.bias_correction import BiasCorrection -from .util import iterator_sess_run,version1_gt_version2,version1_eq_version2 +from .util import generate_feed_dict, iterator_sess_run,version1_gt_version2,version1_eq_version2 from .util import version1_gte_version2,version1_lte_version2,version1_lt_version2 from .quantize_graph.quantize_graph_for_intel_cpu import QuantizeGraphForIntel from .quantize_graph_common import QuantizeGraphHelper diff --git a/neural_compressor/adaptor/tf_utils/util.py b/neural_compressor/adaptor/tf_utils/util.py index f95ea4f2d80..8a4ff70beb8 100644 --- a/neural_compressor/adaptor/tf_utils/util.py +++ b/neural_compressor/adaptor/tf_utils/util.py @@ -16,6 +16,7 @@ # limitations under the License. # +from collections import OrderedDict, UserDict import os import numpy as np from google.protobuf import text_format @@ -493,3 +494,62 @@ def _parse_config(q_config, cfg, op_list): if op_name_and_type[0] in op_list: updated_cfg['op'][op_name_and_type] = cfg['op'][op_name_and_type] return dequan_min_max, updated_cfg + +def generate_feed_dict(input_tensor, inputs): + if len(input_tensor) == 1: + feed_dict = {} + if isinstance(inputs, dict) or isinstance(inputs, OrderedDict) \ + or isinstance(inputs, UserDict): + for name in inputs: + for tensor in input_tensor: + pos = tensor.name.rfind(":") + t_name = tensor.name if pos < 0 else tensor.name[:pos] + if name == t_name: + feed_dict[tensor] = inputs[name] + break + else: + feed_dict = {input_tensor[0]: inputs} # get raw tensor using index [0] + else: + assert len(input_tensor) == len(inputs), \ + 'inputs len must equal with input_tensor' + feed_dict = {} + if isinstance(inputs, dict) or isinstance(inputs, OrderedDict) \ + or isinstance(inputs, UserDict): + for name in inputs: + for tensor in input_tensor: + pos = tensor.name.rfind(":") + t_name = tensor.name if pos < 0 else tensor.name[:pos] + if name in [tensor.name, t_name]: + feed_dict[tensor] = inputs[name] + break + else: + # sometimes the input_tensor is not the same order with inputs + # we should check and pair them + def check_shape(tensor, data): + # scalar or 1 dim default True + if tensor.shape == None or \ + len(tensor.shape.dims) == 1 or \ + not hasattr(data, 'shape'): + return True + tensor_shape = tuple(tensor.shape) + data_shape = tuple(data.shape) + for tensor_dim, data_dim in zip(tensor_shape, data_shape): + if tensor_dim is not None and tensor_dim != data_dim: + return False + return True + + disorder_tensors = [] + disorder_inputs = [] + for idx, sort_tensor in enumerate(input_tensor): + sort_input = inputs[idx] + if check_shape(sort_tensor, sort_input): + feed_dict.update({sort_tensor: sort_input}) + else: + disorder_tensors.append(sort_tensor) + disorder_inputs.append(sort_input) + for i, dis_tensor in enumerate(disorder_tensors): + for j, dis_input in enumerate(disorder_inputs): + if check_shape(dis_tensor, dis_input): + feed_dict.update({dis_tensor: dis_input}) + break + return feed_dict \ No newline at end of file diff --git a/neural_compressor/adaptor/torch_utils/util.py b/neural_compressor/adaptor/torch_utils/util.py index cddc0d6f4e8..71b15c02d36 100644 --- a/neural_compressor/adaptor/torch_utils/util.py +++ b/neural_compressor/adaptor/torch_utils/util.py @@ -14,13 +14,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import copy import re import numpy as np from collections import UserDict +from ...utils import logger from ...utils.utility import LazyImport, CpuInfo +tqdm = LazyImport("tqdm") torch = LazyImport("torch") def get_embedding_contiguous(model): @@ -492,3 +493,260 @@ def unwrap_proxy(a): torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment] new_module.__class__ = CopyDispatchModule return new_module + +def fetch_module(model, op_name): + module = model + name_list = op_name.split('.') + for name in name_list: + if hasattr(module, name): + module = getattr(module, name) + else: + module = module + return module + +def set_module(model, op_name, new_module): + module = model + name_list = op_name.split('.') + for name in name_list[:-1]: + if hasattr(module, name): + module = getattr(module, name) + else: + module = module + setattr(module, name_list[-1], new_module) + return module + +def simple_inference(model, input): + with torch.no_grad(): + if type(input) is dict: + output = model(**input) + elif type(input) is tuple or type(input) is list: + try: + output = model(*input) + except: + output = model(input) + else: + output = model(input) + return output + +def get_example_input(dataloader, i=1): + iter = 0 + try: + for example_inp, label in dataloader: + if iter == i: + break + else: + iter += 1 + except: + for example_inp in dataloader: + if iter == i: + break + else: + iter += 1 + return example_inp + + +def get_fallback_order(adaptor, fp32_model, dataloader, tune_cfg, + confidence_batches, fallback=False, requantize_cfgs=None): + fp32_model.eval() + order_dict = {} + for i in range(0, confidence_batches): + example_input = get_example_input(dataloader, i) + if fallback: + ordered_ops = get_mse_order_per_fp32(adaptor, fp32_model, example_input, tune_cfg) + for i, name in enumerate(ordered_ops): + order_dict[name] = order_dict.get(name, 0) + len(order_dict) - i + ordered_ops = sorted(order_dict, key=lambda k: order_dict[k], reverse=True) + else: + ordered_ops = get_mse_order_per_int8(adaptor, fp32_model, example_input, tune_cfg) + for i, name in enumerate(ordered_ops): + order_dict[name] = order_dict.get(name, 0) + len(order_dict) - i + return ordered_ops + +op_cfg_mapping = {} +def get_mse_order_per_fp32(adaptor, model, example_inp, tune_cfg): + """a helper method to check the mse influence to last module after QDQ(quant/dequant). + Args: + model(torch.fx.GraphModule/torch.nn.Module): A torch model. + dataloader(torch.utils.data.DataLoader): The calibration dataloader. + tune_cfg (dict): dictionary of quantization configuration. + Returns: + fallback_order (dict/list): The fallback order for strategy. + """ + + inner_output = None + def output_hook(self, input, output): + nonlocal inner_output + inner_output = output + return output + + op_type_dict = {} + for k, v in tune_cfg['op'].keys(): + op_type_dict[k] = v + + from ..pytorch import _cfg_to_qconfig, _cfgs_to_fx_cfgs, PyTorch_FXAdaptor + op_cfgs = _cfg_to_qconfig(tune_cfg, tune_cfg["approach"]) + # insert hook to get output tesnor from last module + last_module_name = list(op_cfgs.keys())[-1] + module = fetch_module(model, last_module_name) # get last module + module.register_forward_hook(output_hook) + # record fp32 model output tensor at first + output_fp32 = simple_inference(model, example_inp) + inner_output_fp32 = inner_output + + fx_op_cfgs = {} + fallback_order = {} + logger.info('Evaluate the sensitivity for each int8 operation') + for op_name, qconfig in tqdm(op_cfgs.items()): + global op_cfg_mapping + if op_name not in op_cfg_mapping: + op_cfg_mapping[op_name] = qconfig + tmp_model = copy.deepcopy(model) + if not qconfig: + continue + op_cfgs[op_name] = None + fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, tune_cfg["approach"]) + op_cfgs[op_name] = qconfig + from torch.quantization.quantize_fx import prepare_fx,convert_fx + # do quantization + if adaptor.sub_module_list is None: + tmp_model = prepare_fx(tmp_model, fx_op_cfgs,) + else: + PyTorch_FXAdaptor.prepare_sub_graph(adaptor.sub_module_list, fx_op_cfgs, \ + tmp_model, prefix='') + simple_inference(tmp_model, example_inp) + if adaptor.sub_module_list is None: + tmp_model = convert_fx(tmp_model) + else: + PyTorch_FXAdaptor.convert_sub_graph(adaptor.sub_module_list, \ + tmp_model, prefix='') + + # insert hook to get output tesnor from last module + module = fetch_module(tmp_model, list(op_cfgs.keys())[-1]) # get last module + module.register_forward_hook(output_hook) + output_qdq = simple_inference(tmp_model, example_inp) + inner_output_int8 = inner_output.dequantize() if \ + inner_output.dtype == torch.quint8 else inner_output + mse_val = (inner_output_fp32 - inner_output_int8).pow(2).sum() + fallback_order[(op_name, op_type_dict[op_name])] = mse_val + + ordered_ops = sorted(fallback_order.keys(), key=lambda key: fallback_order[key], \ + reverse=False) + min_mse, max_mse = fallback_order[ordered_ops[0]], fallback_order[ordered_ops[-1]] + + if min_mse < 0.8 * max_mse: + return ordered_ops + + + double_check_list = [] + for op_name in ordered_ops: + if min_mse <= fallback_order[op_name] <= (max_mse - min_mse) * 0.1 + min_mse: + double_check_list.append(op_name) + + check_num = min(len(ordered_ops)//10, 5) + double_check_list = ordered_ops[:check_num] + worst_op_name = ordered_ops[-1] + op_cfgs[worst_op_name[0]] = None # fallback worst module first + new_fallback_order = {} + + logger.info('Evaluate the sensitivity gradient for selected operations') + for op_name, op_type in tqdm(double_check_list): + tmp_model = copy.deepcopy(model) + qconfig = op_cfgs[op_name] + op_cfgs[op_name] = None + fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, tune_cfg["approach"]) + op_cfgs[op_name] = qconfig + from torch.quantization.quantize_fx import prepare_fx,convert_fx + # do quantization + if adaptor.sub_module_list is None: + tmp_model = prepare_fx(tmp_model, fx_op_cfgs,) + else: + PyTorch_FXAdaptor.prepare_sub_graph(adaptor.sub_module_list, fx_op_cfgs, \ + tmp_model, prefix='') + simple_inference(tmp_model, example_inp) + if adaptor.sub_module_list is None: + tmp_model = convert_fx(tmp_model) + else: + PyTorch_FXAdaptor.convert_sub_graph(adaptor.sub_module_list, \ + tmp_model, prefix='') + + # insert hook to get output tesnor from last module + module = fetch_module(tmp_model, last_module_name) # get last module + module.register_forward_hook(output_hook) + output_qdq = simple_inference(tmp_model, example_inp) + inner_output_int8 = inner_output.dequantize() if \ + inner_output.dtype == torch.quint8 else inner_output + mse_val = (inner_output_fp32 - inner_output_int8).pow(2).sum() + new_fallback_order[(op_name, op_type_dict[op_name])] = mse_val + + ordered_ops = sorted(new_fallback_order.keys(), key=lambda key: new_fallback_order[key], \ + reverse=False) + + return ordered_ops + +def get_mse_order_per_int8(adaptor, fp32_model, example_input, tune_cfg): + inner_output = None + def output_hook(self, input, output): + nonlocal inner_output + inner_output = output + return output + + op_type_dict = {} + for k, v in tune_cfg['op'].keys(): + op_type_dict[k] = v + + example_inp = example_input + + from ..pytorch import _cfg_to_qconfig + op_cfgs = _cfg_to_qconfig(tune_cfg, tune_cfg["approach"]) + module = fetch_module(fp32_model, list(op_cfgs.keys())[-1]) # get last module + # insert hook to get output tesnor from last module + module.register_forward_hook(output_hook) + # record fp32 model output tensor at first + output_fp32 = simple_inference(fp32_model, example_inp) + inner_output_fp32 = inner_output + + quant_list = [] + for k, v in tune_cfg['op'].items(): + if k[1] in ['LayerNorm', 'Dropout', 'InstanceNorm3d']: + continue + if v['weight']['dtype'] == 'fp32': + quant_list.append(k) + fallback_order = {} + logger.info('Evaluate the sensitivity for each fp32 operation') + for op_name, op_type in tqdm(quant_list): + if op_name in op_cfg_mapping: + tmp_model = copy.deepcopy(fp32_model) + from ..pytorch import _cfg_to_qconfig, _cfgs_to_fx_cfgs, PyTorch_FXAdaptor + op_cfgs[op_name] = op_cfg_mapping[op_name] + fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, tune_cfg["approach"]) + from torch.quantization.quantize_fx import prepare_fx,convert_fx + # do quantization + if adaptor.sub_module_list is None: + tmp_model = prepare_fx(tmp_model, fx_op_cfgs,) + else: + PyTorch_FXAdaptor.prepare_sub_graph(adaptor.sub_module_list, fx_op_cfgs, \ + tmp_model, prefix='') + simple_inference(tmp_model, example_inp) + if adaptor.sub_module_list is None: + tmp_model = convert_fx(tmp_model) + else: + PyTorch_FXAdaptor.convert_sub_graph(adaptor.sub_module_list, \ + tmp_model, prefix='') + + + # record int8 model output tensor + module = fetch_module(tmp_model, list(op_cfgs.keys())[-1]) # get last module + module.register_forward_hook(output_hook) + output_qdq = simple_inference(tmp_model, example_inp) + inner_output_int8 = inner_output + if inner_output_fp32.dtype == torch.quint8: + inner_output_fp32 = inner_output_fp32.dequantize() + if inner_output_int8.dtype == torch.quint8: + inner_output_int8 = inner_output_int8.dequantize() + + mse_val = (inner_output_fp32 - inner_output_int8).pow(2).sum() + fallback_order[(op_name, op_type_dict[op_name])] = mse_val + # re-insert fp32 module into model + ordered_ops = sorted(fallback_order.keys(), key=lambda key: fallback_order[key], \ + reverse=False) + return ordered_ops diff --git a/neural_compressor/conf/config.py b/neural_compressor/conf/config.py index 2cb547e6cbc..5e889d34cec 100644 --- a/neural_compressor/conf/config.py +++ b/neural_compressor/conf/config.py @@ -860,11 +860,13 @@ def percent_to_float(data): 'diagnosis': False, }): { Optional('strategy', default={'name': 'basic'}): { - 'name': And(str, lambda s: s in STRATEGIES), Optional('sigopt_api_token'): str, + 'name': And(str, lambda s: s in STRATEGIES), + Optional('sigopt_api_token'): str, Optional('sigopt_project_id'): str, Optional('sigopt_experiment_name', default='nc-tune'): str, Optional('accuracy_weight', default=1.0): float, - Optional('latency_weight', default=1.0): float + Optional('latency_weight', default=1.0): float, + Optional('confidence_batches', default=2): int } , Hook('accuracy_criterion', handler=_valid_accuracy_field): object, Optional('accuracy_criterion', default={'relative': 0.01}): { diff --git a/neural_compressor/config.py b/neural_compressor/config.py index 10c666f948b..584d9050108 100644 --- a/neural_compressor/config.py +++ b/neural_compressor/config.py @@ -419,7 +419,7 @@ def strategy(self): @strategy.setter def strategy(self, strategy): if check_value('strategy', strategy, str, - ['basic', 'mse', 'bayesian', 'random', 'exhaustive', 'sigopt', 'tpe']): + ['basic', 'mse', 'bayesian', 'random', 'exhaustive', 'sigopt', 'tpe', 'mse_v2']): self._strategy = strategy @property @@ -562,7 +562,7 @@ def strategy(self): @strategy.setter def strategy(self, strategy): if check_value('strategy', strategy, str, - ['basic', 'mse', 'bayesian', 'random', 'exhaustive', 'sigopt', 'tpe']): + ['basic', 'mse', 'bayesian', 'random', 'exhaustive', 'sigopt', 'tpe', 'mse_v2']): self._strategy = strategy @property diff --git a/neural_compressor/contrib/strategy/sigopt.py b/neural_compressor/contrib/strategy/sigopt.py index 18732ebe59c..8a1e7a34164 100644 --- a/neural_compressor/contrib/strategy/sigopt.py +++ b/neural_compressor/contrib/strategy/sigopt.py @@ -20,8 +20,8 @@ from neural_compressor.utils.utility import LazyImport from neural_compressor.strategy.strategy import strategy_registry, TuneStrategy from collections import OrderedDict -from neural_compressor.strategy.st_utils.tuning_sampler import OpWiseTuningSampler -from neural_compressor.strategy.st_utils.tuning_structs import OpTuningConfig +from neural_compressor.strategy.utils.tuning_sampler import OpWiseTuningSampler +from neural_compressor.strategy.utils.tuning_structs import OpTuningConfig sigopt = LazyImport('sigopt') diff --git a/neural_compressor/contrib/strategy/tpe.py b/neural_compressor/contrib/strategy/tpe.py index 6722b31cb38..730a9f9fef0 100644 --- a/neural_compressor/contrib/strategy/tpe.py +++ b/neural_compressor/contrib/strategy/tpe.py @@ -24,8 +24,8 @@ from neural_compressor.utils.utility import LazyImport from neural_compressor.strategy.strategy import strategy_registry, TuneStrategy from collections import OrderedDict -from neural_compressor.strategy.st_utils.tuning_sampler import OpWiseTuningSampler -from neural_compressor.strategy.st_utils.tuning_structs import OpTuningConfig +from neural_compressor.strategy.utils.tuning_sampler import OpWiseTuningSampler +from neural_compressor.strategy.utils.tuning_structs import OpTuningConfig hyperopt = LazyImport('hyperopt') diff --git a/neural_compressor/strategy/auto_mixed_precision.py b/neural_compressor/strategy/auto_mixed_precision.py index 4b59cf2cced..fc8350f8a10 100644 --- a/neural_compressor/strategy/auto_mixed_precision.py +++ b/neural_compressor/strategy/auto_mixed_precision.py @@ -21,8 +21,8 @@ from .strategy import strategy_registry, TuneStrategy from ..utils import logger -from .st_utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler -from .st_utils.tuning_structs import OpTuningConfig +from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler +from .utils.tuning_structs import OpTuningConfig @strategy_registry diff --git a/neural_compressor/strategy/basic.py b/neural_compressor/strategy/basic.py index c35398dd4bb..c3478789d82 100644 --- a/neural_compressor/strategy/basic.py +++ b/neural_compressor/strategy/basic.py @@ -21,9 +21,9 @@ from .strategy import strategy_registry, TuneStrategy from ..utils import logger -from .st_utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler, ModelWiseTuningSampler -from .st_utils.tuning_structs import OpTuningConfig -from .st_utils.tuning_space import TUNING_ITEMS_LST +from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler, ModelWiseTuningSampler +from .utils.tuning_structs import OpTuningConfig +from .utils.tuning_space import TUNING_ITEMS_LST @strategy_registry class BasicTuneStrategy(TuneStrategy): diff --git a/neural_compressor/strategy/bayesian.py b/neural_compressor/strategy/bayesian.py index 6090d75faf3..e36371cd88d 100644 --- a/neural_compressor/strategy/bayesian.py +++ b/neural_compressor/strategy/bayesian.py @@ -27,8 +27,8 @@ from ..utils import logger from .strategy import strategy_registry, TuneStrategy -from .st_utils.tuning_sampler import OpWiseTuningSampler -from .st_utils.tuning_structs import OpTuningConfig +from .utils.tuning_sampler import OpWiseTuningSampler +from .utils.tuning_structs import OpTuningConfig @strategy_registry diff --git a/neural_compressor/strategy/conservative.py b/neural_compressor/strategy/conservative.py index d4806e59ad5..32c80a69f45 100644 --- a/neural_compressor/strategy/conservative.py +++ b/neural_compressor/strategy/conservative.py @@ -25,7 +25,7 @@ from typing import Dict, List, Tuple, OrderedDict from .strategy import strategy_registry, TuneStrategy -from .st_utils.tuning_space import TuningItem +from .utils.tuning_space import TuningItem from ..utils import logger from ..utils.utility import Statistics diff --git a/neural_compressor/strategy/exhaustive.py b/neural_compressor/strategy/exhaustive.py index 9a2320ed820..fb329332c8c 100644 --- a/neural_compressor/strategy/exhaustive.py +++ b/neural_compressor/strategy/exhaustive.py @@ -20,8 +20,8 @@ from collections import OrderedDict from .strategy import strategy_registry, TuneStrategy -from .st_utils.tuning_sampler import OpWiseTuningSampler, FallbackTuningSampler, ModelWiseTuningSampler -from .st_utils.tuning_structs import OpTuningConfig +from .utils.tuning_sampler import OpWiseTuningSampler, FallbackTuningSampler, ModelWiseTuningSampler +from .utils.tuning_structs import OpTuningConfig from ..utils import logger @strategy_registry diff --git a/neural_compressor/strategy/mse.py b/neural_compressor/strategy/mse.py index 8dcc060ef42..7783c2fee57 100644 --- a/neural_compressor/strategy/mse.py +++ b/neural_compressor/strategy/mse.py @@ -16,14 +16,17 @@ # limitations under the License. import copy +from copy import deepcopy import numpy as np from collections import OrderedDict from typing import Dict, Any, List from .strategy import strategy_registry, TuneStrategy from ..utils import logger +from time import time -from .st_utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler -from .st_utils.tuning_structs import OpTuningConfig +from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler +from .utils.tuning_structs import OpTuningConfig +from .utils.helper import tuning_record_msg @strategy_registry class MSETuneStrategy(TuneStrategy): @@ -175,7 +178,7 @@ def next_tune_cfg(self): initial_op_tuning_cfg[item.name] = OpTuningConfig(op_name, op_type, 'fp32', tuning_space) calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options for calib_sampling_size in calib_sampling_size_lst: - # step1. collect the ops that support static and dynamic + # Collect the ops that support static and dynamic quant_mode_wise_items = OrderedDict() query_order = ['static', 'dynamic', 'bf16', 'fp32'] pre_items = set() @@ -193,9 +196,9 @@ def initial_op_quant_mode(items_lst, target_quant_mode, op_item_dtype_dict): for quant_mode, quant_mode_items in quant_mode_wise_items.items(): initial_op_quant_mode(quant_mode_items, quant_mode, op_item_dtype_dict) - # step3. optype-wise tuning tuning items: the algorithm/scheme/granularity of activation(weight) - early_stop_tuning = False - stage1_cnt = 0 + # Optype-wise tuning + early_stop_tuning = True + stage1_cnt = 0 int8_ops = quant_mode_wise_items['dynamic'] + quant_mode_wise_items['static'] stage1_max = min(5, len(int8_ops)) # TODO set a more appropriate value op_wise_tuning_sampler = OpTypeWiseTuningSampler(tuning_space, [], [], @@ -208,14 +211,13 @@ def initial_op_quant_mode(items_lst, target_quant_mode, op_item_dtype_dict): op_tuning_cfg['calib_sampling_size'] = calib_sampling_size yield op_tuning_cfg - # step4. fallback the ops supported both static and dynamic from static to dynamic - # tuning items: None + # Fallback the ops supported both static and dynamic from static to dynamic static_dynamic_items = [item for item in tuning_space.query_items_by_quant_mode('static') if item in tuning_space.query_items_by_quant_mode('dynamic')] if static_dynamic_items: logger.info("Fallback all ops that support both dynamic and static to dynamic.") else: - logger.info("Non ops that support both dynamic") + logger.info("No op support both dynamic and static") def dynamic_op_tuning_cfg_from_static(op_tuning_cfg: OpTuningConfig): new_op_tuning_cfg = deepcopy(op_tuning_cfg) @@ -230,14 +232,13 @@ def dynamic_op_tuning_cfg_from_static(op_tuning_cfg: OpTuningConfig): best_op_tuning_cfg_stage1 = deepcopy(self.cur_best_tuning_cfg) - # step5. fallback + # Fallback to float point datatypes ('bf16' or 'fp32') for target_dtype in ['bf16', 'fp32']: fallback_items_lst = [item for item in int8_ops if item in tuning_space.query_items_by_quant_mode(target_dtype)] if fallback_items_lst: logger.info(f"Start to fallback op to {target_dtype} one by one.") - self._fallback_started() - # replace it with sorted items list + # Replace it with sorted items list fallback_items_name_lst = [item.name for item in fallback_items_lst] # TODO check the best_qmodel ordered_op_name_types = self.mse_impact_lst(fallback_items_name_lst, self.model, self.best_qmodel) @@ -254,11 +255,11 @@ def dynamic_op_tuning_cfg_from_static(op_tuning_cfg: OpTuningConfig): acc, _ = self.last_tune_result op_fallback_acc_impact[fallback_items_name_lst[op_index]] = acc - # do accumulated fallback according to the order in the previous stage + # Do accumulated fallback according to the order in the previous stage if len(op_fallback_acc_impact) > 0: ordered_ops = sorted(op_fallback_acc_impact.keys(), - key=lambda key: op_fallback_acc_impact[key], - reverse=self.higher_is_better) + key=lambda key: op_fallback_acc_impact[key], + reverse=self.higher_is_better) op_dtypes = OrderedDict(zip(ordered_ops, [target_dtype] * len(fallback_items_name_lst))) logger.info(f"Start to accumulate fallback to {target_dtype}.") initial_op_tuning_cfg = deepcopy(best_op_tuning_cfg_stage1) diff --git a/neural_compressor/strategy/mse_v2.py b/neural_compressor/strategy/mse_v2.py new file mode 100644 index 00000000000..7abc424283e --- /dev/null +++ b/neural_compressor/strategy/mse_v2.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from copy import deepcopy +import numpy as np +from collections import OrderedDict +from typing import Dict, Any, List +from .strategy import strategy_registry, TuneStrategy +from ..utils import logger +from time import time + +from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler +from .utils.tuning_structs import OpTuningConfig +from .utils.helper import tuning_record_msg + +@strategy_registry +class MSE_V2TuneStrategy(TuneStrategy): + """The tuning strategy using MSE policy in tuning space. + + This MSE policy runs fp32 model and int8 model seperately to get all activation tensors, + and then compares those tensors by MSE algorithm to order all ops with MSE distance for + deciding the impact of each op to final accuracy. + It will be used to define opwise tuningspace by priority. + + Args: + model (object): The FP32 model specified for low precision tuning. + conf (Class): The Conf class instance initialized from user yaml + config file. + q_dataloader (generator): Data loader for calibration, mandatory for + post-training quantization. + It is iterable and should yield a tuple (input, + label) for calibration dataset containing label, + or yield (input, _) for label-free calibration + dataset. The input could be a object, list, tuple or + dict, depending on user implementation, as well as + it can be taken as model input. + q_func (function, optional): Reserved for future use. + eval_dataloader (generator, optional): Data loader for evaluation. It is iterable + and should yield a tuple of (input, label). + The input could be a object, list, tuple or dict, + depending on user implementation, as well as it can + be taken as model input. The label should be able + to take as input of supported metrics. If this + parameter is not None, user needs to specify + pre-defined evaluation metrics through configuration + file and should set "eval_func" parameter as None. + Tuner will combine model, eval_dataloader and + pre-defined metrics to run evaluation process. + eval_func (function, optional): The evaluation function provided by user. + This function takes model as parameter, and + evaluation dataset and metrics should be + encapsulated in this function implementation and + outputs a higher-is-better accuracy scalar value. + + The pseudo code should be something like: + + def eval_func(model): + input, label = dataloader() + output = model(input) + accuracy = metric(output, label) + return accuracy + dicts (dict, optional): The dict containing resume information. + Defaults to None. + + """ + + def __init__(self, model, conf, q_dataloader, q_func=None, + eval_dataloader=None, eval_func=None, dicts=None, q_hooks=None): + self.ordered_ops = None + super( + MSE_V2TuneStrategy, + self).__init__( + model, + conf, + q_dataloader, + q_func, + eval_dataloader, + eval_func, + dicts, + q_hooks) + + def __getstate__(self): + for history in self.tuning_history: + if self._same_yaml(history['cfg'], self.cfg): + history['ordered_ops'] = self.ordered_ops + save_dict = super().__getstate__() + return save_dict + + def next_tune_cfg(self): + """The generator of yielding next tuning config to traverse by concrete strategies + according to last tuning result. + + Yields: + tune_config (dict): It's a dict containing the tuning configuration to run. + """ + + best_op_tuning_cfg = None + if len(self.metric_name) == 1 or self.metric_weight is not None: + best_acc = float('-inf') if self.higher_is_better else float('inf') + else: + best_acc = [float('-inf') if higher_is_better else float('inf') for \ + higher_is_better in self.metric_criterion] + + from copy import deepcopy + tuning_space = self.tuning_space + initial_op_tuning_cfg = {} + for item in tuning_space.root_item.options: + if item.item_type == 'op': + op_name, op_type = item.name + initial_op_tuning_cfg[item.name] = OpTuningConfig(op_name, op_type, 'fp32', tuning_space) + calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options + for calib_sampling_size in calib_sampling_size_lst: + # Collect the ops that support static and dynamic + quant_mode_wise_items = OrderedDict() + query_order = ['static', 'dynamic', 'bf16', 'fp16', 'fp32'] + pre_items = set() + for quant_mode in query_order: + items = tuning_space.query_items_by_quant_mode(quant_mode) + filtered_items = [item for item in items if item not in pre_items] + pre_items = pre_items.union(set(items)) + quant_mode_wise_items[quant_mode] = filtered_items + + def initial_op_quant_mode(items_lst, target_quant_mode, op_item_dtype_dict): + for item in items_lst: + op_item_dtype_dict[item.name] = target_quant_mode + + op_item_dtype_dict = OrderedDict() + for quant_mode, quant_mode_items in quant_mode_wise_items.items(): + initial_op_quant_mode(quant_mode_items, quant_mode, op_item_dtype_dict) + + # Optype-wise tuning + early_stop_tuning = True + stage1_cnt = 0 + int8_ops = quant_mode_wise_items['dynamic'] + quant_mode_wise_items['static'] + stage1_max = 2 # TODO set a more appropriate value + op_wise_tuning_sampler = OpTypeWiseTuningSampler(tuning_space, [], [], + op_item_dtype_dict, initial_op_tuning_cfg) + for op_tuning_cfg in op_wise_tuning_sampler: + stage1_cnt += 1 + if early_stop_tuning and stage1_cnt > stage1_max: + logger.info("Early stopping the stage 1.") + break + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield op_tuning_cfg + + # Fallback the ops supported both static and dynamic from static to dynamic + static_dynamic_items = [item for item in tuning_space.query_items_by_quant_mode('static') if + item in tuning_space.query_items_by_quant_mode('dynamic')] + if static_dynamic_items: + logger.info("Fallback all ops that support both dynamic and static to dynamic.") + else: + logger.info("No op support both dynamic and static") + + def dynamic_op_tuning_cfg_from_static(op_tuning_cfg: OpTuningConfig): + new_op_tuning_cfg = deepcopy(op_tuning_cfg) + new_op_tuning_cfg.op_quant_mode = 'dynamic' + return new_op_tuning_cfg + + new_op_tuning_cfg = deepcopy(self.cur_best_tuning_cfg) + for item in static_dynamic_items: + new_op_tuning_cfg[item.name] = dynamic_op_tuning_cfg_from_static(new_op_tuning_cfg[item.name]) + new_op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield new_op_tuning_cfg + + # Fallback one by one by op sensitivity(mse) + # 1. while the accuracy requirements not met: # to improve the accuracy + # 1) calculate the sensitivity of int8 ops in current state. + # 2) fallback the op with higher sensitivity accumulatively + # 2. after the accuracy requirements met: # to improve the performance + # 1) calculate the sensitivity of fp32 ops in the current state + # 2) re-quantize the op with lower sensitivity accumulatively + tune_cfg = deepcopy(self.cur_best_tuning_cfg) + requantize_cfg = deepcopy(self._tune_cfg_converter(self.cur_best_tuning_cfg)) + self.output_op_names = self.adaptor.get_output_op_names(self.cur_best_qmodel) + self.confidence_batches = (self.cfg.tuning.strategy.confidence_batches + if self.cfg.tuning.strategy.confidence_batches != None else 2) + tune_cfg_backup = deepcopy(tune_cfg) + quant_ops_in_tune_cfg = self._collect_ops_by_quant_mode(tune_cfg, 'dynamic') + \ + self._collect_ops_by_quant_mode(tune_cfg, 'static') + op_quant_cfgs = {op_info: tune_cfg_backup[op_info] for op_info in quant_ops_in_tune_cfg} + fallback_records = [] + self.re_quant = True + while not self.objectives.compare(self.last_tune_result, self.baseline): + # Record the time of calcutating the sensitivity + start = time() + ops_lst = self.adaptor.calculate_op_sensitivity(self.model, + self.calib_dataloader, + deepcopy(self._tune_cfg_converter(tune_cfg)), + self.output_op_names, + self.confidence_batches, + fallback=True) + logger.debug(f"*** The op sensitivity analysis took {time() - start:.2f}s.") + select_op_info = ops_lst[0] + logger.info(f"*** The op {select_op_info} have the highest sensitivity in the current state, \ + fallback it to fp32.") + tune_cfg[select_op_info] = OpTuningConfig(select_op_info[0], + select_op_info[1], + 'fp32', + self.tuning_space) + # Record the fallback history + if not fallback_records: + fallback_records = [[select_op_info]] + else: + fallback_records.append(fallback_records[-1] + [select_op_info]) + logger.debug(f"*** The fallback ops record: \n{tuning_record_msg(fallback_records)}") + yield tune_cfg + + logger.info(f"*** The accuracy meeting the accuracy requirements, stop fallback ops.") + while self.objectives.compare(self.last_tune_result, self.baseline): + if len(fallback_records) == 0 or len(fallback_records[-1]) <= 1: + logger.info(f"*** Stop re-quant due to no int8 op or only 1 int8 op left.") + break + logger.info(f"*** Start to re-quant the fallback op in the previous stage.") + # Track the current fallback ops + tmp_fallback_ops = fallback_records[-1] if fallback_records else [] + start = time() + ops_lst = self.adaptor.calculate_op_sensitivity(self.model, + self.calib_dataloader, + deepcopy(self._tune_cfg_converter(tune_cfg)), + self.output_op_names, + self.confidence_batches, + fallback=False, + requantize_cfgs=requantize_cfg['op']) + logger.debug(f"*** The op sensitivity analysis took {time() - start:.2f}s.") + if not ops_lst: + logger.warning("No op to be requantized") + break + for select_op_info in ops_lst: + #assert select_op_info in tmp_fallback_ops, f"{select_op_info} not in fallback list." + if select_op_info not in tmp_fallback_ops: + logger.debug(f"{select_op_info} not in fallback list.") + continue + + new_fallback_ops = deepcopy(tmp_fallback_ops) + new_fallback_ops.remove(select_op_info) + if new_fallback_ops not in fallback_records: + logger.info(f"*** The op {select_op_info} have the lowest sensitivity in the current state, \ + re-quantize it.") + tune_cfg[select_op_info] = op_quant_cfgs[select_op_info] + fallback_records.append(new_fallback_ops) + logger.debug(f"*** The fallback ops record: \n{tuning_record_msg(fallback_records)}") + yield tune_cfg + break + else: + logger.debug(f"*** Skip re-qaunt {select_op_info}, due the config has been evallated.") + continue + self.re_quant = False + logger.info(f"*** The accuracy not meeting the accuracy requirements, stop re-quantize ops.") \ No newline at end of file diff --git a/neural_compressor/strategy/random.py b/neural_compressor/strategy/random.py index d1c65011375..812b9e40003 100644 --- a/neural_compressor/strategy/random.py +++ b/neural_compressor/strategy/random.py @@ -19,8 +19,8 @@ from .strategy import strategy_registry, TuneStrategy from collections import OrderedDict -from .st_utils.tuning_sampler import OpWiseTuningSampler, FallbackTuningSampler -from .st_utils.tuning_structs import OpTuningConfig +from .utils.tuning_sampler import OpWiseTuningSampler, FallbackTuningSampler +from .utils.tuning_structs import OpTuningConfig from ..utils import logger @strategy_registry diff --git a/neural_compressor/strategy/strategy.py b/neural_compressor/strategy/strategy.py index db4cae2b1d4..a5ab572b288 100644 --- a/neural_compressor/strategy/strategy.py +++ b/neural_compressor/strategy/strategy.py @@ -42,12 +42,13 @@ import copy import numpy as np from collections import OrderedDict +from time import time from ..utils import logger -from .st_utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler -from .st_utils.tuning_space import TuningItem, TuningSpace -from .st_utils.tuning_structs import OpTuningConfig +from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler +from .utils.tuning_space import TuningItem, TuningSpace +from .utils.tuning_structs import OpTuningConfig STRATEGIES = {} @@ -140,6 +141,7 @@ def __init__(self, model, conf, q_dataloader=None, q_func=None, self.tune_data = {} self.tune_result_record = [] self.tuning_history = [] + self.tuning_result_data = [] # The tuning history ever made, structured like below: # [ # { @@ -170,6 +172,8 @@ def __init__(self, model, conf, q_dataloader=None, q_func=None, self.best_qmodel = None self.cur_best_acc = self.initial_best_acc() # track the current best accuracy self.cur_best_tuning_cfg = {} # track tuning cfg with the current best accuracy + self.cur_best_qmodel = None # track quantized model with the current best accuracy + self.re_quant = False self.capability = self.adaptor.query_fw_capability(model) logger.debug(self.capability) @@ -226,8 +230,9 @@ def traverse(self): self.show_baseline_info() trials_count = 0 - + traverse_start_time = time() for op_tuning_cfg in self.next_tune_cfg(): + tuning_start_time = time() tune_cfg = self._tune_cfg_converter(op_tuning_cfg) trials_count += 1 tuning_history = self._find_tuning_history(tune_cfg) @@ -263,8 +268,22 @@ def traverse(self): q_config=self.q_model.q_config) self.tune_result_record.append(copy.deepcopy(self.last_tune_result)) self.tune_cfg = tune_cfg + now_time = time() + acc_res_msg = "" + performace_res_msg = "" + if self.tuning_result_data: + acc_res_msg = "[ " + "| ".join(self.tuning_result_data[0]) + " ]" + performace_res_msg = "[ " + "| ".join(self.tuning_result_data[1]) + " ]" + logger.debug(f"*** The accuracy of last tuning is: {acc_res_msg}") + logger.debug(f"*** The perfomance of last tuning is: {performace_res_msg}") + logger.debug(f"*** The last tuning time: {(now_time - tuning_start_time):.2f} s") + logger.debug(f"*** The tuning process lasted time: {(now_time - traverse_start_time):.2f} s") + self._dump_tuning_process_statistics() if need_stop: + if self.re_quant: + logger.info("*** Do not stop the tuning process, re-quantize the ops.") + continue if self.cfg.tuning.diagnosis and self.cfg.tuning.diagnosis.diagnosis_after_tuning: logger.debug(f'*** Start to do diagnosis (inspect tensor).') self._diagnosis() @@ -277,6 +296,7 @@ def traverse(self): self.best_qmodel = recover(self.model.model, os.path.join(self.cfg.tuning.workspace.path, 'history.snapshot'), best_trail) + logger.debug(f"*** Update the best qmodel by recovering from history.") self.best_tune_result = best_result self._dump_tuning_process_statistics() break @@ -461,6 +481,7 @@ def _tune_cfg_converter(self, op_tuning_cfg): else: tune_cfg['calib_iteration'] = 1 tune_cfg['advance'] = self.cfg.quantization.advance + tune_cfg['approach'] = self.cfg.quantization.approach return tune_cfg def set_tuning_space(self, conf): @@ -618,22 +639,26 @@ def update_best_op_tuning_cfg(self, op_tuning_cfg): acc, _ = self.last_tune_result if self.cur_best_tuning_cfg is None: self.cur_best_tuning_cfg = copy.deepcopy(op_tuning_cfg) + self.cur_best_qmodel = self.last_qmodel if not isinstance(acc, list) and ((self.higher_is_better and acc >= self.cur_best_acc) \ or (not self.higher_is_better and acc <= self.cur_best_acc)): self.cur_best_acc = acc self.cur_best_tuning_cfg = copy.deepcopy(op_tuning_cfg) + self.cur_best_qmodel = self.last_qmodel elif len(self.metric_name) > 1 and self.metric_weight is not None: acc = np.mean(np.array(acc) * self.metric_weight) if (self.higher_is_better and acc >= self.cur_best_acc) or \ (not self.higher_is_better and acc <= self.cur_best_acc): self.cur_best_acc = acc self.cur_best_tuning_cfg = copy.deepcopy(op_tuning_cfg) + self.cur_best_qmodel = self.last_qmodel elif len(self.metric_name) > 1 and self.metric_weight is None: if all([acc_i >= best_i if higher_is_better else acc_i <= best_i for \ acc_i, best_i, higher_is_better in \ zip(acc, self.cur_best_acc, self.metric_criterion)]): self.cur_best_acc = acc - self.cur_best_tuning_cfg = copy.deepcopy(op_tuning_cfg) + self.cur_best_tuning_cfg = copy.deepcopy(op_tuning_cfg) + self.cur_best_qmodel = self.last_qmodel logger.debug(f"Best acc is {self.cur_best_acc}.") return self.cur_best_acc, self.cur_best_tuning_cfg @@ -810,10 +835,18 @@ def stop(self, timeout, trials_count): del self.best_qmodel self.best_tune_result = self.last_tune_result self.best_qmodel = self.last_qmodel + logger.debug(f"*** Update the best qmodel with the result {self.best_tune_result}") if self.metric_met_point == 0: self.metric_met_point = self.tuning_times - else: - del self.last_qmodel + + # track the model with highest acc + if self.best_tune_result and self.last_tune_result: # (acc, [perf]) + if self.re_quant and self.objectives.accuracy_meets(): + self.best_tune_result = self.last_tune_result + self.best_qmodel = self.last_qmodel + logger.debug(f"*** Update the best qmodel with the result {self.best_tune_result}.") + else: + logger.debug(f"*** Accuracy not meets the requirements, do not update the best qmodel.") if self.last_tune_result: last_tune = self.last_tune_result[0] if \ @@ -900,7 +933,7 @@ def stop(self, timeout, trials_count): '{:.4f} '.format(self.last_tune_result[1][i]) if self.last_tune_result else 'n/a', '{:.4f} '.format(self.best_tune_result[1][i]) if self.best_tune_result else 'n/a'] \ for i, obj in enumerate(self.objectives.representation)]) - + self.tuning_result_data = output_data Statistics(output_data, header='Tune Result Statistics', field_names=['Info Type', 'Baseline', 'Tune {} result'.format(trials_count), \ @@ -1010,6 +1043,13 @@ def _add_tuning_history(self, tune_cfg=None, tune_result=None, **kwargs): def _fake_eval_func(self, model): return 1. + def _collect_ops_by_quant_mode(self, tune_cfg, quant_mode): + ops_lst = [] + for op_info, op_config in tune_cfg.items(): + if isinstance(op_config, OpTuningConfig) and quant_mode in op_config.op_quant_mode: + ops_lst.append(op_info) + return ops_lst + def _diagnosis(self): import logging logger = logging.getLogger("neural_compressor") diff --git a/neural_compressor/strategy/st_utils/__init__.py b/neural_compressor/strategy/utils/__init__.py similarity index 88% rename from neural_compressor/strategy/st_utils/__init__.py rename to neural_compressor/strategy/utils/__init__.py index e2fa444b0ba..db8d0fcfdf8 100644 --- a/neural_compressor/strategy/st_utils/__init__.py +++ b/neural_compressor/strategy/utils/__init__.py @@ -17,4 +17,5 @@ from .tuning_sampler import TuningSampler, OpWiseTuningSampler, OpTypeWiseTuningSampler, FallbackTuningSampler from .tuning_structs import OpTuningConfig -from .tuning_space import TuningItem, TuningSpace \ No newline at end of file +from .tuning_space import TuningItem, TuningSpace +from .helper import tuning_record_msg \ No newline at end of file diff --git a/neural_compressor/strategy/utils/helper.py b/neural_compressor/strategy/utils/helper.py new file mode 100644 index 00000000000..ce3ca6867bd --- /dev/null +++ b/neural_compressor/strategy/utils/helper.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def tuning_record_msg(records): + records_str_lst = [[str(e) for e in record] for record in records] + record_msg = '\n'.join(','.join(record) for record in records_str_lst) + return record_msg \ No newline at end of file diff --git a/neural_compressor/strategy/st_utils/tuning_sampler.py b/neural_compressor/strategy/utils/tuning_sampler.py similarity index 100% rename from neural_compressor/strategy/st_utils/tuning_sampler.py rename to neural_compressor/strategy/utils/tuning_sampler.py diff --git a/neural_compressor/strategy/st_utils/tuning_space.py b/neural_compressor/strategy/utils/tuning_space.py similarity index 100% rename from neural_compressor/strategy/st_utils/tuning_space.py rename to neural_compressor/strategy/utils/tuning_space.py diff --git a/neural_compressor/strategy/st_utils/tuning_structs.py b/neural_compressor/strategy/utils/tuning_structs.py similarity index 100% rename from neural_compressor/strategy/st_utils/tuning_structs.py rename to neural_compressor/strategy/utils/tuning_structs.py diff --git a/test/adaptor/tensorflow_adaptor/test_tensorflow_calculate_op_sensitivity.py b/test/adaptor/tensorflow_adaptor/test_tensorflow_calculate_op_sensitivity.py new file mode 100644 index 00000000000..5a9c5af6c0e --- /dev/null +++ b/test/adaptor/tensorflow_adaptor/test_tensorflow_calculate_op_sensitivity.py @@ -0,0 +1,136 @@ +import os +import shutil +import unittest +import tensorflow as tf +import numpy as np + +def build_msev2_yaml(): + mse_yaml = ''' + model: + name: fake_yaml + framework: tensorflow + inputs: x + outputs: op2_to_store + device: cpu + evaluation: + accuracy: + metric: + topk: 1 + tuning: + strategy: + name: mse_v2 + accuracy_criterion: + relative: 0.01 + exit_policy: + max_trials: 10 + timeout: 3600 + ''' + with open('mse_yaml.yaml', 'w', encoding="utf-8") as f: + f.write(mse_yaml) + +def build_fake_model(): + try: + graph = tf.Graph() + graph_def = tf.compat.v1.GraphDef() + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(tf.float32, shape=(1,3,3,1), name='x') + y = tf.constant(np.random.random((2,2,1,1)).astype(np.float32), name='y') + z = tf.constant(np.random.random((1,1,1,1)).astype(np.float32), name='z') + op = tf.nn.conv2d(input=x, filters=y, strides=[1,1,1,1], padding='VALID', name='op_to_store') + op2 = tf.nn.conv2d(input=op, filters=z, strides=[1,1,1,1], padding='VALID', ) + last_identity = tf.identity(op2, name='op2_to_store') + sess.run(tf.compat.v1.global_variables_initializer()) + constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op2_to_store']) + + graph_def.ParseFromString(constant_graph.SerializeToString()) + with graph.as_default(): + tf.import_graph_def(graph_def, name='') + except: + graph = tf.Graph() + graph_def = tf.compat.v1.GraphDef() + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(tf.float32, shape=(1,3,3,1), name='x') + y = tf.constant(np.random.random((2,2,1,1)).astype(np.float32), name='y') + z = tf.constant(np.random.random((1,1,1,1)).astype(np.float32), name='z') + op = tf.nn.conv2d(input=x, filters=y, strides=[1,1,1,1], padding='VALID', name='op_to_store') + op2 = tf.nn.conv2d(input=op, filters=z, strides=[1,1,1,1], padding='VALID') + last_identity = tf.identity(op2, name='op2_to_store') + + sess.run(tf.compat.v1.global_variables_initializer()) + constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op2_to_store']) + + graph_def.ParseFromString(constant_graph.SerializeToString()) + with graph.as_default(): + tf.import_graph_def(graph_def, name='') + return graph + +class TestGetOutputTensor(unittest.TestCase): + @classmethod + def setUpClass(self): + build_msev2_yaml() + self.model = build_fake_model() + + @classmethod + def tearDownClass(self): + os.remove('mse_yaml.yaml') + shutil.rmtree('./saved', ignore_errors=True) + shutil.rmtree('runs', ignore_errors=True) + + def test_get_output_op_names(self): + from neural_compressor.experimental import Quantization, common + + quantizer = Quantization('mse_yaml.yaml') + dataset = quantizer.dataset('dummy', (100, 3, 3, 1), label=True) + quantizer.calib_dataloader = common.DataLoader(dataset) + quantizer.eval_dataloader = common.DataLoader(dataset) + quantizer.model = self.model + qmodel = quantizer.fit() + + self.assertEqual( + quantizer.strategy.adaptor.get_output_op_names(qmodel), + ["Conv2D_dummy_biasadd"]) + + + def test_calculate_op_sensitivity(self): + from neural_compressor.experimental import Quantization, common + + quantizer = Quantization("mse_yaml.yaml") + quantizer.model = self.model + dataset = quantizer.dataset('dummy', (100, 3, 3, 1), label=True) + quantizer.calib_dataloader = common.DataLoader(dataset) + quantizer.eval_dataloader = common.DataLoader(dataset) + quantizer.pre_process() + + dataloader = quantizer._calib_dataloader + strategy = quantizer.strategy + adaptor = strategy.adaptor + tune_cfg_generator = strategy.next_tune_cfg() + tune_cfg = strategy._tune_cfg_converter(next(tune_cfg_generator)) + output_op_names = ["Conv2D_dummy_biasadd"] + + op_sensitivity = adaptor.calculate_op_sensitivity( + model=quantizer.model, + dataloader=dataloader, + tune_cfg=tune_cfg, + output_op_names=output_op_names, + confidence_batches=1, + fallback=True) + self.assertIn(('op_to_store', 'conv2d'), op_sensitivity) + self.assertIn(('Conv2D', 'conv2d'), op_sensitivity) + + tune_cfg['op'][('op_to_store', 'conv2d')] = { + 'activation': {'dtype': 'fp32', 'quant_mode': 'fp32'}, + 'weight': {'dtype': 'fp32'}} + + op_sensitivity = adaptor.calculate_op_sensitivity( + model=quantizer.model, + dataloader=dataloader, + tune_cfg=tune_cfg, + output_op_names=output_op_names, + confidence_batches=1, + fallback=True) + self.assertNotIn(('op_to_store', 'conv2d'), op_sensitivity) + self.assertIn(('Conv2D', 'conv2d'), op_sensitivity) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/quantization/test_quantization.py b/test/quantization/test_quantization.py index 61698fc21d0..b98102fad55 100644 --- a/test/quantization/test_quantization.py +++ b/test/quantization/test_quantization.py @@ -197,8 +197,8 @@ def build_fake_strategy(): "from collections import OrderedDict \n", "from .strategy import strategy_registry, TuneStrategy \n", "from ..utils import logger \n", - "from .st_utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler \n", - "from .st_utils.tuning_structs import OpTuningConfig \n", + "from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler \n", + "from .utils.tuning_structs import OpTuningConfig \n", "import copy \n", "@strategy_registry \n", "class FakeTuneStrategy(TuneStrategy): \n", diff --git a/test/strategy/test_mse_v2.py b/test/strategy/test_mse_v2.py new file mode 100644 index 00000000000..e28adba79ce --- /dev/null +++ b/test/strategy/test_mse_v2.py @@ -0,0 +1,152 @@ +import copy +import os +import shutil +import unittest +import tensorflow as tf +import numpy as np +import torchvision +from neural_compressor.experimental import Quantization, common + + +def build_mse_yaml_tf(): + mse_yaml = ''' + model: + name: fake_yaml + framework: tensorflow + inputs: x + outputs: op2_to_store + device: cpu + evaluation: + accuracy: + metric: + topk: 1 + tuning: + strategy: + name: mse_v2 + accuracy_criterion: + relative: 0.01 + exit_policy: + max_trials: 10 + timeout: 3600 + random_seed: 9527 + ''' + with open('mse_yaml_tf.yaml', 'w', encoding="utf-8") as f: + f.write(mse_yaml) + +def build_mse_yaml_pytorch(): + mse_yaml = ''' + model: + name: resnet18 + framework: pytorch_fx + + tuning: + strategy: + name: mse_v2 + accuracy_criterion: + relative: 0.01 + exit_policy: + timeout: 0 + ''' + with open('mse_yaml_pytorch.yaml', 'w', encoding="utf-8") as f: + f.write(mse_yaml) + +def build_fake_model(): + try: + graph = tf.Graph() + graph_def = tf.compat.v1.GraphDef() + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(tf.float32, shape=(1,3,3,1), name='x') + y = tf.constant(np.random.random((2,2,1,1)).astype(np.float32), name='y') + z = tf.constant(np.random.random((1,1,1,1)).astype(np.float32), name='z') + op = tf.nn.conv2d(input=x, filters=y, strides=[1,1,1,1], padding='VALID', name='op_to_store') + op2 = tf.nn.conv2d(input=op, filters=z, strides=[1,1,1,1], padding='VALID', ) + last_identity = tf.identity(op2, name='op2_to_store') + sess.run(tf.compat.v1.global_variables_initializer()) + constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op2_to_store']) + + graph_def.ParseFromString(constant_graph.SerializeToString()) + with graph.as_default(): + tf.import_graph_def(graph_def, name='') + except: + graph = tf.Graph() + graph_def = tf.compat.v1.GraphDef() + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(tf.float32, shape=(1,3,3,1), name='x') + y = tf.constant(np.random.random((2,2,1,1)).astype(np.float32), name='y') + z = tf.constant(np.random.random((1,1,1,1)).astype(np.float32), name='z') + op = tf.nn.conv2d(input=x, filters=y, strides=[1,1,1,1], padding='VALID', name='op_to_store') + op2 = tf.nn.conv2d(input=op, filters=z, strides=[1,1,1,1], padding='VALID') + last_identity = tf.identity(op2, name='op2_to_store') + + sess.run(tf.compat.v1.global_variables_initializer()) + constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op2_to_store']) + + graph_def.ParseFromString(constant_graph.SerializeToString()) + with graph.as_default(): + tf.import_graph_def(graph_def, name='') + return graph +class Test_MSEV2Strategy_Tensorflow(unittest.TestCase): + @classmethod + def setUpClass(self): + build_mse_yaml_tf() + self.model = build_fake_model() + + @classmethod + def tearDownClass(self): + os.remove('mse_yaml_tf.yaml') + shutil.rmtree('./saved', ignore_errors=True) + shutil.rmtree('runs', ignore_errors=True) + shutil.rmtree('nc_workspace', ignore_errors=True) + + def test_quantization_saved(self): + i = [0] # use a mutable type (list) to wrap the int object + def fake_eval_func(_): + # 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + eval_list = [0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1] + i[0] += 1 + return eval_list[i[0]] + + quantizer = Quantization('mse_yaml_tf.yaml') + + quantizer.model = self.model + dataset = quantizer.dataset('dummy', (100, 3, 3, 1), label=True) + quantizer.calib_dataloader = common.DataLoader(dataset) + quantizer.eval_dataloader = common.DataLoader(dataset) + quantizer.eval_func = fake_eval_func + q_model = quantizer.fit() + self.assertIsNotNone(q_model) + q_model.save('./saved') + +class Test_MSEV2Strategy_PyTorch(unittest.TestCase): + @classmethod + def setUpClass(self): + build_mse_yaml_pytorch() + self.model = torchvision.models.resnet18() + + @classmethod + def tearDownClass(self): + os.remove('mse_yaml_pytorch.yaml') + shutil.rmtree('./saved', ignore_errors=True) + shutil.rmtree('runs', ignore_errors=True) + shutil.rmtree('nc_workspace', ignore_errors=True) + + def test_quantization_saved(self): + i = [0] + def fake_eval_func(model): + acc_lst = [1, 1, 0, 0, 0, 0, 1, 1.1, 1.5, 1.1] + + i[0] += 1 + return acc_lst[i[0]] + + model = copy.deepcopy(self.model) + quantizer = Quantization('mse_yaml_pytorch.yaml') + dataset = quantizer.dataset('dummy', (1, 3, 224, 224)) + quantizer.model = model + quantizer.calib_dataloader = common.DataLoader(dataset) + quantizer.eval_func = fake_eval_func + q_model = quantizer.fit() + self.assertIsNotNone(q_model) + q_model.save('./saved') + +if __name__ == "__main__": + unittest.main() diff --git a/test/strategy/test_mse_v2_2.x.py b/test/strategy/test_mse_v2_2.x.py new file mode 100644 index 00000000000..a6b0219c62a --- /dev/null +++ b/test/strategy/test_mse_v2_2.x.py @@ -0,0 +1,141 @@ +import copy +import os +import shutil +import unittest +import tensorflow as tf +import numpy as np +import torchvision + +def build_fake_model(): + try: + graph = tf.Graph() + graph_def = tf.compat.v1.GraphDef() + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(tf.float32, shape=(1,3,3,1), name='x') + y = tf.constant(np.random.random((2,2,1,1)).astype(np.float32), name='y') + z = tf.constant(np.random.random((1,1,1,1)).astype(np.float32), name='z') + op = tf.nn.conv2d(input=x, filters=y, strides=[1,1,1,1], padding='VALID', name='op_to_store') + op2 = tf.nn.conv2d(input=op, filters=z, strides=[1,1,1,1], padding='VALID', ) + last_identity = tf.identity(op2, name='op2_to_store') + sess.run(tf.compat.v1.global_variables_initializer()) + constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op2_to_store']) + + graph_def.ParseFromString(constant_graph.SerializeToString()) + with graph.as_default(): + tf.import_graph_def(graph_def, name='') + except: + graph = tf.Graph() + graph_def = tf.compat.v1.GraphDef() + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(tf.float32, shape=(1,3,3,1), name='x') + y = tf.constant(np.random.random((2,2,1,1)).astype(np.float32), name='y') + z = tf.constant(np.random.random((1,1,1,1)).astype(np.float32), name='z') + op = tf.nn.conv2d(input=x, filters=y, strides=[1,1,1,1], padding='VALID', name='op_to_store') + op2 = tf.nn.conv2d(input=op, filters=z, strides=[1,1,1,1], padding='VALID') + last_identity = tf.identity(op2, name='op2_to_store') + + sess.run(tf.compat.v1.global_variables_initializer()) + constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op2_to_store']) + + graph_def.ParseFromString(constant_graph.SerializeToString()) + with graph.as_default(): + tf.import_graph_def(graph_def, name='') + return graph + +class Test_MSEV2Strategy(unittest.TestCase): + @classmethod + def setUpClass(self): + self.tf_model = build_fake_model() + self.torch_model = torchvision.models.resnet18() + + @classmethod + def tearDownClass(self): + shutil.rmtree('./saved', ignore_errors=True) + shutil.rmtree('nc_workspace', ignore_errors=True) + + def test_quantization_saved_tf(self): + i = [0] # use a mutable type (list) to wrap the int object + def fake_eval_func(_): + # 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + eval_list = [0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1] + i[0] += 1 + return eval_list[i[0]] + + from neural_compressor.quantization import fit + from neural_compressor.config import TuningCriterion, PostTrainingQuantConfig + from neural_compressor.data import DATASETS, DATALOADERS + dataset = DATASETS("tensorflow")["dummy"](((100, 3, 3, 1))) + dataloader = DATALOADERS['tensorflow'](dataset) + + conf = PostTrainingQuantConfig( + approach="static", + optimization_level=1, + tuning_criterion=TuningCriterion(strategy="mse_v2")) + + q_model = fit( + model=self.tf_model, + conf=conf, + calib_dataloader=dataloader, + eval_dataloader=dataloader, + eval_func=fake_eval_func) + self.assertIsNotNone(q_model) + + def test_quantization_saved_tf_with_confidence_batches(self): + i = [0] # use a mutable type (list) to wrap the int object + def fake_eval_func(_): + # 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + eval_list = [0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1] + i[0] += 1 + return eval_list[i[0]] + + from neural_compressor.quantization import fit + from neural_compressor.config import TuningCriterion, PostTrainingQuantConfig + from neural_compressor.data import DATASETS, DATALOADERS + dataset = DATASETS("tensorflow")["dummy"](((100, 3, 3, 1))) + dataloader = DATALOADERS['tensorflow'](dataset) + + conf = PostTrainingQuantConfig( + approach="static", + optimization_level=1, + tuning_criterion=TuningCriterion( + strategy="mse_v2", + strategy_kwargs={ + "confidence_batches": 5, + })) + + q_model = fit( + model=self.tf_model, + conf=conf, + calib_dataloader=dataloader, + eval_dataloader=dataloader, + eval_func=fake_eval_func) + self.assertIsNotNone(q_model) + + def test_quantization_saved_torch(self): + i = [0] + def fake_eval_func(model): + acc_lst = [1, 1, 0, 0, 0, 0, 1, 1.1, 1.5, 1.1] + i[0] += 1 + return acc_lst[i[0]] + + from neural_compressor.quantization import fit + from neural_compressor.config import TuningCriterion, PostTrainingQuantConfig + from neural_compressor.data import DATASETS, DATALOADERS + dataset = DATASETS("pytorch")["dummy"](((1, 3, 224, 224))) + dataloader = DATALOADERS['pytorch'](dataset) + + conf = PostTrainingQuantConfig( + approach="static", + optimization_level=1, + tuning_criterion=TuningCriterion(strategy="mse_v2")) + + q_model = fit( + model=self.torch_model, + conf=conf, + calib_dataloader=dataloader, + eval_dataloader=dataloader, + eval_func=fake_eval_func) + self.assertIsNotNone(q_model) + +if __name__ == "__main__": + unittest.main() diff --git a/test/strategy/test_tuning_sampler.py b/test/strategy/test_tuning_sampler.py index f305454728c..51310ebaea2 100644 --- a/test/strategy/test_tuning_sampler.py +++ b/test/strategy/test_tuning_sampler.py @@ -1,7 +1,7 @@ -from neural_compressor.strategy.st_utils.tuning_sampler import OpTypeWiseTuningSampler, ModelWiseTuningSampler -from neural_compressor.strategy.st_utils.tuning_sampler import OpWiseTuningSampler, FallbackTuningSampler -from neural_compressor.strategy.st_utils.tuning_structs import OpTuningConfig -from neural_compressor.strategy.st_utils.tuning_space import TuningSpace +from neural_compressor.strategy.utils.tuning_sampler import OpTypeWiseTuningSampler, ModelWiseTuningSampler +from neural_compressor.strategy.utils.tuning_sampler import OpWiseTuningSampler, FallbackTuningSampler +from neural_compressor.strategy.utils.tuning_structs import OpTuningConfig +from neural_compressor.strategy.utils.tuning_space import TuningSpace from collections import OrderedDict from copy import deepcopy import unittest diff --git a/test/strategy/test_tuning_space.py b/test/strategy/test_tuning_space.py index d7be2c4ac76..5696050c332 100644 --- a/test/strategy/test_tuning_space.py +++ b/test/strategy/test_tuning_space.py @@ -1,4 +1,4 @@ -from neural_compressor.strategy.st_utils.tuning_space import TuningItem, TuningSpace +from neural_compressor.strategy.utils.tuning_space import TuningItem, TuningSpace from neural_compressor.conf.dotdict import DotDict from copy import deepcopy import unittest