diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 96f64bbb8..767dffacf 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8957,7 +8957,12 @@ def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"): # We can create a definitive split op if the input shape is static # Only torch>=2.7 supports correctly generating the correct number of outputs for Split - outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim]) + num_outputs = self.shape[dim] + if num_outputs != 1: + outputs = op.Split(self, axis=dim, num_outputs=num_outputs) + else: + outputs = [self] + return [op.Squeeze(out, [dim]) for out in outputs] return op.SplitToSequence(self, axis=dim, keepdims=False)