From 7424f69ba9eeec342886398f6204b9567e588d44 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 6 Aug 2025 12:08:31 +0200 Subject: [PATCH 01/10] Implements repeat_interleave --- .../function_libs/torch_lib/ops/core.py | 104 +++++++++++++++++- .../function_libs/torch_lib/e2e_ops_tests.py | 42 +++++++ .../function_libs/torch_lib/ops_test_data.py | 1 + 3 files changed, 143 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 595f4a758a..cbfc79dadb 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7292,12 +7292,108 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor: return op.Tile(self_expanded, repeats) -def aten_repeat_interleave( - repeats: TensorType, output_size: Optional[int] = None +@torch_op("aten::repeat_interleave.Scalar", trace_only=True) +def aten_repeat_interleave_int( + self: TensorType, repeats: int, dim: Optional[int] ) -> TensorType: - """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor""" + """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor - raise NotImplementedError() + The trick is to repeat in one direction orthogonal to reshape. + + .. code-block:: python + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + x.repeat_interleave(2, dim=0) + + is equivalent to: + + .. code-block:: python + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + x.repeat((1, 2)).reshape((-1, t.shape[1])) + """ + if dim is None: + raise NotImplementedError("No conversion available yet when dim is None.") + + self_rank = len(self.shape) + pos_dim = (dim + self_rank) % self_rank + unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) + onehot = op.Concat(op.ConstantOfShape((self_rank,), value=[1]), repeats, axis=0) + tiled = op.Tile(unsqueezed, onehot) + + if dim < -1: + dim += self_rank + return aten_flatten( + tiled, + -2 if dim == -1 else dim, + -1 if dim == -1 else (dim + 1) + ) + +@torch_op("aten::repeat_interleave.Tensor", trace_only=True) +def aten_repeat_interleave_Tensor( + self: TensorType, repeats: Optional[TensorType]=None, dim: Optional[int]=None +) -> TensorType: + """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor + + When `repeats` is a tensor, each line is multiplied + by a different number. + There are multiple strategies. Here is one. + + .. code-block:: python + + import torch + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + times = torch.tensor([2, 3], dtype=torch.int64) + y = x.repeat_interleave(times, dim=0) + print("repeat_interleave") + print(y) + + ci = times.cumsum(dim=0) + rows = torch.arange(ci[-1], dtype=torch.int64) < ci.reshape((-1, 1)) + srows = times.shape[0] - rows.to(torch.int64).sum(axis=0) + indices = srows.reshape((-1, )) + print("decomposed") + print(x[indices, :]) + """ + if repeats is None: + repeats = self + self = op.Range(0, op.Squeeze(op.Shape(repeats, start=-1), [0]), 1) + if dim is None: + # flatten + self = op.Reshape(self, [-1]) + rk = 1 + else: + rk = len(self.shape) + + if rk > 2: + shape_x0 = op.Shape(self, start=0, end=1) + shape_x = op.Shape(self, start=1) + self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0)) + elif rk == 1: + shape_x0 = None + shape_x = None + self = op.Reshape(self, [-1, 1]) + else: + if rk != 2: + raise NotImplementedError(f"rank(self)={rk} not implemented for repeat_interleave") + + ci = op.CumSum(repeats, [0]) + last_ci = op.Gather(ci, [-1]) + trange = op.Range(0, op.Squeeze(last_ci, [0]), 1) + rows = op.Less(trange, op.Unsqueeze(ci, [-1])) + srows = op.Sub( + op.Shape(self, start=0, end=1), + op.ReduceSum(op.Cast(rows, to=INT64.dtype), [0]), + ) + indices = op.Reshape(srows, [-1]) + values = op.GatherND(self, op.Unsqueeze(indices, [-1])) + if rk == 2: + return values + return op.Reshape( + values, + op.Concat([-1], shape_x, axis=0) if shape_x else [-1], + ) @torch_op("aten::reshape") diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 7c2978f6de..f3163ae0b0 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -76,6 +76,48 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_integer(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.repeat_interleave(x, 3, dim=1) + + onnx_program = torch.onnx.export( + Model(), (torch.randn(2, 3),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_tensor(self): + class Model(torch.nn.Module): + def forward(self, x, ind): + return torch.repeat_interleave(x, ind, dim=0) + + onnx_program = torch.onnx.export( + Model(), + ( + torch.arange(6, dtype=torch.float32).reshape((2, 3)), + torch.tensor([1, 2], dtype=torch.int64), + ), + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_tensor_none(self): + class Model(torch.nn.Module): + def forward(self, x, ind): + return torch.repeat_interleave(x, ind) + + onnx_program = torch.onnx.export( + Model(), + ( + torch.arange(4, dtype=torch.float32).reshape((2, 2)), + torch.tensor([1, 2, 3, 2], dtype=torch.int64), + ), + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index cd2d933309..eb67cb8bbe 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1240,6 +1240,7 @@ def _where_input_wrangler( core_ops.aten_remainder, ), TorchLibOpInfo("repeat", core_ops.aten_repeat), + TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_tensor), TorchLibOpInfo("reshape", core_ops.aten_reshape), TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj), TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg), From 05c90622392d5bae1da762082177af7cf7635efa Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 6 Aug 2025 12:42:26 +0200 Subject: [PATCH 02/10] remove mistake --- tests/function_libs/torch_lib/ops_test_data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index eb67cb8bbe..4051d879dd 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1240,7 +1240,9 @@ def _where_input_wrangler( core_ops.aten_remainder, ), TorchLibOpInfo("repeat", core_ops.aten_repeat), - TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_tensor), + # needs to split into two cases + # TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Scalar), + # TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Tensor), TorchLibOpInfo("reshape", core_ops.aten_reshape), TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj), TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg), From 9560baa9ba0b9be289722c172be3bf691b9baff5 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 6 Aug 2025 14:21:57 +0200 Subject: [PATCH 03/10] lint --- onnxscript/function_libs/torch_lib/ops/core.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index cbfc79dadb..6638b526a3 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7323,15 +7323,11 @@ def aten_repeat_interleave_int( if dim < -1: dim += self_rank - return aten_flatten( - tiled, - -2 if dim == -1 else dim, - -1 if dim == -1 else (dim + 1) - ) + return aten_flatten(tiled, -2 if dim == -1 else dim, -1 if dim == -1 else (dim + 1)) @torch_op("aten::repeat_interleave.Tensor", trace_only=True) def aten_repeat_interleave_Tensor( - self: TensorType, repeats: Optional[TensorType]=None, dim: Optional[int]=None + self: TensorType, repeats: Optional[TensorType] = None, dim: Optional[int] = None ) -> TensorType: """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor From 689a456e78cf3f261032ef0fa5731ff5159c3be9 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 6 Aug 2025 14:51:50 +0200 Subject: [PATCH 04/10] lint --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6638b526a3..4f8f443337 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7325,6 +7325,7 @@ def aten_repeat_interleave_int( dim += self_rank return aten_flatten(tiled, -2 if dim == -1 else dim, -1 if dim == -1 else (dim + 1)) + @torch_op("aten::repeat_interleave.Tensor", trace_only=True) def aten_repeat_interleave_Tensor( self: TensorType, repeats: Optional[TensorType] = None, dim: Optional[int] = None @@ -7373,6 +7374,7 @@ def aten_repeat_interleave_Tensor( else: if rk != 2: raise NotImplementedError(f"rank(self)={rk} not implemented for repeat_interleave") + shape_x = None ci = op.CumSum(repeats, [0]) last_ci = op.Gather(ci, [-1]) @@ -7386,6 +7388,8 @@ def aten_repeat_interleave_Tensor( values = op.GatherND(self, op.Unsqueeze(indices, [-1])) if rk == 2: return values + # shape_x cannot be None at this stage. + assert shape_x is not None # for mypy return op.Reshape( values, op.Concat([-1], shape_x, axis=0) if shape_x else [-1], From 9881f6e25c41ef2c355553f83a650c1439f07767 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 6 Aug 2025 20:51:32 +0200 Subject: [PATCH 05/10] fixed --- .../function_libs/torch_lib/ops/core.py | 55 +++++++++++++++++-- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4f8f443337..00c7edc59c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7292,7 +7292,7 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor: return op.Tile(self_expanded, repeats) -@torch_op("aten::repeat_interleave.Scalar", trace_only=True) +@torch_op("aten::repeat_interleave.self_int", trace_only=True) def aten_repeat_interleave_int( self: TensorType, repeats: int, dim: Optional[int] ) -> TensorType: @@ -7318,12 +7318,57 @@ def aten_repeat_interleave_int( self_rank = len(self.shape) pos_dim = (dim + self_rank) % self_rank unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) - onehot = op.Concat(op.ConstantOfShape((self_rank,), value=[1]), repeats, axis=0) + onehot = op.Concat( + op.ConstantOfShape( + op.Constant(value_ints=[self_rank]), + value=ir.tensor([1], dtype=INT64.dtype) + ), + op.Constant(value_ints=[repeats]), + axis=0, + ) tiled = op.Tile(unsqueezed, onehot) + # tiled has no shape at this stage + # return aten_flatten(tiled, -2 if dim == -1 else dim, -1 if dim == -1 else (dim + 1)) if dim < -1: dim += self_rank - return aten_flatten(tiled, -2 if dim == -1 else dim, -1 if dim == -1 else (dim + 1)) + + if self_rank == 1: + return op.Identity(tiled) + + start_dim, end_dim = -2 if dim == -1 else dim, -1 if dim == -1 else (dim + 1) + if start_dim == 1: + if end_dim in (-1, dim - 1): + return op.Flatten(tiled, axis=start_dim) + elif start_dim == 0: + if end_dim in (-2, dim - 2): + return op.Flatten(tiled, axis=end_dim + 1) + if end_dim < 0: + end_dim = dim + end_dim + + input_size = op.Shape(tiled) + dim_head = op.Slice( + input_size, + op.Constant(value_ints=[0]), + op.Constant(value_ints=[start_dim]), + op.Constant(value_ints=[0]), + ) + final_dims = [dim_head, op.Constant(value_ints=[-1])] + if end_dim < dim - 1: + dim_tail = op.Slice( + input_size, + op.Constant(value_ints=[end_dim + 1]), + op.Constant(value_ints=[dim]), + op.Constant(value_ints=[0]), + ) + final_dims = [ + dim_head, + op.Constant(value_ints=[-1]), + dim_tail, + ] + + final_shape = op.Concat(*final_dims, axis=0) + return op.Reshape(tiled, final_shape) @torch_op("aten::repeat_interleave.Tensor", trace_only=True) @@ -7388,8 +7433,8 @@ def aten_repeat_interleave_Tensor( values = op.GatherND(self, op.Unsqueeze(indices, [-1])) if rk == 2: return values - # shape_x cannot be None at this stage. - assert shape_x is not None # for mypy + # shape_x is None at this stage. + assert shape_x is None # for mypy return op.Reshape( values, op.Concat([-1], shape_x, axis=0) if shape_x else [-1], From 84e297eb87b06778900866fd077273d2e6369aaa Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 7 Aug 2025 19:52:20 +0200 Subject: [PATCH 06/10] lint --- onnxscript/function_libs/torch_lib/ops/core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 00c7edc59c..2b1b07d91f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7320,8 +7320,7 @@ def aten_repeat_interleave_int( unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) onehot = op.Concat( op.ConstantOfShape( - op.Constant(value_ints=[self_rank]), - value=ir.tensor([1], dtype=INT64.dtype) + op.Constant(value_ints=[self_rank]), value=ir.tensor([1], dtype=INT64.dtype) ), op.Constant(value_ints=[repeats]), axis=0, From 0e590a44e58874790431750583f5f2b9da535739 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 26 Aug 2025 14:15:19 +0200 Subject: [PATCH 07/10] restore one test worngly merged Signed-off-by: xadupre --- tests/function_libs/torch_lib/e2e_ops_tests.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index d4b38c6aff..58a41d8fc6 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -107,15 +107,25 @@ class Model(torch.nn.Module): def forward(self, x, ind): return torch.repeat_interleave(x, ind) - onnx_program = torch.onnx.export( - Model(), - ( + inputs = ( torch.arange(4, dtype=torch.float32).reshape((2, 2)), torch.tensor([1, 2, 3, 2], dtype=torch.int64), - ), + ) + onnx_program = torch.onnx.export( + Model(), + inputs, dynamo=True, optimize=False, ) + onnx_program = torch.onnx.export( + Model(), + inputs, + input_names=["x", "ind"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) def test_sdpa_with_bool_attn_mask(self): class ScaledDotProductAttention(torch.nn.Module): From 9586e579a6fe843ac1b6a20611495886ea485501 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 29 Aug 2025 11:31:46 +0200 Subject: [PATCH 08/10] merge with main --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- tests/function_libs/torch_lib/e2e_ops_tests.py | 6 +++--- tests/function_libs/torch_lib/ops_test_data.py | 5 ++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index fe3206f400..ff8ae4c9a9 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7293,7 +7293,7 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor: @torch_op("aten::repeat_interleave.self_int", trace_only=True) -def aten_repeat_interleave_int( +def aten_repeat_interleave_self_int( self: TensorType, repeats: int, dim: Optional[int] ) -> TensorType: """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 58a41d8fc6..c44774b5cb 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -108,9 +108,9 @@ def forward(self, x, ind): return torch.repeat_interleave(x, ind) inputs = ( - torch.arange(4, dtype=torch.float32).reshape((2, 2)), - torch.tensor([1, 2, 3, 2], dtype=torch.int64), - ) + torch.arange(4, dtype=torch.float32).reshape((2, 2)), + torch.tensor([1, 2, 3, 2], dtype=torch.int64), + ) onnx_program = torch.onnx.export( Model(), inputs, diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 538e24e194..9d0e895ddc 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1250,9 +1250,8 @@ def _where_input_wrangler( core_ops.aten_remainder, ), TorchLibOpInfo("repeat", core_ops.aten_repeat), - # needs to split into two cases - # TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Scalar), - # TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Tensor), + TorchLibOpInfo("repeat_interleave.self_int", core_ops.aten_repeat_interleave_self_int), + TorchLibOpInfo("repeat_interleave.Tensor", core_ops.aten_repeat_interleave_Tensor), TorchLibOpInfo("reshape", core_ops.aten_reshape), TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj), TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg), From 7cc457ab441a6e088ccbe52ce49190b51bcc57b5 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 29 Aug 2025 12:58:21 +0200 Subject: [PATCH 09/10] fix repeat_interleave --- .../function_libs/torch_lib/ops/core.py | 58 ++++--------------- .../function_libs/torch_lib/e2e_ops_tests.py | 12 +++- .../function_libs/torch_lib/ops_test_data.py | 36 +++++++++++- 3 files changed, 55 insertions(+), 51 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ff8ae4c9a9..39e836f8dd 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7294,7 +7294,7 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor: @torch_op("aten::repeat_interleave.self_int", trace_only=True) def aten_repeat_interleave_self_int( - self: TensorType, repeats: int, dim: Optional[int] + self: TensorType, repeats: int, dim: Optional[int] = None ) -> TensorType: """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor @@ -7318,55 +7318,18 @@ def aten_repeat_interleave_self_int( self_rank = len(self.shape) pos_dim = (dim + self_rank) % self_rank unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) - onehot = op.Concat( - op.ConstantOfShape( - op.Constant(value_ints=[self_rank]), value=ir.tensor([1], dtype=INT64.dtype) - ), - op.Constant(value_ints=[repeats]), - axis=0, - ) - tiled = op.Tile(unsqueezed, onehot) - - # tiled has no shape at this stage - # return aten_flatten(tiled, -2 if dim == -1 else dim, -1 if dim == -1 else (dim + 1)) - if dim < -1: - dim += self_rank - + tiles = [1] * (self_rank + 1) + tiles[pos_dim + 1] = repeats + tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype)) + tiled = op.Tile(unsqueezed, tile_repeat) if self_rank == 1: return op.Identity(tiled) - - start_dim, end_dim = -2 if dim == -1 else dim, -1 if dim == -1 else (dim + 1) - if start_dim == 1: - if end_dim in (-1, dim - 1): - return op.Flatten(tiled, axis=start_dim) - elif start_dim == 0: - if end_dim in (-2, dim - 2): - return op.Flatten(tiled, axis=end_dim + 1) - if end_dim < 0: - end_dim = dim + end_dim - - input_size = op.Shape(tiled) - dim_head = op.Slice( - input_size, - op.Constant(value_ints=[0]), - op.Constant(value_ints=[start_dim]), - op.Constant(value_ints=[0]), + final_shape = op.Concat( + op.Shape(self, start=0, end=dim), + op.Constant(value_ints=[-1]), + op.Shape(self, start=dim + 1), + axis=0, ) - final_dims = [dim_head, op.Constant(value_ints=[-1])] - if end_dim < dim - 1: - dim_tail = op.Slice( - input_size, - op.Constant(value_ints=[end_dim + 1]), - op.Constant(value_ints=[dim]), - op.Constant(value_ints=[0]), - ) - final_dims = [ - dim_head, - op.Constant(value_ints=[-1]), - dim_tail, - ] - - final_shape = op.Concat(*final_dims, axis=0) return op.Reshape(tiled, final_shape) @@ -7412,7 +7375,6 @@ def aten_repeat_interleave_Tensor( shape_x = op.Shape(self, start=1) self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0)) elif rk == 1: - shape_x0 = None shape_x = None self = op.Reshape(self, [-1, 1]) else: diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index c44774b5cb..a0d0a0d880 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -76,7 +76,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) _testing.assert_onnx_program(onnx_program) - def test_repeat_interleave_integer(self): + def test_repeat_interleave_integer_1(self): class Model(torch.nn.Module): def forward(self, x): return torch.repeat_interleave(x, 3, dim=1) @@ -86,6 +86,16 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_integer_2(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.repeat_interleave(x, 3, dim=1) + + onnx_program = torch.onnx.export( + Model(), (torch.randn(2, 3, 4),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_tensor(self): class Model(torch.nn.Module): def forward(self, x, ind): diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 9d0e895ddc..01db7161b5 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1250,8 +1250,40 @@ def _where_input_wrangler( core_ops.aten_remainder, ), TorchLibOpInfo("repeat", core_ops.aten_repeat), - TorchLibOpInfo("repeat_interleave.self_int", core_ops.aten_repeat_interleave_self_int), - TorchLibOpInfo("repeat_interleave.Tensor", core_ops.aten_repeat_interleave_Tensor), + TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_int) + .skip( + matcher=lambda sample: not isinstance(sample.kwargs.get("repeats", None), int), + reason=("ignore cases when repeasts is a Tensor"), + ) + .skip( + dtypes=(torch.bool,), + reason="bool not supported", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("dim") is None, + reason="fixme: conversion not implemented if dim is None", + ) + .skip( + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: conversion not implemented when input tensor is empty", + ), + TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Tensor) + .skip( + matcher=lambda sample: isinstance(sample.kwargs.get("repeats", None), int), + reason=("ignore cases when repeasts is an int"), + ) + .skip( + dtypes=(torch.bool,), + reason="bool not supported", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("dim") is None, + reason="fixme: conversion not implemented if dim is None", + ) + .skip( + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: conversion not implemented when input tensor is empty", + ), TorchLibOpInfo("reshape", core_ops.aten_reshape), TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj), TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg), From b34332937fadb2abde440fe79d6a69266d08739d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 29 Aug 2025 18:03:43 +0200 Subject: [PATCH 10/10] Update onnxscript/function_libs/torch_lib/ops/core.py Co-authored-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 39e836f8dd..2e6bf9530c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7296,7 +7296,7 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor: def aten_repeat_interleave_self_int( self: TensorType, repeats: int, dim: Optional[int] = None ) -> TensorType: - """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor + """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor The trick is to repeat in one direction orthogonal to reshape.