Skip to content

Commit

Permalink
Fix UnionTypeInfo bug (#4980)
Browse files Browse the repository at this point in the history
* Fix shape-inference for if-operator

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

* Refactor code

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

* Fix lint errors

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

---------

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
  • Loading branch information
gramalingam committed Mar 22, 2023
1 parent 4af67b9 commit cb60e12
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 3 deletions.
27 changes: 24 additions & 3 deletions onnx/defs/shape_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,28 @@ void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& targ
UnionShapeInfoForTensor(source_shape, target_type);
}

void UnionShapeInfo(const TypeProto_Tensor& source_type, TypeProto_Tensor& target_type) {
// The union of a tensor of unknown rank and a tensor of known rank is a tensor of unknown rank.
// Hence, if the source_type had unknown rank, we clear the shape of the target_type.
// Otherwise, UnionShapeInfoForTensor handles the rest.
if (source_type.has_shape()) {
UnionShapeInfoForTensor(source_type.shape(), target_type);
} else {
target_type.clear_shape();
}
}

void UnionShapeInfo(const TypeProto_SparseTensor& source_type, TypeProto_SparseTensor& target_type) {
// The union of a tensor of unknown rank and a tensor of known rank is a tensor of unknown rank.
// Hence, if the source_type had unknown rank, we clear the shape of the target_type.
// Otherwise, UnionShapeInfoForTensor handles the rest.
if (source_type.has_shape()) {
UnionShapeInfoForTensor(source_type.shape(), target_type);
} else {
target_type.clear_shape();
}
}

void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type) {
UnionShapeInfoForTensor(source_shape, target_type);
}
Expand All @@ -249,16 +271,15 @@ void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type) {
"Mismatched tensor element type:", " source=", source_elem_type, " target=", target_elem_type);
}

UnionShapeInfoForTensor(source_type.tensor_type().shape(), *target_type.mutable_tensor_type());
UnionShapeInfo(source_type.tensor_type(), *target_type.mutable_tensor_type());
} else if (target_case == TypeProto::ValueCase::kSparseTensorType) {
auto source_elem_type = source_type.sparse_tensor_type().elem_type();
auto target_elem_type = target_type.sparse_tensor_type().elem_type();
if (source_elem_type != target_elem_type) {
fail_type_inference(
"Mismatched sparse tensor element type:", " source=", source_elem_type, " target=", target_elem_type);
}

UnionShapeInfoForTensor(source_type.sparse_tensor_type().shape(), *target_type.mutable_sparse_tensor_type());
UnionShapeInfo(source_type.sparse_tensor_type(), *target_type.mutable_sparse_tensor_type());
} else if (target_case == TypeProto::ValueCase::kSequenceType) {
if (!source_type.sequence_type().has_elem_type()) {
fail_type_inference("source sequence type missing element type.");
Expand Down
53 changes: 53 additions & 0 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
make_tensor_sequence_value_info,
make_tensor_value_info,
)
from onnx.parser import parse_graph


class TestShapeInferenceHelper(unittest.TestCase):
Expand Down Expand Up @@ -3853,6 +3854,58 @@ def test_if_with_different_shapes_in_then_else_branches(self) -> None:

self._assert_inferred(graph, [make_tensor_value_info("if_output", TensorProto.FLOAT, (None,))]) # type: ignore

def test_if_no_shape_in_then_branch(self) -> None:
then_graph = parse_graph(
"then_graph () => (then_output) { then_output = ReduceSum <keepdims=0> (X, axes) }"
)
else_graph = parse_graph(
"else_graph () => (else_output) { else_output = ReduceSum <keepdims=0> (X) }"
)
graph = self._make_graph(
[
("cond", TensorProto.BOOL, (1,)),
("X", TensorProto.FLOAT, (4, 8, 16)),
("axes", TensorProto.INT64, (1,)),
],
[
make_node(
"If",
["cond"],
["if_output"],
then_branch=then_graph,
else_branch=else_graph,
)
],
[],
)
self._assert_inferred(graph, [make_tensor_value_info("if_output", TensorProto.FLOAT, None)]) # type: ignore

def test_if_no_shape_in_else_branch(self) -> None:
then_graph = parse_graph(
"then_graph () => (then_output) { then_output = ReduceSum <keepdims=0> (X) }"
)
else_graph = parse_graph(
"else_graph () => (else_output) { else_output = ReduceSum <keepdims=0> (X, axes) }"
)
graph = self._make_graph(
[
("cond", TensorProto.BOOL, (1,)),
("X", TensorProto.FLOAT, (4, 8, 16)),
("axes", TensorProto.INT64, (1,)),
],
[
make_node(
"If",
["cond"],
["if_output"],
then_branch=then_graph,
else_branch=else_graph,
)
],
[],
)
self._assert_inferred(graph, [make_tensor_value_info("if_output", TensorProto.FLOAT, None)]) # type: ignore

def test_if_with_different_optional_shapes_in_then_else_branches(self) -> None:
# Create a simple If node where the 'then' subgraph adds to the current value, and the 'else' subgraph
# subtracts.
Expand Down

0 comments on commit cb60e12

Please sign in to comment.