Skip to content

Commit

Permalink
Support OptionalType to get shape | feat(graph_building) (#890)
Browse files Browse the repository at this point in the history
Due to the need of supporting None args, we need to support None
(torch.OptionalType) to get shape/dtype.
To support pytorch/pytorch#105263

---------

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
  • Loading branch information
titaiwangms and justinchuby committed Jul 18, 2023
1 parent 71b724b commit ab54cb1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
5 changes: 4 additions & 1 deletion onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ def shape(self) -> Tuple[int | None, ...] | None:
if value_type is None:
return None
value_type = typing.cast(torch.TensorType, value_type)
shape = value_type.varyingSizes()
if isinstance(value_type, torch.OptionalType):
shape = value_type.getElementType().varyingSizes() # type: ignore[attr-defined]
else:
shape = value_type.varyingSizes()
if shape is None:
return None
return tuple(shape)
Expand Down
6 changes: 6 additions & 0 deletions onnxscript/function_libs/torch_lib/graph_building_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ def outer(x: FLOAT[1, 2, 3]):
expected = outer.to_model_proto()
onnxscript.testing.assert_isomorphic(traced, expected)

def test_add_input_with_optionaltype_does_not_raise_torch_internal_error(self):
graph = graph_building.TorchScriptGraph()
x = graph.add_input(input_name=None)
with evaluator.default_as(self.tracer):
_ = x.shape


class TestTorchScriptGraph(unittest.TestCase):
def test_add_initializer_raises_when_the_same_name_used_for_different_tensors(self):
Expand Down

0 comments on commit ab54cb1

Please sign in to comment.