diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d9b576a3f6..dd142a6b5c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -14,7 +14,7 @@ import math from typing import Any, Optional, Sequence, Tuple, Union -from onnxscript import BOOL, DOUBLE, FLOAT, INT8, INT16, INT32, INT64 +from onnxscript import BOOL, DOUBLE, FLOAT, INT8, INT16, INT32, INT64, graph from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import ( IntType, @@ -666,22 +666,86 @@ def aten_atanh(self: TFloat) -> TFloat: return op.Atanh(self) -def aten_atleast_1d(self: TensorType) -> TensorType: +@torch_op("aten::atleast_1d") +def aten_atleast_1d(self: Sequence[TTensor]) -> TTensor: """atleast_1d(Tensor self) -> Tensor""" - raise NotImplementedError() + @graph() + def reshape_to_1d(tensor): + shape = op.Shape(tensor) + rank = op.Size(shape) + if rank == 0: + tensor = op.Reshape(tensor, op.Constant(value_ints=[1])) + return tensor + + return op.SequenceMap(self, body=reshape_to_1d) + + +@torch_op("aten::atleast_1d") +def aten_atleast_1d_single_tensor(self: TTensor) -> TTensor: + """atleast_1d(Tensor self) -> Tensor""" + + shape = op.Shape(self) + rank = op.Size(shape) + if rank == 0: + self = op.Reshape(self, op.Constant(value_ints=[1])) + return self -def aten_atleast_2d(self: TensorType) -> TensorType: +@torch_op("aten::atleast_2d") +def aten_atleast_2d(self: Sequence[TTensor]) -> TTensor: """atleast_2d(Tensor self) -> Tensor""" - raise NotImplementedError() + @graph() + def reshape_to_2d(tensor): + shape = op.Shape(tensor) + rank = op.Size(shape) + if rank <= 1: + tensor = op.Reshape(tensor, op.Constant(value_ints=[1, -1])) + return tensor + + return op.SequenceMap(self, body=reshape_to_2d) + + +@torch_op("aten::atleast_2d") +def aten_atleast_2d_single_tensor(self: TTensor) -> TTensor: + """atleast_2d(Tensor self) -> Tensor""" + + shape = op.Shape(self) + rank = op.Size(shape) + if rank <= 1: + self = op.Reshape(self, op.Constant(value_ints=[1, -1])) + return self -def aten_atleast_3d(self: TensorType) -> TensorType: +@torch_op("aten::atleast_3d") +def aten_atleast_3d(self: Sequence[TTensor]) -> TTensor: """atleast_3d(Tensor self) -> Tensor""" - raise NotImplementedError() + @graph() + def reshape_to_3d(tensor): + shape = op.Shape(tensor) + rank = op.Size(shape) + if rank <= 1: + tensor = op.Reshape(tensor, op.Constant(value_ints=[1, -1, 1])) + elif rank == 2: + tensor = op.Unsqueeze(tensor, op.Constant(value_ints=[-1])) + return tensor + + return op.SequenceMap(self, body=reshape_to_3d) + + +@torch_op("aten::atleast_3d") +def aten_atleast_3d_single_tensor(self: TTensor) -> TTensor: + """atleast_3d(Tensor self) -> Tensor""" + + shape = op.Shape(self) + rank = op.Size(shape) + if rank <= 1: + self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1])) + elif rank == 2: + self = op.Unsqueeze(self, op.Constant(value_ints=[-1])) + return self def aten_avg_pool1d( diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index fc9c1eda26..a4ab21a5a9 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -344,6 +344,12 @@ def _where_input_wrangler( "atan": core_ops.aten_atan, "atan2": core_ops.aten_atan2, "atanh": core_ops.aten_atanh, + "atleast_1d": core_ops.aten_atleast_1d, + "atleast_1d_single_tensor": core_ops.aten_atleast_1d_single_tensor, + "atleast_2d": core_ops.aten_atleast_2d, + "atleast_2d_single_tensor": core_ops.aten_atleast_2d_single_tensor, + "atleast_3d": core_ops.aten_atleast_3d, + "atleast_3d_single_tensor": core_ops.aten_atleast_3d_single_tensor, "baddbmm": core_ops.aten_baddbmm, "bmm": core_ops.aten_bmm, "broadcast_to": core_ops.aten_broadcast_to, @@ -808,6 +814,21 @@ def _where_input_wrangler( matcher=lambda sample: len(sample.args) != 2, reason="arange_start_step overload takes three arguments (input, start, step)", ), + skip( + "atleast_1d_single_tensor", + matcher=lambda sample: isinstance(sample.input, (list, tuple)), + reason="atleast_1d_single_tensor overload takes single tensor as input", + ), + skip( + "atleast_2d_single_tensor", + matcher=lambda sample: isinstance(sample.input, (list, tuple)), + reason="atleast_2d_single_tensor overload takes single tensor as input", + ), + skip( + "atleast_3d_single_tensor", + matcher=lambda sample: isinstance(sample.input, (list, tuple)), + reason="atleast_3d_single_tensor overload takes single tensor as input", + ), skip( "cat", matcher=lambda sample: sample.input[0].equal(torch.tensor([])), @@ -1166,6 +1187,11 @@ def _where_input_wrangler( ), ) +ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_single_tensor",)) +ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_single_tensor",)) +ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_single_tensor",)) + + ops_test_common.duplicate_opinfo(OPS_DB, "full_like", ("full_like_dtype",)) ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) @@ -1480,6 +1506,30 @@ def _where_input_wrangler( torch.float32, torch.float16, ), + "atleast_1d": ( + torch.float32, + torch.float16, + ), + "atleast_1d_single_tensor": ( + torch.float32, + torch.float16, + ), + "atleast_2d": ( + torch.float32, + torch.float16, + ), + "atleast_2d_single_tensor": ( + torch.float32, + torch.float16, + ), + "atleast_3d": ( + torch.float32, + torch.float16, + ), + "atleast_3d_single_tensor": ( + torch.float32, + torch.float16, + ), "baddbmm": ( torch.float32, torch.float16,