From 4eb25a3bef6bdd735f868d9278ae79d9f4565f92 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 Oct 2025 08:55:33 -0700 Subject: [PATCH 1/4] Simplify aten_unbind when shape is static Add static shape handling to aten_unbind function --- onnxscript/function_libs/torch_lib/ops/core.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1a688a4277..a713453bdb 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8764,6 +8764,11 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" + if isinstance(self.shape[dim], int): + # We can create a definitive split op if the input shape is static + outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim]) + return [op.Squeeze(out, [self.shape[dim]]) for out in outputs] + return op.SplitToSequence(self, axis=dim, keepdims=False) From af3dfa2353669737c8a63b07c7457e033ff8d833 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 Oct 2025 09:30:33 -0700 Subject: [PATCH 2/4] Update onnxscript/function_libs/torch_lib/ops/core.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a713453bdb..2a69c1d7a7 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8767,7 +8767,7 @@ def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: if isinstance(self.shape[dim], int): # We can create a definitive split op if the input shape is static outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim]) - return [op.Squeeze(out, [self.shape[dim]]) for out in outputs] + return [op.Squeeze(out, [dim]) for out in outputs] return op.SplitToSequence(self, axis=dim, keepdims=False) From 2f181fe16e626ac365b3c9eb87293630dd01e011 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 6 Oct 2025 09:20:07 -0700 Subject: [PATCH 3/4] torchversion Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a9f397bc68..dfcd5c35b0 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8776,8 +8776,9 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" - if isinstance(self.shape[dim], int): + if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"): # We can create a definitive split op if the input shape is static + # Only torch>=2.7 supports correctly generating the correct number of outputs for Split outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim]) return [op.Squeeze(out, [dim]) for out in outputs] From 99d0441d06f1358ecb2be89c934e696eb27f6c4d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 6 Oct 2025 09:23:21 -0700 Subject: [PATCH 4/4] Test Signed-off-by: Justin Chu --- tests/function_libs/torch_lib/ops_test.py | 3 ++- tests/function_libs/torch_lib/ops_test_data.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index 7ba6f9d37f..45875043ea 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -39,6 +39,7 @@ from torch.utils import _pytree as pytree import onnxscript +from onnxscript._internal import version_utils from tests.function_libs.torch_lib import ( error_reproduction, ops_test_common, @@ -200,7 +201,7 @@ def run_test_output_match( reference_torch_outputs, _ = pytree.tree_flatten(torch_output) if ( op.name.startswith("split") - or op.name.startswith("unbind") + or (op.name.startswith("unbind") and version_utils.torch_older_than("2.7")) or op.name in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"} ): diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index ff4a68d2f6..1b998b1d22 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1478,6 +1478,7 @@ def _where_input_wrangler( reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006", ) .xfail( + enabled_if=version_utils.torch_older_than("2.7"), dtypes=(torch.bool,), reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", ),