Skip to content
Merged
38 changes: 30 additions & 8 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,8 @@ def aten_atanh(self: TFloat) -> TFloat:
return op.Atanh(self)


@torch_op("aten::atleast_1d")
def aten_atleast_1d(self: Sequence[TTensor]) -> TTensor:
@torch_op("aten::atleast_1d", private=True)
def _aten_atleast_1d_onnx(self: Sequence[TTensor]) -> TTensor:
"""atleast_1d(Tensor self) -> Tensor"""

@graph()
Expand All @@ -681,6 +681,11 @@ def reshape_to_1d(tensor):
return op.SequenceMap(self, body=reshape_to_1d)


@torch_op("aten::atleast_1d")
def aten_atleast_1d(self: Sequence[TTensor]) -> TTensor:
return _aten_atleast_1d_onnx(self)


@torch_op("aten::atleast_1d")
def aten_atleast_1d_single_tensor(self: TTensor) -> TTensor:
"""atleast_1d(Tensor self) -> Tensor"""
Expand All @@ -692,8 +697,8 @@ def aten_atleast_1d_single_tensor(self: TTensor) -> TTensor:
return self


@torch_op("aten::atleast_2d")
def aten_atleast_2d(self: Sequence[TTensor]) -> TTensor:
@torch_op("aten::atleast_2d", private=True)
def _aten_atleast_2d_onnx(self: Sequence[TTensor]) -> TTensor:
"""atleast_2d(Tensor self) -> Tensor"""

@graph()
Expand All @@ -707,6 +712,11 @@ def reshape_to_2d(tensor):
return op.SequenceMap(self, body=reshape_to_2d)


@torch_op("aten::atleast_2d")
def aten_atleast_2d(self: Sequence[TTensor]) -> TTensor:
return _aten_atleast_2d_onnx(self)


@torch_op("aten::atleast_2d")
def aten_atleast_2d_single_tensor(self: TTensor) -> TTensor:
"""atleast_2d(Tensor self) -> Tensor"""
Expand Down Expand Up @@ -2875,10 +2885,20 @@ def aten_hspmm(mat1: TensorType, mat2: TensorType) -> TensorType:
raise NotImplementedError()


def aten_hstack(tensors: Sequence[TensorType]) -> TensorType:
@torch_op("aten::hstack", trace_only=True)
def aten_hstack(tensors: Sequence[TTensor]) -> TTensor:
"""hstack(Tensor[] tensors) -> Tensor"""

raise NotImplementedError()
# Use another onnx function
tensors = _aten_atleast_1d_onnx(tensors)

# NOTE: The if/else graph has different shape/type which breaks the
# graph matching. We need to use trace_only.
if len(tensors[0].shape) == 1:
result = op.ConcatFromSequence(tensors, axis=0, new_axis=0)
else:
result = op.ConcatFromSequence(tensors, axis=1, new_axis=0)
return result


def aten_hypot(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -6596,10 +6616,12 @@ def aten_view_copy(self: TensorType, size: INT64) -> TensorType:
raise NotImplementedError()


def aten_vstack(tensors: Sequence[TensorType]) -> TensorType:
@torch_op("aten::vstack")
def aten_vstack(tensors: Sequence[TTensor]) -> TTensor:
"""vstack(Tensor[] tensors) -> Tensor"""

raise NotImplementedError()
tensors = _aten_atleast_2d_onnx(tensors)
return op.ConcatFromSequence(tensors, axis=0)


@torch_op("aten::where")
Expand Down
18 changes: 18 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def _where_input_wrangler(
"unflatten": (core_ops.aten_unflatten, _unflatten_input_wrangler),
"unsqueeze": core_ops.aten_unsqueeze,
"view": core_ops.aten_view,
"vstack": core_ops.aten_vstack,
"where": (core_ops.aten_where, _where_input_wrangler),
"xlogy": special_ops.aten_special_xlogy,
"zeros": core_ops.aten_zeros,
Expand All @@ -533,6 +534,7 @@ def _where_input_wrangler(
"convolution": core_ops.aten_convolution,
"empty_like": core_ops.aten_empty_like,
"grid_sampler_2d": core_ops.aten_grid_sampler_2d,
"hstack": core_ops.aten_hstack,
"nn.functional.grid_sample": (core_ops.aten_grid_sampler, _grid_sample_input_wrangler),
"index_select": core_ops.aten_index_select,
"layer_norm": core_ops.aten_layer_norm,
Expand Down Expand Up @@ -620,6 +622,10 @@ def _where_input_wrangler(
variant_name="partial_views",
reason="ONNX doesn't have partial view for tensor",
),
xfail(
"hstack",
reason="fixme: A bug of constant-propagation optimization within the subgraph, we can avoid it by turning off graph-optimizations in session options",
),
xfail("logcumsumexp", reason="naive implementation not numerically stable"),
xfail(
"max",
Expand Down Expand Up @@ -760,6 +766,10 @@ def _where_input_wrangler(
reason="fixme: ORT fails with invalid model: 'INVALID_ARGUMENT : Failed to load model with error: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1)'",
test_class_name="TestOutputConsistencyFullGraph",
),
xfail(
"vstack",
reason="fixme: A bug of constant-propagation optimization within the subgraph, we can avoid it by turning off graph-optimizations in session options",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we turn it off in test and file a bug to ORT?

Copy link
Contributor Author

@titaiwangms titaiwangms Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would nested affect performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xfail actually works. I think CI breaks because of something else. Actually @gramalingam spot the bug, but I can still file an issue to track.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

),
)


Expand Down Expand Up @@ -1697,6 +1707,10 @@ def _where_input_wrangler(
torch.float32,
torch.float16,
),
"hstack": (
torch.float32,
torch.float16,
),
# "is_same_size": core_ops.aten_is_same_size, # no test case in OPS_DB
# "is_nonzero": core_ops.aten_is_nonzero, # no test case in OPS_DB
"index_put_bool": (
Expand Down Expand Up @@ -2276,6 +2290,10 @@ def _where_input_wrangler(
torch.float32,
torch.float16,
),
"vstack": (
torch.float32,
torch.float16,
),
"where": (
torch.float32,
torch.float16,
Expand Down