Skip to content

Commit

Permalink
Merge
Browse files Browse the repository at this point in the history
Signed-off-by: neginraoof <neginmr@utexas.edu>
  • Loading branch information
neginraoof committed Jul 13, 2021
1 parent 4f6b030 commit 35549fa
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 16 deletions.
49 changes: 34 additions & 15 deletions onnx/test/automatic_upgrade_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def _test_op_upgrade(
initializer=[], # type: List[Any]
attrs={}, # type: Dict[Text, Any]
seq_inputs=[], # type: List[int]
seq_outputs=[] # type: List[int]
seq_outputs=[], # type: List[int]
optional_inputs=[], # type: List[int]
optional_outputs=[] # type: List[int]
): # type: (...) -> None
global tested_ops
tested_ops.append(op)
Expand All @@ -43,32 +45,45 @@ def _test_op_upgrade(
if input_types is None:
input_types = [TensorProto.FLOAT] * n_inputs
is_sequence = [0 if id not in seq_inputs else 1 for id in range(n_inputs)]
is_optional = [0 if id not in optional_inputs else 1 for id in range(n_inputs)]
# turn empty strings into [0] to ease type analysis, even though those entries
# will be ignored
input_shapes_cast = cast(List[List[int]],
[[0] if isinstance(shape, str) else shape for shape in input_shapes]
)
inputs = [
helper.make_tensor_value_info(name, ttype, shape) if is_sequence == 0
else helper.make_tensor_sequence_value_info(name, ttype, shape)
for (name, ttype, shape, is_sequence)
in zip(input_names, input_types, input_shapes_cast, is_sequence) if name != ''
]
inputs = []
for (name, ttype, shape, is_sequence, is_optional) in \
zip(input_names, input_types, input_shapes_cast, is_sequence, is_optional):
if name != '':
if is_sequence:
inputs += [helper.make_tensor_sequence_value_info(name, ttype, shape)]
elif is_optional:
type_proto = helper.make_tensor_type_proto(ttype, shape)
type_proto2 = helper.make_optional_type_proto(type_proto)
inputs += [helper.make_value_info(name, type_proto2)]
else:
inputs += [helper.make_tensor_value_info(name, ttype, shape)]

n_outputs = len(output_shapes)
output_names = list(string.ascii_lowercase)[n_inputs:n_inputs + n_outputs]
if output_types is None:
output_types = [TensorProto.FLOAT] * n_outputs
is_sequence = [0 if id not in seq_outputs else 1 for id in range(n_outputs)]
is_optional = [0 if id not in optional_outputs else 1 for id in range(n_outputs)]
output_shapes_cast = cast(List[List[int]],
[[0] if isinstance(shape, str) else shape for shape in output_shapes]
)
outputs = [
helper.make_tensor_value_info(name, ttype, shape) if is_sequence == 0
else helper.make_tensor_sequence_value_info(name, ttype, shape)
for (name, ttype, shape, is_sequence)
in zip(output_names, output_types, output_shapes_cast, is_sequence)
]
outputs = []
for (name, ttype, shape, is_sequence, is_optional) in \
zip(output_names, output_types, output_shapes_cast, is_sequence, is_optional):
if is_sequence:
inputs += [helper.make_tensor_sequence_value_info(name, ttype, shape)]
elif is_optional:
type_proto = helper.make_tensor_type_proto(ttype, shape)
type_proto2 = helper.make_optional_type_proto(type_proto)
outputs += [helper.make_value_info(name, type_proto2)]
else:
outputs += [helper.make_tensor_value_info(name, ttype, shape)]

node = helper.make_node(op, input_names, output_names, **attrs)
graph = helper.make_graph([node], op, inputs, outputs, initializer)
Expand Down Expand Up @@ -1015,15 +1030,19 @@ def test_ops_tested(self): # type: () -> None
all_schemas = onnx.defs.get_all_schemas()
all_op_names = [schema.name for schema in all_schemas if schema.domain == '']
excluded_ops = [
# Sequence-based ops disabled because the version converter doesn't play nicely with sequences
# Sequence-based and Optional-based ops disabled because
# the version converter doesn't play nicely with sequences
'ConcatFromSequence',
'SequenceAt',
'SequenceConstruct',
'SequenceEmpty',
'SequenceErase',
'SequenceInsert',
'SequenceLength',
'SplitToSequence'
'SplitToSequence',
'Optional',
'OptionalGetElement',
"OptionalHasElement"
]
all_op_names = [op for op in all_op_names if op not in excluded_ops]

Expand Down
2 changes: 1 addition & 1 deletion onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import print_function
from __future__ import unicode_literals

from onnx import checker, helper, TensorProto, NodeProto, GraphProto, ValueInfoProto, ModelProto, ONNX_ML, SparseTensorProto
from onnx import checker, helper, TensorProto, NodeProto, GraphProto, ValueInfoProto, ModelProto, ONNX_ML, SparseTensorProto, TypeProto
from onnx.defs import ONNX_DOMAIN, ONNX_ML_DOMAIN, AI_ONNX_PREVIEW_TRAINING_DOMAIN
from onnx.helper import make_node, make_tensor, make_tensor_value_info, make_empty_tensor_value_info, make_opsetid, make_tensor_sequence_value_info
from typing import Sequence, Union, Text, Tuple, List, Any, Optional
Expand Down

0 comments on commit 35549fa

Please sign in to comment.