From f0ff3cb3261d957023642980948b436aba8911c2 Mon Sep 17 00:00:00 2001 From: Sebastian Mossburger Date: Wed, 5 Nov 2025 16:04:31 +0100 Subject: [PATCH 1/2] fix(outputs): Cast to sequence if num_outputs=1 --- onnxscript/function_libs/torch_lib/ops/core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 96f64bbb8a..a9a7f90419 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8957,7 +8957,11 @@ def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"): # We can create a definitive split op if the input shape is static # Only torch>=2.7 supports correctly generating the correct number of outputs for Split - outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim]) + num_outputs = self.shape[dim] + outputs = op.Split(self, axis=dim, num_outputs=num_outputs) + if num_outputs == 1: + outputs = [outputs] + return [op.Squeeze(out, [dim]) for out in outputs] return op.SplitToSequence(self, axis=dim, keepdims=False) From 58a167225bf4da786bc34db8251eba718264534f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 5 Nov 2025 13:14:51 -0800 Subject: [PATCH 2/2] Update onnxscript/function_libs/torch_lib/ops/core.py --- onnxscript/function_libs/torch_lib/ops/core.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a9a7f90419..767dffacf7 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8958,9 +8958,10 @@ def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: # We can create a definitive split op if the input shape is static # Only torch>=2.7 supports correctly generating the correct number of outputs for Split num_outputs = self.shape[dim] - outputs = op.Split(self, axis=dim, num_outputs=num_outputs) - if num_outputs == 1: - outputs = [outputs] + if num_outputs != 1: + outputs = op.Split(self, axis=dim, num_outputs=num_outputs) + else: + outputs = [self] return [op.Squeeze(out, [dim]) for out in outputs]