diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index e609b9c0cdb..23dffac51e5 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -3329,20 +3329,23 @@ def fuse_fx_model(self, model, is_qat): fused_model = _fuse_fx(graph_module, prepare_custom_config_dict) except: self.sub_module_list = [] - self._fuse_sub_graph(tmp_model, prefix='', is_qat=is_qat) + module_dict = dict(tmp_model.named_modules()) + self._fuse_sub_graph(tmp_model, module_dict, prefix='', is_qat=is_qat) fused_model = tmp_model except Exception as e: # pragma: no cover self.sub_module_list = [] fused_model = model._model - self._fuse_sub_graph(fused_model, prefix='', is_qat=is_qat) + module_dict = dict(fused_model.named_modules()) + self._fuse_sub_graph(fused_model, module_dict, prefix='', is_qat=is_qat) logger.warning("Deepcopy failed: {}, inplace=True now!".format(repr(e))) return fused_model - def _fuse_sub_graph(self, model, prefix, is_qat): + def _fuse_sub_graph(self, model, module_dict, prefix, is_qat): """This is a helper function to get fused fx sub modules recursively for PyTorch_FXAdaptor. Args: model (object): input model which is PyTorch model. + module_dict (dict): module dict of input model. prefix (string): prefix of op name. is_qat (bool): check quantization approach is qat or not. @@ -3357,11 +3360,13 @@ def _fuse_sub_graph(self, model, prefix, is_qat): if type(module) == torch.nn.Dropout: # pragma: no cover continue op_name = prefix + '.' + name if prefix != '' else name + if op_name not in module_dict: + continue if type(module) in fx_white_list \ and type(module) != torch.nn.Sequential: module = torch.quantization.QuantWrapper(module) if self._check_dynamic_control(module): - self._fuse_sub_graph(module, op_name, is_qat=is_qat) + self._fuse_sub_graph(module, module_dict, op_name, is_qat=is_qat) else: try: graph_module = torch.fx.symbolic_trace(module) @@ -3372,7 +3377,7 @@ def _fuse_sub_graph(self, model, prefix, is_qat): setattr(model, name, fused_model) self.sub_module_list.append(op_name) except: - self._fuse_sub_graph(module, op_name, is_qat) + self._fuse_sub_graph(module, module_dict, op_name, is_qat) @staticmethod def _check_dynamic_control(module):