Skip to content

Commit

Permalink
Enhance smoothquant and ipex op_type capability. (#808)
Browse files Browse the repository at this point in the history
Signed-off-by: Xin He <xin3.he@intel.com>
  • Loading branch information
xin3he committed Apr 18, 2023
1 parent bf4fb7e commit 603811e
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 49 deletions.
71 changes: 48 additions & 23 deletions neural_compressor/adaptor/pytorch.py
Expand Up @@ -19,6 +19,7 @@
import gc
import math
import os
import re
from collections import OrderedDict, UserDict, namedtuple
from packaging.version import Version
import yaml
Expand Down Expand Up @@ -1304,23 +1305,26 @@ def qdq_quantize(self, model, tune_cfg):
op_name = op_name.rstrip('.sq_linear')
fallback_op_name_list.append(op_name)

smoothquant_op_info = {'sq_linear': {}, 'qdq_linear': []}
stats_result['SQLinearWrapper'] = {'INT8(QDQ)': 0, 'BF16': 0, 'FP32': 0}
for name, module in q_model.named_modules():
if isinstance(module, SQLinearWrapper):
smoothquant_op_info['sq_linear'][name] = module.input_scale
if name not in fallback_op_name_list:
smoothquant_scale_info[name] = {
'input_scale_for_mul': module.input_scale,
'quant_scale': module.scale,
'quant_zero_point': module.zero_point,
'quant_dtype': module.dtype,
}
smoothquant_op_info['qdq_linear'].append(name+'.sq_linear')
new_module = QDQLinear(module.sq_linear, module.scale, module.zero_point, module.dtype)
set_module(q_model, name+'.sq_linear', new_module)
stats_result['SQLinearWrapper']['INT8(QDQ)'] += 1
else:
stats_result['SQLinearWrapper']['FP32'] += 1

tune_cfg['recipe_cfgs']['smoothquant_scale_info'] = smoothquant_scale_info
tune_cfg['recipe_cfgs']['smoothquant_op_info'] = smoothquant_op_info
model._model = q_model
model.q_config = copy.deepcopy(tune_cfg)
field_names=["Op Type", "Total", "INT8", "BF16", "FP32"]
Expand Down Expand Up @@ -2425,7 +2429,10 @@ def get_non_quant_modules(self, model_kwargs):
"<class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>": "adaptiveavgpool2d",
"Linear_Relu": "linear",
"<class 'torch.nn.modules.linear.Linear'>": "linear",
"<class 'torch.nn.modules.pooling.MaxPool2d'>": "maxpool2d"
"<class 'torch.nn.modules.pooling.MaxPool2d'>": "maxpool2d",
're': {
"<built-in method matmul of type object at": "matmul"
}
}


Expand Down Expand Up @@ -2476,7 +2483,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
if recipe_cfgs and recipe_cfgs.get('smooth_quant', False) \
and self.version.release >= Version("2.1").release \
and self.approach != 'post_training_dynamic_quant':
return self.qdq_quantize(model, tune_cfg, dataloader)
return self.qdq_quantize(model, tune_cfg, dataloader, q_func)

