diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py index 6d1a59de469..0323956eb9f 100644 --- a/neural_compressor/model/torch_model.py +++ b/neural_compressor/model/torch_model.py @@ -324,7 +324,7 @@ def graph_info(self): def export_to_jit(self, example_inputs=None): if example_inputs is not None: - if isinstance(input, dict) or isinstance(input, UserDict): + if isinstance(example_inputs, dict) or isinstance(example_inputs, UserDict): example_inputs = tuple(example_inputs.values()) else: logger.warning("Please provide example_inputs for jit.trace") @@ -357,8 +357,8 @@ def export_to_fp32_onnx( fp32_model=None, ): example_input_names = ['input'] - if isinstance(input, dict) or isinstance(input, UserDict): - example_input_names = list(input.keys()) + if isinstance(example_inputs, dict) or isinstance(example_inputs, UserDict): + example_input_names = list(example_inputs.keys()) model = self.model if fp32_model: model = fp32_model