Skip to content

Commit

Permalink
fix bug to support t5 static quantization (#1402)
Browse files Browse the repository at this point in the history
* fix bug to support t5 static quantization

* fix bug

* fix bug
  • Loading branch information
xin3he committed Oct 31, 2022
1 parent 6ab5570 commit ee3ef0e
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions neural_compressor/adaptor/pytorch.py
Expand Up @@ -3329,20 +3329,23 @@ def fuse_fx_model(self, model, is_qat):
fused_model = _fuse_fx(graph_module, prepare_custom_config_dict)
except:
self.sub_module_list = []
self._fuse_sub_graph(tmp_model, prefix='', is_qat=is_qat)
module_dict = dict(tmp_model.named_modules())
self._fuse_sub_graph(tmp_model, module_dict, prefix='', is_qat=is_qat)
fused_model = tmp_model
except Exception as e: # pragma: no cover
self.sub_module_list = []
fused_model = model._model
self._fuse_sub_graph(fused_model, prefix='', is_qat=is_qat)
module_dict = dict(fused_model.named_modules())
self._fuse_sub_graph(fused_model, module_dict, prefix='', is_qat=is_qat)
logger.warning("Deepcopy failed: {}, inplace=True now!".format(repr(e)))
return fused_model

def _fuse_sub_graph(self, model, prefix, is_qat):
def _fuse_sub_graph(self, model, module_dict, prefix, is_qat):
"""This is a helper function to get fused fx sub modules recursively for PyTorch_FXAdaptor.
Args:
model (object): input model which is PyTorch model.
module_dict (dict): module dict of input model.
prefix (string): prefix of op name.
is_qat (bool): check quantization approach is qat or not.
Expand All @@ -3357,11 +3360,13 @@ def _fuse_sub_graph(self, model, prefix, is_qat):
if type(module) == torch.nn.Dropout: # pragma: no cover
continue
op_name = prefix + '.' + name if prefix != '' else name
if op_name not in module_dict:
continue
if type(module) in fx_white_list \
and type(module) != torch.nn.Sequential:
module = torch.quantization.QuantWrapper(module)
if self._check_dynamic_control(module):
self._fuse_sub_graph(module, op_name, is_qat=is_qat)
self._fuse_sub_graph(module, module_dict, op_name, is_qat=is_qat)
else:
try:
graph_module = torch.fx.symbolic_trace(module)
Expand All @@ -3372,7 +3377,7 @@ def _fuse_sub_graph(self, model, prefix, is_qat):
setattr(model, name, fused_model)
self.sub_module_list.append(op_name)
except:
self._fuse_sub_graph(module, op_name, is_qat)
self._fuse_sub_graph(module, module_dict, op_name, is_qat)

@staticmethod
def _check_dynamic_control(module):
Expand Down

0 comments on commit ee3ef0e

Please sign in to comment.