From 974f6632268b4dacce81c2b5bcbde0de9cd165b5 Mon Sep 17 00:00:00 2001 From: Aravind-11 Date: Sat, 8 Nov 2025 19:37:25 -0700 Subject: [PATCH] Fixes #854 - linspace now correctly handles int64 dtype --- .../function_libs/torch_lib/ops/core.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2cbecdcfc..d3254a427 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4984,30 +4984,33 @@ def aten_linspace( pin_memory: bool = False, ) -> TensorType: """linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - if dtype == -1 or dtype is None: dtype = FLOAT.dtype - # Reference: https://github.com/pytorch/pytorch/blob/b35ca2cb941b5ba90858322810ca85c31e4541fd/torch/_refs/__init__.py#L4896 if steps == 0: return aten_full(op.Constant(value_ints=[0]), 0.0, dtype=dtype) if steps == 1: return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype) - rg = aten_arange_start(0, steps, dtype=dtype) - start = op.Cast(start, to=dtype) - end = op.Cast(end, to=dtype) - steps_float = op.Cast(steps, to=dtype) - one = op.Cast(1.0, to=dtype) - two = op.Cast(2.0, to=dtype) - steps_minus_1 = op.Cast(steps - 1, to=dtype) - step = op.Div(op.Sub(end, start), steps_minus_1) - return op.Where( - rg < op.Div(steps_float, two), - start + step * rg, - end - step * (steps_float - one - rg), + compute_dtype = FLOAT.dtype + + rg = aten_arange_start(0, steps, dtype=compute_dtype) + start_f = op.Cast(start, to=compute_dtype) + end_f = op.Cast(end, to=compute_dtype) + steps_f = op.Cast(steps, to=compute_dtype) + one = op.Cast(1.0, to=compute_dtype) + two = op.Cast(2.0, to=compute_dtype) + steps_minus_1 = op.Sub(steps_f, one) + step = op.Div(op.Sub(end_f, start_f), steps_minus_1) + + lin_vals = op.Where( + rg < op.Div(steps_f, two), + op.Add(start_f, op.Mul(step, rg)), + op.Sub(end_f, op.Mul(step, op.Sub(op.Sub(steps_f, one), rg))), ) + return op.Cast(lin_vals, to=dtype) + @torch_op("aten::log", trace_only=True) def aten_log(self: TFloat) -> TFloat: