Skip to content

Commit

Permalink
fix bug in export_to_onnx API (#1166)
Browse files Browse the repository at this point in the history
  • Loading branch information
xin3he committed Aug 18, 2022
1 parent a28705c commit 158c7f4
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions neural_compressor/model/torch_model.py
Expand Up @@ -324,7 +324,7 @@ def graph_info(self):

def export_to_jit(self, example_inputs=None):
if example_inputs is not None:
if isinstance(input, dict) or isinstance(input, UserDict):
if isinstance(example_inputs, dict) or isinstance(example_inputs, UserDict):
example_inputs = tuple(example_inputs.values())
else:
logger.warning("Please provide example_inputs for jit.trace")
Expand Down Expand Up @@ -357,8 +357,8 @@ def export_to_fp32_onnx(
fp32_model=None,
):
example_input_names = ['input']
if isinstance(input, dict) or isinstance(input, UserDict):
example_input_names = list(input.keys())
if isinstance(example_inputs, dict) or isinstance(example_inputs, UserDict):
example_input_names = list(example_inputs.keys())
model = self.model
if fp32_model:
model = fp32_model
Expand Down

0 comments on commit 158c7f4

Please sign in to comment.