Skip to content

Commit

Permalink
Fix example inputs issue for IPEX smoothquant (#864)
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <chang1.wang@intel.com>
  • Loading branch information
changwangss committed May 10, 2023
1 parent ea65271 commit c8b7533
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
6 changes: 4 additions & 2 deletions neural_compressor/adaptor/pytorch.py
Expand Up @@ -1269,7 +1269,8 @@ def smooth_quant(self, model, dataloader, calib_iter, tune_cfg=None, alpha=0.5,

if not hasattr(self, 'sq') or force_re_smooth:
from .torch_utils.smooth_quant import TorchSmoothQuant
self.sq = TorchSmoothQuant(model._model, dataloader=dataloader, q_func=self.q_func)
self.sq = TorchSmoothQuant(model._model, dataloader=dataloader, \
example_inputs=self.example_inputs, q_func=self.q_func)
kwargs = {} ##different backends may have different default values
if op_types != None:
kwargs["op_types"] = op_types
Expand Down Expand Up @@ -2986,7 +2987,8 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
ipex_conf.save(self.ipex_config_path)
else:
if self.approach in ['post_training_static_quant', 'post_training_auto_quant']:
assert self.q_dataloader is not None, "IPEX need q_dataloader to prepare the model"
assert self.q_dataloader or self.example_inputs, \
"IPEX need q_dataloader or example_inputs to prepare the model"
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
if self.version.release >= Version("2.1").release:
# HistogramObserver will cause a performance issue.
Expand Down
13 changes: 8 additions & 5 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Expand Up @@ -174,7 +174,7 @@ class TorchSmoothQuant:
to recover the weights if needed
"""

def __init__(self, model, dataloader, q_func=None, traced_model=None):
def __init__(self, model, dataloader, example_inputs=None, q_func=None, traced_model=None):
"""
:param model: Torch model :param dataloader: Calibration dataloader :param traced_model: A specific model
shares the same architecture as the model and could be traced by torch.jit. If not supplied, we use model
Expand All @@ -187,6 +187,7 @@ def __init__(self, model, dataloader, q_func=None, traced_model=None):
self.device = device
self.dtype = dtype
self.dataloader = dataloader
self.example_inputs = example_inputs
self.q_func = q_func
self.input_values = {}
self.output_values = {}
Expand Down Expand Up @@ -752,10 +753,12 @@ def _trace(self, op_types):
no_absorb_layers: A list saving the layers which could not find the absorb layer
"""
tg = GraphTrace()
for idx, input in enumerate(self.dataloader):
example_inputs = input
break
absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer(self.traced_model, example_inputs, op_types)
if self.example_inputs is None:
assert self.dataloader, "Please provide dataloader or example_inputs"
for idx, input in enumerate(self.dataloader):
self.example_inputs = input
break
absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer(self.traced_model, self.example_inputs, op_types)
return absorb_to_layer, no_absorb_layers


Expand Down

0 comments on commit c8b7533

Please sign in to comment.