Skip to content

Commit

Permalink
Fix wrapper function signature
Browse files Browse the repository at this point in the history
Signed-off-by: JakubBachurskiQC <jakub.bachurski@quantco.com>
  • Loading branch information
JakubBachurskiQC committed Aug 4, 2022
1 parent 0993ff5 commit 0d33d68
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion onnx/test/inference_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@


def call_infer_types(
schema: onnx.defs.OpSchema, inputs: Iterable[onnx.TypeProto], num_outputs: int, node: Optional[onnx.NodeProto] = None
schema: onnx.defs.OpSchema, inputs: Iterable[onnx.TypeProto],
num_outputs: Optional[int] = None, node: Optional[onnx.NodeProto] = None
) -> Optional[List[onnx.TypeProto]]:
inputs = list(inputs)

if not schema.has_type_and_shape_inference_function: # type: ignore
return None

if node is None:
if num_outputs is None:
raise ValueError("Either node or num_outputs must be specified.")
input_names = [f"in{i}" for i in range(len(inputs))]
output_names = [f"out{i}" for i in range(num_outputs)]
node = onnx.helper.make_node(schema.name, input_names, output_names)
Expand Down

0 comments on commit 0d33d68

Please sign in to comment.