diff --git a/onnx/test/inference_function_test.py b/onnx/test/inference_function_test.py index cf6d6fdb9cc..fc2c196a514 100644 --- a/onnx/test/inference_function_test.py +++ b/onnx/test/inference_function_test.py @@ -6,7 +6,8 @@ def call_infer_types( - schema: onnx.defs.OpSchema, inputs: Iterable[onnx.TypeProto], num_outputs: int, node: Optional[onnx.NodeProto] = None + schema: onnx.defs.OpSchema, inputs: Iterable[onnx.TypeProto], + num_outputs: Optional[int] = None, node: Optional[onnx.NodeProto] = None ) -> Optional[List[onnx.TypeProto]]: inputs = list(inputs) @@ -14,6 +15,8 @@ def call_infer_types( return None if node is None: + if num_outputs is None: + raise ValueError("Either node or num_outputs must be specified.") input_names = [f"in{i}" for i in range(len(inputs))] output_names = [f"out{i}" for i in range(num_outputs)] node = onnx.helper.make_node(schema.name, input_names, output_names)