Skip to content

Commit

Permalink
Add mse_v2 tuning strategy (#218)
Browse files Browse the repository at this point in the history
Signed-off-by: intel-zhangyi <yi5.zhang@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Co-authored-by: yiliu30 <yi4.liu@intel.com>
Co-authored-by: Xin He <xin3.he@intel.com>
  • Loading branch information
3 people committed Dec 9, 2022
1 parent 7ffd5e5 commit 80311f6
Show file tree
Hide file tree
Showing 31 changed files with 1,365 additions and 54 deletions.
33 changes: 33 additions & 0 deletions docs/source/tuning_strategies.md
Expand Up @@ -200,6 +200,39 @@ tuning:
random_seed: 9527
```

### MSE_v2

#### Design

`MSE_v2` is a two-stage fallback strategy for few-shot mixed quantization,
which is composed of three key components. First, a multi-batch order
combination based on per-layer fallback MSE values helps evaluate layer
sensitivity with few-shot. Second, a sensitivity gradient is proposed to
better evaluate the sensitivity, together with the beam search to solve
the local optimum problem. Third, a quantize-again procedure is introduced
to remove redundancy in fallback layers to protect performance. MSE_v2 performs
better especially in models with a long full-dataset evaluation time and a
large number of tuning counts.

#### Usage
`MSE_v2` is similar to `MSE` in usage. To use the `MSE_v2` tuning strategy,
the specific strategy name of `mse_v2` must be included. Also, the option
`confidence_batches` can be included optionally to specify the count of batches
in sensitivity calculation process.


```yaml
tuning:
strategy:
name: mse_v2
confidence_batches: 2
accuracy_criterion:
relative: 0.01
exit_policy:
timeout: 0
random_seed: 9527
```

### TPE

#### Design
Expand Down
27 changes: 27 additions & 0 deletions examples/.config/model_params_pytorch.json
Expand Up @@ -216,6 +216,33 @@
"batch_size": 100,
"new_benchmark": false
},
"efficientnet_b0_fx": {
"model_src_dir": "image_recognition/torchvision_models/quantization/ptq/cpu/fx/",
"dataset_location": "/tf_dataset/pytorch/ImageNet/raw",
"input_model": "",
"yaml": "conf.yaml",
"strategy": "mse_v2",
"batch_size": 100,
"new_benchmark": false
},
"efficientnet_b3_fx": {
"model_src_dir": "image_recognition/torchvision_models/quantization/ptq/cpu/fx/",
"dataset_location": "/tf_dataset/pytorch/ImageNet/raw",
"input_model": "",
"yaml": "conf.yaml",
"strategy": "mse_v2",
"batch_size": 100,
"new_benchmark": false
},
"efficientnet_b7_fx": {
"model_src_dir": "image_recognition/torchvision_models/quantization/ptq/cpu/fx/",
"dataset_location": "/tf_dataset/pytorch/ImageNet/raw",
"input_model": "",
"yaml": "conf.yaml",
"strategy": "mse_v2",
"batch_size": 100,
"new_benchmark": false
},
"bert_base_MRPC": {
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx",
"dataset_location": "",
Expand Down
23 changes: 22 additions & 1 deletion neural_compressor/adaptor/pytorch.py
Expand Up @@ -3182,7 +3182,6 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
Returns:
None
"""

module_dict = dict(model.named_modules())
for op_name, child in model.named_modules():
if self.is_fused_module(child):
Expand Down Expand Up @@ -3507,6 +3506,28 @@ def _check_dynamic_control(module):
logger.info('Module has no forward function')
return False

def get_output_op_names(self, *args, **kwargs):
return None

def calculate_op_sensitivity(self, model, dataloader, tune_cfg, output_op_names,
confidence_batches, fallback=True, requantize_cfgs=None):
"""This is a helper function for `query_fw_capability`,
and it will get all quantizable ops from model.
Args:
model (object): INC model containing fp32 model
dataloader (string): dataloader contains real data.
tune_cfg (dict): dictionary of tune configure for each op.
fallback (bool): switch method in fallback stage and re-quantize stage
Returns:
ops_lst (list): sorted op list by sensitivity
"""
from .torch_utils.util import get_fallback_order
ordered_ops = get_fallback_order(self, model.model, dataloader, tune_cfg,
confidence_batches, fallback, requantize_cfgs)
return ordered_ops


class PyTorchQuery(QueryBackendCapability):
def __init__(self, local_config_file=None):
Expand Down
155 changes: 155 additions & 0 deletions neural_compressor/adaptor/tensorflow.py
Expand Up @@ -94,6 +94,8 @@ def __init__(self, framework_specific_info):

self.optype_statistics = None

self._last_dequantize_ops = None

def log_histogram(self, writer, tag, values, step=0, bins=1000):
import tensorflow as tf
# Convert to a numpy array
Expand Down Expand Up @@ -1453,8 +1455,161 @@ def recover_tuned_model(self, model, q_config):
def diagnosis_helper(self, fp32_model, quan_model, tune_cfg, save_path):
from .tf_utils.util import tf_diagnosis_helper
return tf_diagnosis_helper(fp32_model, quan_model, tune_cfg, save_path)

def get_output_op_names(self, qmodel):
from .tf_utils.graph_util import GraphAnalyzer

graph_def = GraphAnalyzer().parse_graph(qmodel.graph_def)
output_op_names = set()

for output_opname in qmodel.output_node_names:
op_count = 0
stack = [output_opname]
while stack:
opname = stack.pop()
while True:
op_count += 1
if opname not in graph_def:
break
op = graph_def[opname]
if op.node.op == 'Dequantize':
output_op_names.add(opname)
break
next_opnames = op.node.input
if not next_opnames:
break
elif len(next_opnames) > 1:
stack += next_opnames[1:]

opname = next_opnames[0]

output_op_names = list(output_op_names)
logger.debug(f"output op names: {output_op_names}")
return output_op_names

def calculate_op_sensitivity(self, model, dataloader, tune_cfg, output_op_names,
confidence_batches, fallback=True, requantize_cfgs=None):
"""Compute the op sensitivity.
The sensitivity metric is the mse between the output of the last quantized op of
the quantized model and the output of its corresponding op in the fp32 model.
1. Backup the tune cfg
2. Fallback each int8 op and compute its mse if use fallback (with 'fallback == True'),
or re-quantize each fp32 op(fallen back in the previous stage) and compute its MSE if not.
3. Sorted op name list according to its MSE
Args:
fp32_model: The fp32 model.
dataloader: the dataloader with full dataset.
tune_cfg: tuning config
fallback: denote fallback stage or re-quantize stage
requantize_cfgs: the dict of tuning configs for all re-quantizable ops
Returns:
A list of op names, sorted by its MSE sensitivity.
"""
from copy import deepcopy

fp32_op_cfg = {'activation': {'dtype': 'fp32', 'quant_mode': 'fp32'},
'weight': {'dtype': 'fp32'}}

if fallback:
ops_list = [op for op, config in tune_cfg['op'].items()
if config['activation']['quant_mode'] in ('static', 'dynamic')]
replace_cfgs = {op : fp32_op_cfg for op in tune_cfg['op']}
else:
ops_list = [op for op, config in tune_cfg['op'].items()
if config['activation']['quant_mode'] == 'fp32' and op in requantize_cfgs]
replace_cfgs = requantize_cfgs

# Step2. compute mse
mse_result = self._get_mse_order(
model, deepcopy(tune_cfg), replace_cfgs, ops_list, dataloader,
output_op_names, confidence_batches)

# Step3. sort
mse_order = [op for op, _ in sorted(mse_result.items(), key=lambda i: i[1])]
logger.debug("Dump MSE order:")
for op in mse_order:
logger.debug(f"{op}: {mse_result[op]}")
return mse_order

def _get_mse_order(self, fp32_model, tune_cfg, replace_cfgs, ops_lst, dataloader,
output_op_names, confidence_batches):
op_cfg = tune_cfg['op']
mse_result = {}
partial_dataloader = self._partial_dataloader(dataloader, confidence_batches)

fp32_output = self._inference_model_on_batches(
fp32_model, tune_cfg, partial_dataloader, output_op_names)

for op in ops_lst:
# backup and set replace tuning config
backup_cfg = op_cfg[op]
op_cfg[op] = replace_cfgs[op]

# quantize and inference the model
q_model = self.quantize(tune_cfg, fp32_model, partial_dataloader)
q_output = self._inference_model_on_batches(
q_model, tune_cfg, partial_dataloader, output_op_names)

mse_result[op] = self._calculate_mse(fp32_output, q_output)

# recover tune_cfg
op_cfg[op] = backup_cfg

return mse_result

def _partial_dataset_of(self, dataloader, confidence_batches):
from neural_compressor.experimental.data.datasets.dummy_dataset import DummyDataset
if isinstance(dataloader.dataset, DummyDataset):
assert(isinstance(confidence_batches, int))
ds = copy.deepcopy(dataloader.dataset)
ds.dataset = ds.dataset[:confidence_batches]
return ds
else:
return dataloader.dataset.take(confidence_batches)

def _partial_dataloader(self, dataloader, confidence_batches):
return type(dataloader)(
dataset=self._partial_dataset_of(dataloader, confidence_batches),
batch_size=dataloader.batch_size,
last_batch=dataloader.last_batch,
collate_fn=dataloader.collate_fn,
sampler=dataloader.sampler,
batch_sampler=dataloader.batch_sampler,
num_workers=dataloader.num_workers,
pin_memory=dataloader.pin_memory,
shuffle=dataloader.shuffle,
distributed=dataloader.distributed)

def _calculate_mse(self, fp32_output, q_output):
result = []
for i, j in zip(fp32_output, q_output):
result.append(np.square(i - j).mean())
return np.array(result).mean()

def _inference_model_on_batches(self, model, tune_cfg, dataloader,
output_op_names):
from .tf_utils.util import generate_feed_dict

input_tensors = model.input_tensor
output_tensors = []
for op in output_op_names:
for tensor in model.graph.get_operation_by_name(op).outputs:
output_tensors.append(tensor)

predictions = []
for index, (inputs, _) in enumerate(dataloader):
feed_dict = generate_feed_dict(input_tensors, inputs)

pred = model.sess.run(output_tensors, feed_dict)
for item in pred:
predictions.append(item)

return predictions

@adaptor_registry
class Tensorflow_ITEXAdaptor(TensorFlowAdaptor):
def __init__(self, framework_specific_info):
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/tf_utils/graph_converter.py
Expand Up @@ -34,7 +34,7 @@
from .transform_graph.insert_logging import InsertLogging
from .transform_graph.rerange_quantized_concat import RerangeQuantizedConcat
from .transform_graph.bias_correction import BiasCorrection
from .util import iterator_sess_run,version1_gt_version2,version1_eq_version2
from .util import generate_feed_dict, iterator_sess_run,version1_gt_version2,version1_eq_version2
from .util import version1_gte_version2,version1_lte_version2,version1_lt_version2
from .quantize_graph.quantize_graph_for_intel_cpu import QuantizeGraphForIntel
from .quantize_graph_common import QuantizeGraphHelper
Expand Down
60 changes: 60 additions & 0 deletions neural_compressor/adaptor/tf_utils/util.py
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.
#

from collections import OrderedDict, UserDict
import os
import numpy as np
from google.protobuf import text_format
Expand Down Expand Up @@ -493,3 +494,62 @@ def _parse_config(q_config, cfg, op_list):
if op_name_and_type[0] in op_list:
updated_cfg['op'][op_name_and_type] = cfg['op'][op_name_and_type]
return dequan_min_max, updated_cfg

def generate_feed_dict(input_tensor, inputs):
if len(input_tensor) == 1:
feed_dict = {}
if isinstance(inputs, dict) or isinstance(inputs, OrderedDict) \
or isinstance(inputs, UserDict):
for name in inputs:
for tensor in input_tensor:
pos = tensor.name.rfind(":")
t_name = tensor.name if pos < 0 else tensor.name[:pos]
if name == t_name:
feed_dict[tensor] = inputs[name]
break
else:
feed_dict = {input_tensor[0]: inputs} # get raw tensor using index [0]
else:
assert len(input_tensor) == len(inputs), \
'inputs len must equal with input_tensor'
feed_dict = {}
if isinstance(inputs, dict) or isinstance(inputs, OrderedDict) \
or isinstance(inputs, UserDict):
for name in inputs:
for tensor in input_tensor:
pos = tensor.name.rfind(":")
t_name = tensor.name if pos < 0 else tensor.name[:pos]
if name in [tensor.name, t_name]:
feed_dict[tensor] = inputs[name]
break
else:
# sometimes the input_tensor is not the same order with inputs
# we should check and pair them
def check_shape(tensor, data):
# scalar or 1 dim default True
if tensor.shape == None or \
len(tensor.shape.dims) == 1 or \
not hasattr(data, 'shape'):
return True
tensor_shape = tuple(tensor.shape)
data_shape = tuple(data.shape)
for tensor_dim, data_dim in zip(tensor_shape, data_shape):
if tensor_dim is not None and tensor_dim != data_dim:
return False
return True

disorder_tensors = []
disorder_inputs = []
for idx, sort_tensor in enumerate(input_tensor):
sort_input = inputs[idx]
if check_shape(sort_tensor, sort_input):
feed_dict.update({sort_tensor: sort_input})
else:
disorder_tensors.append(sort_tensor)
disorder_inputs.append(sort_input)
for i, dis_tensor in enumerate(disorder_tensors):
for j, dis_input in enumerate(disorder_inputs):
if check_shape(dis_tensor, dis_input):
feed_dict.update({dis_tensor: dis_input})
break
return feed_dict

0 comments on commit 80311f6

Please sign in to comment.