From 4114272b9ed66a545182c6a63f0399d1e1ecbaf9 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Mon, 5 Jun 2023 16:31:19 +0000 Subject: [PATCH 1/5] Add op(atleast_1d and atleast_2d and atleast_3d) | feat(torchlib) --- .../function_libs/torch_lib/ops/core.py | 60 +++++++++++++++++-- .../function_libs/torch_lib/ops_test_data.py | 15 +++++ 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d9b576a3f6..d08aa2a871 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -666,22 +666,70 @@ def aten_atanh(self: TFloat) -> TFloat: return op.Atanh(self) -def aten_atleast_1d(self: TensorType) -> TensorType: +@torch_op("aten::atleast_1d") +def aten_atleast_1d(self: Sequence[TTensor]) -> TTensor: """atleast_1d(Tensor self) -> Tensor""" - raise NotImplementedError() + sequence_length = op.Shape(self) + + for i in range(sequence_length): + tensor = op.SequenceAt(self, i) + shape = op.Shape(tensor) + rank = op.Size(shape) + rank_required = op.Constant(value_ints=[1]) + if rank < 1: + # Get how many dim needs to be added + one = op.Constant(value_ints=[1]) + one_count = op.Sub(rank, rank_required) + append_shape = op.Expand(one, one_count) + new_shape = op.Concat(shape, append_shape, axis=0) + # Do we need a new Sequence? + tensor = op.Reshape(tensor, new_shape) + return self -def aten_atleast_2d(self: TensorType) -> TensorType: +@torch_op("aten::atleast_2d") +def aten_atleast_2d(self: Sequence[TTensor]) -> TTensor: """atleast_2d(Tensor self) -> Tensor""" - raise NotImplementedError() + sequence_length = op.Shape(self) + + for i in range(sequence_length): + tensor = op.SequenceAt(self, i) + shape = op.Shape(tensor) + rank = op.Size(shape) + rank_required = op.Constant(value_ints=[2]) + if rank < 3: + # Get how many dim needs to be added + one = op.Constant(value_ints=[1]) + one_count = op.Sub(rank, rank_required) + append_shape = op.Expand(one, one_count) + new_shape = op.Concat(shape, append_shape, axis=0) + # Do we need a new Sequence? + tensor = op.Reshape(tensor, new_shape) + return self -def aten_atleast_3d(self: TensorType) -> TensorType: +@torch_op("aten::atleast_3d") +def aten_atleast_3d(self: Sequence[TTensor]) -> TTensor: """atleast_3d(Tensor self) -> Tensor""" - raise NotImplementedError() + sequence_length = op.Shape(self) + + for i in range(sequence_length): + tensor = op.SequenceAt(self, i) + shape = op.Shape(tensor) + rank = op.Size(shape) + rank_required = op.Constant(value_ints=[3]) + if rank < 1: + # Get how many dim needs to be added + one = op.Constant(value_ints=[1]) + one_count = op.Sub(rank, rank_required) + append_shape = op.Expand(one, one_count) + new_shape = op.Concat(shape, append_shape, axis=0) + # Do we need a new Sequence? + tensor = op.Reshape(tensor, new_shape) + return self def aten_avg_pool1d( diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index fc9c1eda26..14fba01bd2 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -344,6 +344,9 @@ def _where_input_wrangler( "atan": core_ops.aten_atan, "atan2": core_ops.aten_atan2, "atanh": core_ops.aten_atanh, + "atleast_1d": core_ops.aten_atleast_1d, + "atleast_2d": core_ops.aten_atleast_2d, + "atleast_3d": core_ops.aten_atleast_3d, "baddbmm": core_ops.aten_baddbmm, "bmm": core_ops.aten_bmm, "broadcast_to": core_ops.aten_broadcast_to, @@ -1480,6 +1483,18 @@ def _where_input_wrangler( torch.float32, torch.float16, ), + "atleast_1d": ( + torch.float32, + torch.float16, + ), + "atleast_2d": ( + torch.float32, + torch.float16, + ), + "atleast_3d": ( + torch.float32, + torch.float16, + ), "baddbmm": ( torch.float32, torch.float16, From d21c79e9c1fcdfe884ef83e18f70cd73f54e5ecb Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Mon, 5 Jun 2023 17:52:49 +0000 Subject: [PATCH 2/5] add constant node for requirement --- onnxscript/function_libs/torch_lib/ops/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d08aa2a871..14efc21669 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -677,7 +677,7 @@ def aten_atleast_1d(self: Sequence[TTensor]) -> TTensor: shape = op.Shape(tensor) rank = op.Size(shape) rank_required = op.Constant(value_ints=[1]) - if rank < 1: + if rank < rank_required: # Get how many dim needs to be added one = op.Constant(value_ints=[1]) one_count = op.Sub(rank, rank_required) @@ -699,7 +699,7 @@ def aten_atleast_2d(self: Sequence[TTensor]) -> TTensor: shape = op.Shape(tensor) rank = op.Size(shape) rank_required = op.Constant(value_ints=[2]) - if rank < 3: + if rank < rank_required: # Get how many dim needs to be added one = op.Constant(value_ints=[1]) one_count = op.Sub(rank, rank_required) @@ -721,7 +721,7 @@ def aten_atleast_3d(self: Sequence[TTensor]) -> TTensor: shape = op.Shape(tensor) rank = op.Size(shape) rank_required = op.Constant(value_ints=[3]) - if rank < 1: + if rank < rank_required: # Get how many dim needs to be added one = op.Constant(value_ints=[1]) one_count = op.Sub(rank, rank_required) From 873957f0b8ef47258f486d16e0bd77df30ff8785 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Tue, 6 Jun 2023 03:51:42 +0000 Subject: [PATCH 3/5] USe torch logic --- .../function_libs/torch_lib/ops/core.py | 40 ++++++------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 14efc21669..0dfd2053c0 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -676,15 +676,8 @@ def aten_atleast_1d(self: Sequence[TTensor]) -> TTensor: tensor = op.SequenceAt(self, i) shape = op.Shape(tensor) rank = op.Size(shape) - rank_required = op.Constant(value_ints=[1]) - if rank < rank_required: - # Get how many dim needs to be added - one = op.Constant(value_ints=[1]) - one_count = op.Sub(rank, rank_required) - append_shape = op.Expand(one, one_count) - new_shape = op.Concat(shape, append_shape, axis=0) - # Do we need a new Sequence? - tensor = op.Reshape(tensor, new_shape) + if rank == 0: + tensor = op.Reshape(tensor, op.Constant(value_ints=[1])) return self @@ -698,15 +691,10 @@ def aten_atleast_2d(self: Sequence[TTensor]) -> TTensor: tensor = op.SequenceAt(self, i) shape = op.Shape(tensor) rank = op.Size(shape) - rank_required = op.Constant(value_ints=[2]) - if rank < rank_required: - # Get how many dim needs to be added - one = op.Constant(value_ints=[1]) - one_count = op.Sub(rank, rank_required) - append_shape = op.Expand(one, one_count) - new_shape = op.Concat(shape, append_shape, axis=0) - # Do we need a new Sequence? - tensor = op.Reshape(tensor, new_shape) + if rank == 0: + tensor = op.Reshape(tensor, op.Constant(value_ints=[1, 1])) + elif rank == 1: + tensor = op.Unsqueeze(tensor, op.Constant(value_ints=[0])) return self @@ -720,15 +708,13 @@ def aten_atleast_3d(self: Sequence[TTensor]) -> TTensor: tensor = op.SequenceAt(self, i) shape = op.Shape(tensor) rank = op.Size(shape) - rank_required = op.Constant(value_ints=[3]) - if rank < rank_required: - # Get how many dim needs to be added - one = op.Constant(value_ints=[1]) - one_count = op.Sub(rank, rank_required) - append_shape = op.Expand(one, one_count) - new_shape = op.Concat(shape, append_shape, axis=0) - # Do we need a new Sequence? - tensor = op.Reshape(tensor, new_shape) + if rank == 0: + tensor = op.Reshape(tensor, op.Constant(value_ints=[1, 1, 1])) + elif rank == 1: + tensor = op.Unsqueeze(tensor, op.Constant(value_ints=[0])) + tensor = op.Unsqueeze(tensor, op.Constant(value_ints=[-1])) + elif rank == 2: + tensor = op.Unsqueeze(tensor, op.Constant(value_ints=[-1])) return self From 7efaf27b0047a91ce3ad473ed237089b2e1a74dc Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Tue, 6 Jun 2023 16:59:23 +0000 Subject: [PATCH 4/5] Use SequemceMap --- .../function_libs/torch_lib/ops/core.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 0dfd2053c0..4d5b00e42d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -14,7 +14,7 @@ import math from typing import Any, Optional, Sequence, Tuple, Union -from onnxscript import BOOL, DOUBLE, FLOAT, INT8, INT16, INT32, INT64 +from onnxscript import BOOL, DOUBLE, FLOAT, INT8, INT16, INT32, INT64, graph from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import ( IntType, @@ -670,42 +670,39 @@ def aten_atanh(self: TFloat) -> TFloat: def aten_atleast_1d(self: Sequence[TTensor]) -> TTensor: """atleast_1d(Tensor self) -> Tensor""" - sequence_length = op.Shape(self) - - for i in range(sequence_length): - tensor = op.SequenceAt(self, i) + @graph() + def reshape_to_1d(tensor): shape = op.Shape(tensor) rank = op.Size(shape) if rank == 0: tensor = op.Reshape(tensor, op.Constant(value_ints=[1])) - return self + return tensor + + return op.SequenceMap(self, body=reshape_to_1d) @torch_op("aten::atleast_2d") def aten_atleast_2d(self: Sequence[TTensor]) -> TTensor: """atleast_2d(Tensor self) -> Tensor""" - sequence_length = op.Shape(self) - - for i in range(sequence_length): - tensor = op.SequenceAt(self, i) + @graph() + def reshape_to_2d(tensor): shape = op.Shape(tensor) rank = op.Size(shape) if rank == 0: tensor = op.Reshape(tensor, op.Constant(value_ints=[1, 1])) elif rank == 1: tensor = op.Unsqueeze(tensor, op.Constant(value_ints=[0])) - return self + + return op.SequenceMap(self, body=reshape_to_2d) @torch_op("aten::atleast_3d") def aten_atleast_3d(self: Sequence[TTensor]) -> TTensor: """atleast_3d(Tensor self) -> Tensor""" - sequence_length = op.Shape(self) - - for i in range(sequence_length): - tensor = op.SequenceAt(self, i) + @graph() + def reshape_to_3d(tensor): shape = op.Shape(tensor) rank = op.Size(shape) if rank == 0: @@ -715,7 +712,8 @@ def aten_atleast_3d(self: Sequence[TTensor]) -> TTensor: tensor = op.Unsqueeze(tensor, op.Constant(value_ints=[-1])) elif rank == 2: tensor = op.Unsqueeze(tensor, op.Constant(value_ints=[-1])) - return self + + return op.SequenceMap(self, body=reshape_to_3d) def aten_avg_pool1d( From 2be4465f2b1cc3e483434739e0da0b846ab2fc64 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Tue, 6 Jun 2023 17:52:18 +0000 Subject: [PATCH 5/5] refactor core and add songle tensor support --- .../function_libs/torch_lib/ops/core.py | 50 +++++++++++++++---- .../function_libs/torch_lib/ops_test_data.py | 35 +++++++++++++ 2 files changed, 76 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4d5b00e42d..dd142a6b5c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -681,6 +681,17 @@ def reshape_to_1d(tensor): return op.SequenceMap(self, body=reshape_to_1d) +@torch_op("aten::atleast_1d") +def aten_atleast_1d_single_tensor(self: TTensor) -> TTensor: + """atleast_1d(Tensor self) -> Tensor""" + + shape = op.Shape(self) + rank = op.Size(shape) + if rank == 0: + self = op.Reshape(self, op.Constant(value_ints=[1])) + return self + + @torch_op("aten::atleast_2d") def aten_atleast_2d(self: Sequence[TTensor]) -> TTensor: """atleast_2d(Tensor self) -> Tensor""" @@ -689,14 +700,24 @@ def aten_atleast_2d(self: Sequence[TTensor]) -> TTensor: def reshape_to_2d(tensor): shape = op.Shape(tensor) rank = op.Size(shape) - if rank == 0: - tensor = op.Reshape(tensor, op.Constant(value_ints=[1, 1])) - elif rank == 1: - tensor = op.Unsqueeze(tensor, op.Constant(value_ints=[0])) + if rank <= 1: + tensor = op.Reshape(tensor, op.Constant(value_ints=[1, -1])) + return tensor return op.SequenceMap(self, body=reshape_to_2d) +@torch_op("aten::atleast_2d") +def aten_atleast_2d_single_tensor(self: TTensor) -> TTensor: + """atleast_2d(Tensor self) -> Tensor""" + + shape = op.Shape(self) + rank = op.Size(shape) + if rank <= 1: + self = op.Reshape(self, op.Constant(value_ints=[1, -1])) + return self + + @torch_op("aten::atleast_3d") def aten_atleast_3d(self: Sequence[TTensor]) -> TTensor: """atleast_3d(Tensor self) -> Tensor""" @@ -705,17 +726,28 @@ def aten_atleast_3d(self: Sequence[TTensor]) -> TTensor: def reshape_to_3d(tensor): shape = op.Shape(tensor) rank = op.Size(shape) - if rank == 0: - tensor = op.Reshape(tensor, op.Constant(value_ints=[1, 1, 1])) - elif rank == 1: - tensor = op.Unsqueeze(tensor, op.Constant(value_ints=[0])) - tensor = op.Unsqueeze(tensor, op.Constant(value_ints=[-1])) + if rank <= 1: + tensor = op.Reshape(tensor, op.Constant(value_ints=[1, -1, 1])) elif rank == 2: tensor = op.Unsqueeze(tensor, op.Constant(value_ints=[-1])) + return tensor return op.SequenceMap(self, body=reshape_to_3d) +@torch_op("aten::atleast_3d") +def aten_atleast_3d_single_tensor(self: TTensor) -> TTensor: + """atleast_3d(Tensor self) -> Tensor""" + + shape = op.Shape(self) + rank = op.Size(shape) + if rank <= 1: + self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1])) + elif rank == 2: + self = op.Unsqueeze(self, op.Constant(value_ints=[-1])) + return self + + def aten_avg_pool1d( self: TensorType, kernel_size: Sequence[int], diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 14fba01bd2..a4ab21a5a9 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -345,8 +345,11 @@ def _where_input_wrangler( "atan2": core_ops.aten_atan2, "atanh": core_ops.aten_atanh, "atleast_1d": core_ops.aten_atleast_1d, + "atleast_1d_single_tensor": core_ops.aten_atleast_1d_single_tensor, "atleast_2d": core_ops.aten_atleast_2d, + "atleast_2d_single_tensor": core_ops.aten_atleast_2d_single_tensor, "atleast_3d": core_ops.aten_atleast_3d, + "atleast_3d_single_tensor": core_ops.aten_atleast_3d_single_tensor, "baddbmm": core_ops.aten_baddbmm, "bmm": core_ops.aten_bmm, "broadcast_to": core_ops.aten_broadcast_to, @@ -811,6 +814,21 @@ def _where_input_wrangler( matcher=lambda sample: len(sample.args) != 2, reason="arange_start_step overload takes three arguments (input, start, step)", ), + skip( + "atleast_1d_single_tensor", + matcher=lambda sample: isinstance(sample.input, (list, tuple)), + reason="atleast_1d_single_tensor overload takes single tensor as input", + ), + skip( + "atleast_2d_single_tensor", + matcher=lambda sample: isinstance(sample.input, (list, tuple)), + reason="atleast_2d_single_tensor overload takes single tensor as input", + ), + skip( + "atleast_3d_single_tensor", + matcher=lambda sample: isinstance(sample.input, (list, tuple)), + reason="atleast_3d_single_tensor overload takes single tensor as input", + ), skip( "cat", matcher=lambda sample: sample.input[0].equal(torch.tensor([])), @@ -1169,6 +1187,11 @@ def _where_input_wrangler( ), ) +ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_single_tensor",)) +ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_single_tensor",)) +ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_single_tensor",)) + + ops_test_common.duplicate_opinfo(OPS_DB, "full_like", ("full_like_dtype",)) ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) @@ -1487,14 +1510,26 @@ def _where_input_wrangler( torch.float32, torch.float16, ), + "atleast_1d_single_tensor": ( + torch.float32, + torch.float16, + ), "atleast_2d": ( torch.float32, torch.float16, ), + "atleast_2d_single_tensor": ( + torch.float32, + torch.float16, + ), "atleast_3d": ( torch.float32, torch.float16, ), + "atleast_3d_single_tensor": ( + torch.float32, + torch.float16, + ), "baddbmm": ( torch.float32, torch.float16,