Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 71 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import math
from typing import Any, Optional, Sequence, Tuple, Union

from onnxscript import BOOL, DOUBLE, FLOAT, INT8, INT16, INT32, INT64
from onnxscript import BOOL, DOUBLE, FLOAT, INT8, INT16, INT32, INT64, graph
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import (
IntType,
Expand Down Expand Up @@ -666,22 +666,86 @@ def aten_atanh(self: TFloat) -> TFloat:
return op.Atanh(self)


def aten_atleast_1d(self: TensorType) -> TensorType:
@torch_op("aten::atleast_1d")
def aten_atleast_1d(self: Sequence[TTensor]) -> TTensor:
"""atleast_1d(Tensor self) -> Tensor"""

raise NotImplementedError()
@graph()
def reshape_to_1d(tensor):
shape = op.Shape(tensor)
rank = op.Size(shape)
if rank == 0:
tensor = op.Reshape(tensor, op.Constant(value_ints=[1]))
return tensor

return op.SequenceMap(self, body=reshape_to_1d)


@torch_op("aten::atleast_1d")
def aten_atleast_1d_single_tensor(self: TTensor) -> TTensor:
"""atleast_1d(Tensor self) -> Tensor"""

shape = op.Shape(self)
rank = op.Size(shape)
if rank == 0:
self = op.Reshape(self, op.Constant(value_ints=[1]))
return self


def aten_atleast_2d(self: TensorType) -> TensorType:
@torch_op("aten::atleast_2d")
def aten_atleast_2d(self: Sequence[TTensor]) -> TTensor:
"""atleast_2d(Tensor self) -> Tensor"""

raise NotImplementedError()
@graph()
def reshape_to_2d(tensor):
shape = op.Shape(tensor)
rank = op.Size(shape)
if rank <= 1:
tensor = op.Reshape(tensor, op.Constant(value_ints=[1, -1]))
return tensor

return op.SequenceMap(self, body=reshape_to_2d)


@torch_op("aten::atleast_2d")
def aten_atleast_2d_single_tensor(self: TTensor) -> TTensor:
"""atleast_2d(Tensor self) -> Tensor"""

shape = op.Shape(self)
rank = op.Size(shape)
if rank <= 1:
self = op.Reshape(self, op.Constant(value_ints=[1, -1]))
return self


def aten_atleast_3d(self: TensorType) -> TensorType:
@torch_op("aten::atleast_3d")
def aten_atleast_3d(self: Sequence[TTensor]) -> TTensor:
"""atleast_3d(Tensor self) -> Tensor"""

raise NotImplementedError()
@graph()
def reshape_to_3d(tensor):
shape = op.Shape(tensor)
rank = op.Size(shape)
if rank <= 1:
tensor = op.Reshape(tensor, op.Constant(value_ints=[1, -1, 1]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to double-check the downstream usage to make sure whether you want to reshape a tensor of size N to 1x1xN or Nx1x1 ... the above reshapes it to 1xNx1 .... which may be okay, but just wondering.

elif rank == 2:
tensor = op.Unsqueeze(tensor, op.Constant(value_ints=[-1]))
return tensor

return op.SequenceMap(self, body=reshape_to_3d)


@torch_op("aten::atleast_3d")
def aten_atleast_3d_single_tensor(self: TTensor) -> TTensor:
"""atleast_3d(Tensor self) -> Tensor"""

shape = op.Shape(self)
rank = op.Size(shape)
if rank <= 1:
self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1]))
elif rank == 2:
self = op.Unsqueeze(self, op.Constant(value_ints=[-1]))
return self


def aten_avg_pool1d(
Expand Down
50 changes: 50 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ def _where_input_wrangler(
"atan": core_ops.aten_atan,
"atan2": core_ops.aten_atan2,
"atanh": core_ops.aten_atanh,
"atleast_1d": core_ops.aten_atleast_1d,
"atleast_1d_single_tensor": core_ops.aten_atleast_1d_single_tensor,
"atleast_2d": core_ops.aten_atleast_2d,
"atleast_2d_single_tensor": core_ops.aten_atleast_2d_single_tensor,
"atleast_3d": core_ops.aten_atleast_3d,
"atleast_3d_single_tensor": core_ops.aten_atleast_3d_single_tensor,
"baddbmm": core_ops.aten_baddbmm,
"bmm": core_ops.aten_bmm,
"broadcast_to": core_ops.aten_broadcast_to,
Expand Down Expand Up @@ -808,6 +814,21 @@ def _where_input_wrangler(
matcher=lambda sample: len(sample.args) != 2,
reason="arange_start_step overload takes three arguments (input, start, step)",
),
skip(
"atleast_1d_single_tensor",
matcher=lambda sample: isinstance(sample.input, (list, tuple)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don’t take a Sequence as input? Where is this op used? How are the inputs produced?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It accepts both tensor and list.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see there are two variants

reason="atleast_1d_single_tensor overload takes single tensor as input",
),
skip(
"atleast_2d_single_tensor",
matcher=lambda sample: isinstance(sample.input, (list, tuple)),
reason="atleast_2d_single_tensor overload takes single tensor as input",
),
skip(
"atleast_3d_single_tensor",
matcher=lambda sample: isinstance(sample.input, (list, tuple)),
reason="atleast_3d_single_tensor overload takes single tensor as input",
),
skip(
"cat",
matcher=lambda sample: sample.input[0].equal(torch.tensor([])),
Expand Down Expand Up @@ -1166,6 +1187,11 @@ def _where_input_wrangler(
),
)

ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_single_tensor",))
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_single_tensor",))
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_single_tensor",))


ops_test_common.duplicate_opinfo(OPS_DB, "full_like", ("full_like_dtype",))

ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",))
Expand Down Expand Up @@ -1480,6 +1506,30 @@ def _where_input_wrangler(
torch.float32,
torch.float16,
),
"atleast_1d": (
torch.float32,
torch.float16,
),
"atleast_1d_single_tensor": (
torch.float32,
torch.float16,
),
"atleast_2d": (
torch.float32,
torch.float16,
),
"atleast_2d_single_tensor": (
torch.float32,
torch.float16,
),
"atleast_3d": (
torch.float32,
torch.float16,
),
"atleast_3d_single_tensor": (
torch.float32,
torch.float16,
),
"baddbmm": (
torch.float32,
torch.float16,
Expand Down