assert self.approach != 'quant_aware_training', \
"Intel PyTorch Extension didn't support quantization aware training mode"
Expand Down Expand Up @@ -2956,9 +2963,18 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
self.default_cfgs = copy.deepcopy(self.cfgs)
self.fuse_ops = self.get_fuse_ops(self.cfgs)
for op_cfg in self.cfgs:
quantizable_ops.append(
(op_cfg["id"], unify_op_type_mapping_ipex[op_cfg["name"]]
if op_cfg["name"] in unify_op_type_mapping_ipex else op_cfg["name"]))
if op_cfg["name"] in unify_op_type_mapping_ipex:
quantizable_ops.append((op_cfg["id"],
unify_op_type_mapping_ipex[op_cfg["name"]]))
else:
re_flag = False
for pattern, unify_op_type in unify_op_type_mapping_ipex['re'].items():
if re.match(pattern, op_cfg["name"]):
re_flag = True
quantizable_ops.append((op_cfg["id"], unify_op_type))
break
if not re_flag:
quantizable_ops.append((op_cfg["id"], op_cfg["name"]))
else:
ops_name, op_infos_from_cfgs, input_tensor_id_op_name, \
output_tensor_id_op_name = torch_utils.util.paser_cfgs(self.cfgs)
Expand All @@ -2970,11 +2986,19 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
if len(name) == 1:
module_key = name[0][0]
op_cfg_id = name[0][2]
quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex \
[self.cfgs[module_key]['q_op_infos'][op_cfg_id]['op_type']] \
if self.cfgs[module_key]['q_op_infos'][op_cfg_id]['op_type'] \
in unify_op_type_mapping_ipex else \
self.cfgs[module_key]['q_op_infos'][op_cfg_id]['op_type']))
ipex_op_type = self.cfgs[module_key]['q_op_infos'][op_cfg_id]['op_type']
if ipex_op_type in unify_op_type_mapping_ipex:
quantizable_ops.append((tuple(name),
unify_op_type_mapping_ipex[ipex_op_type]))
else:
re_flag = False
for pattern, unify_op_type in unify_op_type_mapping_ipex['re'].items():
if re.match(pattern, ipex_op_type):
re_flag = True
quantizable_ops.append((tuple(name), unify_op_type))
break
if not re_flag:
quantizable_ops.append((tuple(name), ipex_op_type))
else:
op_type = ""
for op_name in name:
Expand Down Expand Up @@ -3021,7 +3045,7 @@ def get_fuse_ops(self, default_cfgs):
op_patterns.append([(value[0], value[1]), (cur_id, cur_op)])
return op_patterns

def qdq_quantize(self, model, tune_cfg, dataloader):
def qdq_quantize(self, model, tune_cfg, dataloader, q_func):
assert not self.version.release < Version("2.1").release, \
"IPEX version >= 2.1 is required for SmoothQuant."

Expand All @@ -3042,14 +3066,9 @@ def qdq_quantize(self, model, tune_cfg, dataloader):
smoothquant_scale_info = {}
for name, module in q_model.named_modules():
if isinstance(module, SQLinearWrapper):
weight_scale = module._get_weight_scale()
smoothquant_scale_info[name + '.sq_linear'] = {
'input_scale_for_mul': module.input_scale,
'input_scale_after_mul': module.scale,
'input_zero_point_after_mul': module.zero_point,
'input_dtype': module.dtype,
'weight_scale_after_mul': weight_scale,
}
'alpha': module.alpha,
}
module.ipex = True
# Note: save weight scale before recover
module._recover_sq_linear()
Expand All @@ -3058,14 +3077,21 @@ def qdq_quantize(self, model, tune_cfg, dataloader):
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
q_model = ipex.quantization.prepare(q_model, static_qconfig, \
example_inputs=self.example_inputs, inplace=True)
self.calib_func(q_model, dataloader, tmp_iterations=1) # fake calibration
self.calib_func(q_model, dataloader, tmp_iterations=1) # fake calibration for save qconf
q_model.save_qconf_summary(qconf_summary=self.ipex_config_path)

# update ipex_config.json with smoothquant_scale_info
update_sq_scale(self.ipex_config_path, smoothquant_scale_info)
# enable fallback
self._cfg_to_qconfig(tune_cfg)
# update ipex_config.json with smoothquant_scale_info
update_sq_scale(self.ipex_config_path, smoothquant_scale_info)
q_model.load_qconf_summary(qconf_summary=self.ipex_config_path)
# real calibration for other operators
if q_func is not None:
q_func(q_model)
else:
iterations = tune_cfg.get('calib_iteration', 1)
self.model_calibration(q_model, dataloader, iterations, None,
tune_cfg.get('calib_sampling_size', 1))

