diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 09704199f..a25015b23 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -51,6 +51,7 @@ from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType +_INT32_MAX = 2147483647 _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi @@ -9183,15 +9184,57 @@ def aten_unfold_copy(self: TensorType, dimension: int, size: int, step: int) -> raise NotImplementedError() +@torch_op("aten::unique_consecutive", trace_only=True) def aten_unique_consecutive( - self: TensorType, + x: TensorType, return_inverse: bool = False, return_counts: bool = False, dim: Optional[int] = None, ) -> tuple[TensorType, TensorType, TensorType]: """unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)""" + assert x.dtype in {INT64.dtype, INT32.dtype}, ( + "unique_consecutive not implemented for other type than int32, int64" + ) + rank_x = len(x.shape) - raise NotImplementedError() + zero = op.Constant(value=ir.tensor([0], dtype=x.dtype)) + zero64 = op.Constant(value=ir.tensor([0], dtype=INT64.dtype)) + minus_one = op.Constant(value=ir.tensor([-1], dtype=INT64.dtype)) + + if dim is None: + if rank_x != 1: + x = op.Reshape(x, minus_one) + else: + assert rank_x == 1 and dim == 0, ( + f"Not implemented for x={x!r} with rank={rank_x} and dim={dim}." + ) + + lag = op.Concat( + # Hopefully this will never be equal to the first value of the tensor x + # ideally we could do differently but with a higher cost + op.Constant(value=ir.tensor([_INT32_MAX], dtype=x.dtype)), + op.Slice(x, zero64, minus_one, zero64), + axis=0, + ) + eq = op.Equal(x, lag) + diff = op.Not(eq) + res = op.Compress(x, diff, axis=0) + + zero_no_dim = op.Constant(value=ir.tensor(0, dtype=x.dtype)) + one_no_dim = op.Constant(value=ir.tensor(1, dtype=x.dtype)) + one = op.Constant(value=ir.tensor([1], dtype=x.dtype)) + + inverse = op.Sub(op.CumSum(op.Cast(diff, to=x.dtype), zero), one) + shape_x = op.Shape(x) + indices = op.Range(zero_no_dim, op.Squeeze(shape_x), one_no_dim) + points = op.Compress(indices, diff, axis=0) + lagp = op.Concat( + op.Slice(points, one, op.Shape(points), zero), + shape_x, + axis=0, + ) + counts = op.Sub(lagp, points) + return res, inverse, counts @torch_op("aten::_unique", trace_only=True) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index cb272a98a..1546de59b 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -406,6 +406,51 @@ def forward(self, x): onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) _testing.assert_onnx_program(onnx_program) + def test_aten_unique_consecutive(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.unique_consecutive(x) + + model = Model() + x = torch.tensor([0, 1, 2, 2, 3, 3, 0, 0], dtype=torch.int64) + onnx_program = torch.onnx.export( + model, + (x,), + dynamic_shapes=({0: "length"},), + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_unique_consecutive_int32(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.unique_consecutive(x) + + model = Model() + x = torch.tensor([0, 1, 2, 2, 3, 3, 0, 0], dtype=torch.int32) + onnx_program = torch.onnx.export( + model, + (x,), + dynamic_shapes=({0: "length"},), + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_unique_consecutive_return(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.unique_consecutive(x, return_inverse=True, return_counts=True) + + model = Model() + x = torch.tensor([0, 1, 2, 2, 3, 3, 3, 0, 0], dtype=torch.int64) + onnx_program = torch.onnx.export( + model, + (x,), + dynamic_shapes=({0: "length"},), + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + def test_aten_stft_1(self): class Model(torch.nn.Module): def forward(self, x):