From a7d777a7f30e97d786ce52ce9ed23ec69cdcad80 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 7 Nov 2025 10:39:13 -0800 Subject: [PATCH] [torchlib] Fix mod on SymInt Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 42 +++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index be30520878..7472e8eec6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3732,7 +3732,7 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType: @torch_op( - ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"), + ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor"), trace_only=True, ) def aten_ge(self: TTensor, other: TTensor) -> BOOL: @@ -3749,6 +3749,12 @@ def aten_ge(self: TTensor, other: TTensor) -> BOOL: return op.GreaterOrEqual(self, other) +@torch_op("_operator::ge", trace_only=True) +def operator_ge(self: TTensor, other: TTensor) -> BOOL: + # operator.ge for SymInt + return op.GreaterOrEqual(self, other) + + def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]: """geqrf(Tensor self) -> (Tensor a, Tensor tau)""" @@ -3872,7 +3878,7 @@ def aten_gru_cell( @torch_op( - ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), + ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor"), trace_only=True, ) def aten_gt(self: TTensor, other: TTensor) -> BOOL: @@ -3890,6 +3896,12 @@ def aten_gt(self: TTensor, other: TTensor) -> BOOL: return op.Greater(self, other) +@torch_op("_operator::gt", trace_only=True) +def operator_gt(self: TTensor, other: TTensor) -> BOOL: + # operator.gt for SymInt + return op.Greater(self, other) + + @torch_op("aten::hamming_window", trace_only=True) def aten_hamming_window( window_length: int, @@ -4705,7 +4717,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType: @torch_op( - ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), + ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor"), trace_only=True, ) def aten_le(self: TTensor, other: TTensor) -> BOOL: @@ -4723,6 +4735,12 @@ def aten_le(self: TTensor, other: TTensor) -> BOOL: return op.LessOrEqual(self, other) +@torch_op("_operator::le", trace_only=True) +def operator_le(self: TTensor, other: TTensor) -> BOOL: + # operator.le for SymInt + return op.LessOrEqual(self, other) + + @torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar")) def aten_lerp(self: TTensor, end: TTensor, weight: TTensor) -> TTensor: """lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor""" @@ -4992,7 +5010,7 @@ def aten_lstm_mps_backward( @torch_op( - ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), + ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor"), trace_only=True, ) def aten_lt(self: TTensor, other: TTensor) -> BOOL: @@ -5009,6 +5027,12 @@ def aten_lt(self: TTensor, other: TTensor) -> BOOL: return op.Less(self, other) +@torch_op("_operator::lt", trace_only=True) +def operator_lt(self: TTensor, other: TTensor) -> BOOL: + # operator.lt for SymInt + return op.Less(self, other) + + def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType: """lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor""" @@ -7076,9 +7100,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: raise NotImplementedError() -@torch_op( - ("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True -) +@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True) def aten_remainder(self: TTensor, other: TTensor) -> TTensor: """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -7094,6 +7116,12 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor: return op.Sub(self, op.Mul(rounded_quotient, other)) +@torch_op("_operator::mod", trace_only=True) +def operator_mod(self: TTensor, other: TTensor) -> TTensor: + # Modulus operator % on SymInt + return op.Mod(self, other) + + def aten_rename(self: TensorType, names: Optional[str]) -> TensorType: """rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)"""