diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2cbecdcfc..d3254a427 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4984,30 +4984,33 @@ def aten_linspace( pin_memory: bool = False, ) -> TensorType: """linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - if dtype == -1 or dtype is None: dtype = FLOAT.dtype - # Reference: https://github.com/pytorch/pytorch/blob/b35ca2cb941b5ba90858322810ca85c31e4541fd/torch/_refs/__init__.py#L4896 if steps == 0: return aten_full(op.Constant(value_ints=[0]), 0.0, dtype=dtype) if steps == 1: return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype) - rg = aten_arange_start(0, steps, dtype=dtype) - start = op.Cast(start, to=dtype) - end = op.Cast(end, to=dtype) - steps_float = op.Cast(steps, to=dtype) - one = op.Cast(1.0, to=dtype) - two = op.Cast(2.0, to=dtype) - steps_minus_1 = op.Cast(steps - 1, to=dtype) - step = op.Div(op.Sub(end, start), steps_minus_1) - return op.Where( - rg < op.Div(steps_float, two), - start + step * rg, - end - step * (steps_float - one - rg), + compute_dtype = FLOAT.dtype + + rg = aten_arange_start(0, steps, dtype=compute_dtype) + start_f = op.Cast(start, to=compute_dtype) + end_f = op.Cast(end, to=compute_dtype) + steps_f = op.Cast(steps, to=compute_dtype) + one = op.Cast(1.0, to=compute_dtype) + two = op.Cast(2.0, to=compute_dtype) + steps_minus_1 = op.Sub(steps_f, one) + step = op.Div(op.Sub(end_f, start_f), steps_minus_1) + + lin_vals = op.Where( + rg < op.Div(steps_f, two), + op.Add(start_f, op.Mul(step, rg)), + op.Sub(end_f, op.Mul(step, op.Sub(op.Sub(steps_f, one), rg))), ) + return op.Cast(lin_vals, to=dtype) + @torch_op("aten::log", trace_only=True) def aten_log(self: TFloat) -> TFloat: