Skip to content

Commit

Permalink
fix bug in TorchSmoothQuant (#1149)
Browse files Browse the repository at this point in the history
* [bug fix] when folding=False and QKV is not fully converted to SQLinear.

Signed-off-by: Xin He <xin3.he@intel.com>

---------

Signed-off-by: Xin He <xin3.he@intel.com>
  • Loading branch information
xin3he committed Aug 17, 2023
1 parent aa4770d commit 0349b9a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 33 deletions.
10 changes: 6 additions & 4 deletions neural_compressor/adaptor/pytorch.py
Expand Up @@ -2605,7 +2605,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
if self.version.release >= Version("1.12.0").release:
# Check save_qconf_summary part is a workaroud for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model._model, "save_qconf_summary"):
if not hasattr(model._model, "save_qconf_summary") or \
not hasattr(model._model, "load_qconf_summary"):
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
if self.version.release >= Version("2.1").release:
static_qconfig = ipex.quantization.default_static_qconfig_mapping
Expand Down Expand Up @@ -2950,7 +2951,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
ipex_conf.save(self.ipex_config_path)
else:
if self.approach in ['post_training_static_quant', 'post_training_auto_quant']:
assert self.q_dataloader or self.example_inputs, \
assert self.q_dataloader is not None or self.example_inputs is not None, \
"IPEX need q_dataloader or example_inputs to prepare the model"
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
if self.version.release >= Version("2.1").release:
Expand Down Expand Up @@ -2983,7 +2984,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
model = ipex.quantization.prepare(model, static_qconfig,
example_inputs=self.example_inputs, inplace=True)

if self.q_dataloader or self.example_inputs:
if self.q_dataloader is not None or self.example_inputs is not None:
self._simple_inference(model, self.q_dataloader, iterations=1)
else:
try:
Expand Down Expand Up @@ -3141,7 +3142,8 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):

# Check save_qconf_summary part is a workaroud for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model._model, "save_qconf_summary"):
if not hasattr(model._model, "save_qconf_summary") or \
not hasattr(model._model, "load_qconf_summary"):
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
if isinstance(self.example_inputs, dict):
model._model = ipex.quantization.prepare(model._model, static_qconfig,
Expand Down
62 changes: 35 additions & 27 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Expand Up @@ -501,44 +501,43 @@ def _reshape_scale_for_input(self, layer, scale):

return scale

def _scale_layer_weight(self, layer_name, scale): ##input channel
def _scale_layer_weight(self, layer_name, scale, alpha=0.5, input_minmax=None): ##input channel
"""
Scale the layer weights at input channel, depthwise conv output channel
:param layer_name: The layer name
:param scale: The scale to be multiplied
:param alpha: alpha for SQLinearWrapper
:param input_minmax: input_minmax for SQLinearWrapper
:return:
"""
layer = get_module(self.model, layer_name)
if layer.__class__.__name__ == "SQLinearWrapper":
return scale # weigth update is done in SQLinearWrapper initialization
scale = self._reshape_scale_for_weight(layer, scale)
layer.weight = torch.nn.Parameter(layer.weight * scale)
if self.insert_mul:
from .model_wrapper import SQLinearWrapper
layer = get_module(self.model, layer_name)
if isinstance(layer, SQLinearWrapper):
layer._recover_sq_linear()
set_module(self.model, layer_name, layer.sq_linear) ##recover
else:
new_module = SQLinearWrapper(layer, 1.0 / scale, input_minmax, alpha)
set_module(self.model, layer_name, new_module)
elif self.allow_absorb:
scale = self._reshape_scale_for_weight(layer, scale)
layer.weight = torch.nn.Parameter(layer.weight * scale)
return scale

def _absorb_scales(self, layer_name, scale, alpha=0.5): ##output channel
def _absorb_scales(self, layer_name, scale): ##output channel
"""
Absorb the scale to the layer at output channel
:param layer_name: The module name
:param scale: The scale to be absorbed
:param alpha_key: The alpha passed to SQLinearWrapper
:return:
"""
layer = get_module(self.model, layer_name)
if self.insert_mul:
if layer.__class__.__name__ == "SQLinearWrapper":
layer._recover_sq_linear()
set_module(self.model, layer_name, layer.sq_linear) ##recover
else:
from .model_wrapper import SQLinearWrapper
input_minmax = [self.input_mins[layer_name], self.input_maxes[layer_name]]
new_module = SQLinearWrapper(layer, scale, input_minmax, alpha)
set_module(self.model, layer_name, new_module)
return

if not self.allow_absorb:
return ## change the code style due to too many if/else statements in the following
if self.insert_mul or not self.allow_absorb:
return # absorb is updated in SQLinearWrapper in def _scale_layer_weight

##if self.allow absorb
layer = get_module(self.model, layer_name)
if layer.__class__.__name__ == 'WrapperLayer':
layer = layer.orig_layer
if isinstance(layer, torch.nn.BatchNorm2d) or isinstance(layer, torch.nn.GroupNorm) or \
Expand Down Expand Up @@ -650,7 +649,9 @@ def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=Fal
:param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict
:return:
"""
absorb_scales_info, weight_scales_info = self._cal_scales(absorb_to_layer, input_maxes, alpha, tuning)
absorb_scales_info, weight_scales_info = self._cal_scales(
absorb_to_layer, input_maxes, alpha, tuning
)
if not absorb_scales_info or not weight_scales_info:
return weight_scales_info, absorb_scales_info
for index, key in enumerate(absorb_to_layer.keys()):
Expand All @@ -659,10 +660,13 @@ def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=Fal
elif isinstance(alpha, dict):
alpha_tmp = alpha[key]
absorb_scale = absorb_scales_info[key]
self._absorb_scales(key, absorb_scale, alpha_tmp)
self._absorb_scales(key, absorb_scale)
layer_names = absorb_to_layer[key]
for layer_name in layer_names:
self._scale_layer_weight(layer_name, weight_scales_info[layer_name])
input_minmax = [self.input_mins[layer_names[0]], self.input_maxes[layer_names[0]]]
self._scale_layer_weight(
layer_name, weight_scales_info[layer_name], alpha_tmp, input_minmax
)
return weight_scales_info, absorb_scales_info

def _check_need_calibration(self, alpha, percentile, op_types,
Expand Down Expand Up @@ -1110,10 +1114,14 @@ def _get_example_input(self):
if self.dataloader == None and self.example_inputs == None:
return None
if self.example_inputs is None:
##assert self.dataloader, "Please provide dataloader or example_inputs"
for idx, input in enumerate(self.dataloader):
self.example_inputs = input
break
try:
for idx, (input, label) in enumerate(self.dataloader):
self.example_inputs = input
break
except:
for idx, input in enumerate(self.dataloader):
self.example_inputs = input
break

return self.example_inputs

Expand Down
17 changes: 15 additions & 2 deletions test/algorithm/test_smooth_quant.py
Expand Up @@ -273,19 +273,22 @@ def __init__(self):
self.norm = torch.nn.GroupNorm(num_channels=4, num_groups=2)
self.act = torch.nn.ReLU()
self.conv2 = torch.nn.Conv2d(4, 3, 1, 1)
self.conv3 = torch.nn.Conv2d(4, 3, 1, 1)

def forward(self, x):
out = self.conv1(x)
out = self.norm(out)
out = self.act(out)
out = self.conv2(out)
tmp1 = self.conv2(out)
tmp2 = self.conv3(out)
out = tmp1 + tmp2
return out

model = Model()

sq = TorchSmoothQuant(model, self.conv_dl)
sq.transform(alpha=0.6, calib_iter=2, folding=True)
assert len(sq.absorb_to_layer) == 1
assert len(sq.absorb_to_layer['norm']) == 2

def test_sq_add(self):
class Model(torch.nn.Module):
Expand Down Expand Up @@ -626,6 +629,15 @@ def forward(self, x):
sq.transform(alpha=0.5, calib_iter=1) # By default, folding=False
assert isinstance(sq.model.fc1, SQLinearWrapper)

def test_sq_qkv(self):
model = transformers.AutoModelForCausalLM.from_pretrained(
'facebook/opt-125m', torchscript=True,)
sq = TorchSmoothQuant(model, LLMCalibDataloader())
sq.transform(alpha=0.5, calib_iter=-1, folding=False)
assert isinstance(
sq.model.model.decoder.layers[0].self_attn.k_proj, SQLinearWrapper
)

def test_sq_quant(self):
from neural_compressor import PostTrainingQuantConfig, quantization
class Model(torch.nn.Module):
Expand Down Expand Up @@ -734,6 +746,7 @@ def calib_func(model):
calib_func=calib_func,
)

fp32_model = Model()
conf = PostTrainingQuantConfig(
backend="ipex",
calibration_sampling_size=8,
Expand Down

0 comments on commit 0349b9a

Please sign in to comment.