From 2ae9f2107eb7ac1dd54a74087b1a352cbf6d9670 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 12:59:36 -0800 Subject: [PATCH 1/4] [torchlib] Fix and implement overloads for aten::remainder Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 099b786d74..3d0e9fd5df 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7657,7 +7657,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: raise NotImplementedError() -@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True) +@torch_op("aten::remainder.Tensor", trace_only=True) def aten_remainder(self: TTensor, other: TTensor) -> TTensor: """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -7673,6 +7673,36 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor: return op.Sub(self, op.Mul(rounded_quotient, other)) +@torch_op("aten::remainder.Scalar", trace_only=True) +def aten_remainder(self: TTensor, other: float) -> TTensor: + """remainder.Scalar(Tensor self, Scalar other) -> Tensor""" + + other_tensor = ir.tensor(other, dtype=self.dtype) + + if self.dtype.is_integer(): + return op.Mod(self, other_tensor) + + # a - a.div(b, rounding_mode="floor") * b + rounded_quotient = op.Floor(op.Div(self, other_tensor)) + + return op.Sub(self, op.Mul(rounded_quotient, other_tensor)) + + +@torch_op("aten::remainder.Scalar_Tensor", trace_only=True) +def aten_remainder_scalar_tensor(self: TTensor, other: TTensor) -> TTensor: + """remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + + self_tensor = ir.tensor(self, dtype=other.dtype) + + if other.dtype.is_integer(): + return op.Mod(self_tensor, other) + + # a - a.div(b, rounding_mode="floor") * b + rounded_quotient = op.Floor(op.Div(self_tensor, other)) + + return op.Sub(self_tensor, op.Mul(rounded_quotient, other)) + + @torch_op("_operator::mod", trace_only=True) def operator_mod(self: TTensor, other: TTensor) -> TTensor: # Modulus operator % on SymInt From 3778f7dbaffc174b7f3f20b8a13c7262cd9d151b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 13:42:16 -0800 Subject: [PATCH 2/4] Apply suggestion from @justinchuby --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 3d0e9fd5df..63c6c1b8c1 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7689,7 +7689,7 @@ def aten_remainder(self: TTensor, other: float) -> TTensor: @torch_op("aten::remainder.Scalar_Tensor", trace_only=True) -def aten_remainder_scalar_tensor(self: TTensor, other: TTensor) -> TTensor: +def aten_remainder_scalar_tensor(self: float, other: TTensor) -> TTensor: """remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" self_tensor = ir.tensor(self, dtype=other.dtype) From 98e81a2813e274902e9a98b676bebb4cf46ea453 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 13:56:32 -0800 Subject: [PATCH 3/4] fix Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 63c6c1b8c1..b0d0ad467b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7674,7 +7674,7 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor: @torch_op("aten::remainder.Scalar", trace_only=True) -def aten_remainder(self: TTensor, other: float) -> TTensor: +def aten_remainder_scalar(self: TTensor, other: float) -> TTensor: """remainder.Scalar(Tensor self, Scalar other) -> Tensor""" other_tensor = ir.tensor(other, dtype=self.dtype) From 4ce07c9943fe96a786c315040a2c797dfd5fb337 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 14:01:02 -0800 Subject: [PATCH 4/4] Refactor Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 32 +++++++------------ 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b0d0ad467b..254378bf09 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7657,11 +7657,8 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: raise NotImplementedError() -@torch_op("aten::remainder.Tensor", trace_only=True) -def aten_remainder(self: TTensor, other: TTensor) -> TTensor: - """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" - - if self.dtype.is_integer(): +def _aten_remainder(self: TTensor, other: TTensor, integer: bool) -> TTensor: + if integer: return op.Mod(self, other) # TODO(justinchuby): Improve fp16 precision by following the logic in @@ -7673,19 +7670,19 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor: return op.Sub(self, op.Mul(rounded_quotient, other)) +@torch_op("aten::remainder.Tensor", trace_only=True) +def aten_remainder(self: TTensor, other: TTensor) -> TTensor: + """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" + + return _aten_remainder(self, other, integer=self.dtype.is_integer()) + + @torch_op("aten::remainder.Scalar", trace_only=True) def aten_remainder_scalar(self: TTensor, other: float) -> TTensor: """remainder.Scalar(Tensor self, Scalar other) -> Tensor""" other_tensor = ir.tensor(other, dtype=self.dtype) - - if self.dtype.is_integer(): - return op.Mod(self, other_tensor) - - # a - a.div(b, rounding_mode="floor") * b - rounded_quotient = op.Floor(op.Div(self, other_tensor)) - - return op.Sub(self, op.Mul(rounded_quotient, other_tensor)) + return _aten_remainder(self, other_tensor, integer=self.dtype.is_integer()) @torch_op("aten::remainder.Scalar_Tensor", trace_only=True) @@ -7693,14 +7690,7 @@ def aten_remainder_scalar_tensor(self: float, other: TTensor) -> TTensor: """remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" self_tensor = ir.tensor(self, dtype=other.dtype) - - if other.dtype.is_integer(): - return op.Mod(self_tensor, other) - - # a - a.div(b, rounding_mode="floor") * b - rounded_quotient = op.Floor(op.Div(self_tensor, other)) - - return op.Sub(self_tensor, op.Mul(rounded_quotient, other)) + return _aten_remainder(self_tensor, other, integer=other.dtype.is_integer()) @torch_op("_operator::mod", trace_only=True)