if self.use_bf16 and (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1') and \
(self.version.release >= Version("1.11.0").release):
Expand Down Expand Up @@ -4080,7 +4106,6 @@ def _check_dynamic_control(module):
fused_model (GraphModule): fused GraphModule model from torch.fx.
"""
import inspect
import re
try:
lines = inspect.getsource(module.forward)
# Proxy obj. will always be detectd as `not None`.
Expand Down
4 changes: 3 additions & 1 deletion neural_compressor/adaptor/pytorch_ipex.yaml
Expand Up @@ -84,7 +84,7 @@
'conv1d': *cap_s8_1_10_Conv2d,
'conv3d': *cap_s8_1_10_Conv2d,
'linear': *cap_s8_1_10_Conv2d,
'default': {
'default': &cap_s8_1_10_default {
'weight': {
'dtype': ['int8'],
'scheme': ['sym'],
Expand All @@ -98,6 +98,8 @@
'algorithm': ['minmax']
}
},
'add': *cap_s8_1_10_default,
'matmul': *cap_s8_1_10_default,
},
'dynamic': {},
'quant_aware': {}
Expand Down
40 changes: 37 additions & 3 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Expand Up @@ -34,7 +34,7 @@ def get_torch_version():


class QDQLinear(torch.nn.Module):
def __init__(self, module, scale, zero_point, dtype):
def __init__(self, module, scale=1, zero_point=0, dtype=torch.quint8):
super().__init__()
if PT_VERSION < Version("1.13.0").release:
import torch.nn.quantized as nnq
Expand All @@ -59,9 +59,10 @@ def qdq_weight(self):


class SQLinearWrapper(torch.nn.Module):
def __init__(self, module, input_scale, input_minmax, dtype=torch.quint8):
def __init__(self, module, input_scale, input_minmax, alpha=0.5, dtype=torch.quint8):
super().__init__()
self.input_scale = input_scale
self.register_buffer('input_scale', input_scale)
self.alpha = alpha
self.dtype = dtype
# calculate and only save scale, zero_point to avoid memory usage
self.scale, self.zero_point = self._calculate_qparams(input_scale, input_minmax, dtype)
Expand Down Expand Up @@ -104,3 +105,36 @@ def _recover_sq_linear(self):
scale = self.input_scale.view(1, self.input_scale.shape[0])
with torch.no_grad():
self.sq_linear.weight *= scale


def _wrapper_sq_linear(tmp_model, input_scale_dict):
"""Help function to generate a fake SmoothQuant model for loading weights"""
class SQLinearWrapper(torch.nn.Module):
def __init__(self, module, input_scale):
super().__init__()
self.register_buffer('input_scale', input_scale)
self.add_module('sq_linear', module)

def forward(self, X):
X = torch.mul(X, self.input_scale)
X = self.sq_linear(X)
return X

module_name_list = input_scale_dict.keys()
from .smooth_quant import get_module, set_module
for name in module_name_list:
module = get_module(tmp_model, name)
input_scale = input_scale_dict[name]
new_module = SQLinearWrapper(module, input_scale)
set_module(tmp_model, name, new_module)
return tmp_model


def _wrapper_qdq_linear(tmp_model, module_name_list=[]):
"""Help function to generate a fake QDQ model for loading weights"""
from .smooth_quant import get_module, set_module
for name in module_name_list:
module = get_module(tmp_model, name)
new_module = QDQLinear(module)
set_module(tmp_model, name, new_module)
return tmp_model
41 changes: 20 additions & 21 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Expand Up @@ -399,11 +399,12 @@ def _scale_layer_weight(self, layer_name, scale): ##input channel
layer.weight *= scale
return scale

def _absorb_scales(self, layer_name, scale): ##output channel
def _absorb_scales(self, layer_name, scale, alpha=0.5): ##output channel
"""
Absorb the scale to the layer at output channel
:param layer_name: The module name
:param scale: The scale to be absorbed
:param alpha_key: The alpha passed to SQLinearWrapper
:return:
"""
from .model_wrapper import SQLinearWrapper
Expand All @@ -413,7 +414,7 @@ def _absorb_scales(self, layer_name, scale): ##output channel
set_module(self.model, layer_name, layer.sq_linear) ##recover
else:
input_minmax = [self.input_mins[layer_name], self.input_maxes[layer_name]]
new_module = SQLinearWrapper(layer, scale, input_minmax)
new_module = SQLinearWrapper(layer, scale, input_minmax, alpha)
set_module(self.model, layer_name, new_module)

elif self.allow_absorb:
Expand Down Expand Up @@ -507,7 +508,7 @@ def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5):
scale = torch.clip(input_power / weight_power, min=1e-5)
scale[input_power == 0] = 1.0

self._absorb_scales(key, 1.0 / scale)
self._absorb_scales(key, 1.0 / scale, alpha_key)
absorb_scales_info[key] = 1.0 / scale
layer_names = absorb_to_layer[key]
for layer_name in layer_names:
Expand Down Expand Up @@ -902,24 +903,22 @@ def update_sq_scale(ipex_config_path, smoothquant_scale_info):
for module_name, v in ipex_config.items():
if 'q_op_infos' in v and v['q_op_infos']:
for op_num, v1 in v['q_op_infos'].items():
if 'weight_tensor_infos' in v1 and v1['weight_tensor_infos']:
op_name = v1['fqn']
if op_name in smoothquant_scale_info:
input_scale_for_mul = \
smoothquant_scale_info[op_name]['input_scale_for_mul'].tolist()
input_scale_after_mul = \
smoothquant_scale_info[op_name]['input_scale_after_mul'].tolist()
input_zero_point_after_mul = \
smoothquant_scale_info[op_name]['input_zero_point_after_mul'].tolist()
weight_scale_for_mul = \
(1 / smoothquant_scale_info[op_name]['input_scale_for_mul']).tolist()
weight_scale_after_mul = \
smoothquant_scale_info[op_name]['weight_scale_after_mul'].tolist()
v1['input_tensor_infos'][0]['smooth_quant_scaling_factor'] = input_scale_for_mul
v1['input_tensor_infos'][0]['scale'] = input_scale_after_mul
v1['input_tensor_infos'][0]['zero_point'] = input_zero_point_after_mul
v1['weight_tensor_infos'][0]['smooth_quant_scaling_factor'] = weight_scale_for_mul
v1['weight_tensor_infos'][0]['scale'] = weight_scale_after_mul
# update alpha data instead of updating weight scale
op_name = v1['fqn'] # fqn always exists even it's empty.
if op_name in smoothquant_scale_info:
# observers were overridden by the fallback step, setting it back.
v1['activation_observer'] = {'name': 'SmoothQuantActivationObserver',
'smooth_quant_enabled': False, 'dtype': 'torch.quint8',
'qscheme': 'torch.per_tensor_affine', 'reduce_range': False,
'quant_min': 0, 'quant_max': 255,
'alpha': smoothquant_scale_info[op_name]['alpha']
}
v1['weight_observer'] = {'name': 'SmoothQuantWeightObserver',
'smooth_quant_enabled': False, 'dtype': 'torch.qint8',
'qscheme': 'torch.per_channel_symmetric', 'reduce_range': False,
'quant_min': -128, 'quant_max': 127,
'alpha': smoothquant_scale_info[op_name]['alpha'] #only update alpha
}
f.close()
# overwrite ipex_config_path
with open(ipex_config_path, 'w') as f1:
Expand Down
10 changes: 10 additions & 0 deletions neural_compressor/utils/pytorch.py
Expand Up @@ -260,6 +260,16 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
elif tune_cfg['approach'] == "post_training_static_quant":
approach_quant_mode = 'static'

recipe_cfgs = tune_cfg.get('recipe_cfgs', None)
if recipe_cfgs and recipe_cfgs.get('smooth_quant', False) \
and not recipe_cfgs['smooth_quant_args']['folding'] \
and approach_quant_mode != 'dynamic':
from ..adaptor.torch_utils.model_wrapper import _wrapper_sq_linear, _wrapper_qdq_linear
model = _wrapper_sq_linear(model, recipe_cfgs['smoothquant_op_info']['sq_linear'])
model = _wrapper_qdq_linear(model, recipe_cfgs['smoothquant_op_info']['qdq_linear'])
model.load_state_dict(stat_dict)
return model

for _, op_cfg in tune_cfg['op'].items():
if 'quant_mode' not in op_cfg['activation']:
op_cfg['activation']['quant_mode'] = approach_quant_mode
Expand Down
7 changes: 6 additions & 1 deletion test/algorithm/test_smooth_quant.py
Expand Up @@ -585,7 +585,7 @@ def forward(self, x):
out = self.fc2(out)
return out

input_ids = torch.randn([1, 3])
input_ids = torch.randn([2, 3])
fp32_model = Model()
conf = PostTrainingQuantConfig(
calibration_sampling_size=8,
Expand All @@ -608,6 +608,11 @@ def __iter__(self):
assert isinstance(q_model.model.fc1, SQLinearWrapper)
assert isinstance(fp32_model.fc1, SQLinearWrapper) # for smoothquant, inplace=True.

q_model.save('saved_result')
from neural_compressor.utils.pytorch import load
model_origin = Model()
qdq_model = load("./saved_result", model_origin)

fp32_model = Model()
origin_bias = float(fp32_model.fc1.bias[0])
conf = PostTrainingQuantConfig(
Expand Down

0 comments on commit 603811e

Please sign in to comment.