Skip to content

Commit

Permalink
fix bug in recover API (#1274)
Browse files Browse the repository at this point in the history
  • Loading branch information
xin3he committed Sep 21, 2022
1 parent 4f3f4af commit fd7a53f
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions neural_compressor/adaptor/pytorch.py
Expand Up @@ -3135,7 +3135,10 @@ def _get_module_scale_zeropoint(self, model, tune_cfg, prefix=''):
# get scale and zero_point of modules.
modules = dict(model.named_modules())
for key in tune_cfg['op']:
sub_name = key[0].replace(prefix + '.', '', 1)
if prefix:
sub_name = key[0].replace(prefix + '.', '', 1)
else:
sub_name = key[0]
if sub_name in modules:
value = tune_cfg['op'][key]
assert isinstance(value, dict)
Expand All @@ -3146,7 +3149,10 @@ def _get_module_scale_zeropoint(self, model, tune_cfg, prefix=''):
# get scale and zero_point of getattr ops (like quantize ops).
for node in model.graph.nodes:
if node.op == 'get_attr':
sub_name = prefix + '--' + node.target
if prefix:
sub_name = prefix + '--' + node.target
else:
sub_name = node.target
if 'scale' in node.target:
tune_cfg['get_attr'][sub_name] = float(getattr(model, node.target))
elif 'zero_point' in node.target:
Expand Down

0 comments on commit fd7a53f

Please sign in to comment.