From 2f9de2dacd7fc470c061ef6614ea05750716e525 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Sep 2025 15:03:30 -0700 Subject: [PATCH 1/3] Improve aten_floor_divide for int inputs Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 28 +++++++++++-------- tests/function_libs/torch_lib/extra_opinfo.py | 11 +------- .../function_libs/torch_lib/ops_test_data.py | 1 - 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index dfbd562708..5566ef554a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3688,23 +3688,27 @@ def python_math_floor(self: TFloat) -> TInt: @torch_op("aten::floor_divide", trace_only=True) -def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat: +def aten_floor_divide(self: TTensor, other: TTensor) -> TTensor: """floor_divide(Tensor self, Tensor other) -> Tensor""" - return op.Floor(op.Div(self, other)) + if self.dtype.is_floating_point(): + return op.Floor(op.Div(self, other)) + assert self.dtype.is_integer() -@torch_op("aten::floor_divide", trace_only=True) -def aten_floor_divide_int(self: TInt, other: TInt) -> TInt: - """floor_divide(Tensor self, Tensor other) -> Tensor""" + if not self.dtype.is_signed(): + return op.Div(self, other) - # TODO(justinchuby): This can be simplified if we can constrain the - # inputs to be positive integers. Consider how we can embed constraints in the model. - dtype = self.dtype - self = op.Cast(self, to=FLOAT.dtype) - other = op.Cast(other, to=FLOAT.dtype) - result = op.Floor(op.Div(self, other)) - return op.Cast(result, to=dtype) + # Convert truncation to flooring + # Reference: https://pytorch.org/docs/stable/generated/torch.floor_divide.html + # offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) + # return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) + offset = op.Or( + op.Equal(op.Sign(self), op.Sign(other)), + op.Not(op.Cast(op.Mod(self, other), to=BOOL.dtype)), + ) + offset = op.Cast(offset, to=self.dtype) + return op.Sub(op.Div(self, other), offset) @torch_op("_operator::floordiv", trace_only=True) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 4f4a3872e1..b03cb5880a 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -2270,18 +2270,9 @@ def __init__(self): opinfo_core.BinaryUfuncInfo( "ops.aten.floor_divide", aten_name="floor_divide", - dtypes=common_dtype.floating_types_and_half(), + dtypes=common_dtype.all_types_and_half(), rhs_make_tensor_kwargs=dict(exclude_zero=True), ), - opinfo_core.BinaryUfuncInfo( - "ops.aten.floor_divide.int", - aten_name="floor_divide", - op=torch.ops.aten.floor_divide, - dtypes=common_dtype.integral_types(), - # Create only positive inputs - lhs_make_tensor_kwargs=dict(low=0), - rhs_make_tensor_kwargs=dict(exclude_zero=True, low=0), - ), opinfo_core.OpInfo( "ops.aten.hamming_window", aten_name="hamming_window", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 98d10d9e5b..92495d201a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -794,7 +794,6 @@ def _where_input_wrangler( TorchLibOpInfo("flatten", core_ops.aten_flatten), TorchLibOpInfo("floor", core_ops.aten_floor), TorchLibOpInfo("ops.aten.floor_divide", core_ops.aten_floor_divide), - TorchLibOpInfo("ops.aten.floor_divide.int", core_ops.aten_floor_divide_int), TorchLibOpInfo("fmod", core_ops.aten_fmod), TorchLibOpInfo("frac", core_ops.aten_frac), TorchLibOpInfo("full", core_ops.aten_full), From 80d9e1f0d23e450d19920f47f13160f1160acf9d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Sep 2025 15:16:25 -0700 Subject: [PATCH 2/3] Fix Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 5566ef554a..5a4c4338c8 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3703,9 +3703,9 @@ def aten_floor_divide(self: TTensor, other: TTensor) -> TTensor: # Reference: https://pytorch.org/docs/stable/generated/torch.floor_divide.html # offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) # return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) - offset = op.Or( - op.Equal(op.Sign(self), op.Sign(other)), - op.Not(op.Cast(op.Mod(self, other), to=BOOL.dtype)), + offset = op.And( + op.Not(op.Equal(op.Sign(self), op.Sign(other))), + op.Cast(op.Mod(self, other), to=BOOL.dtype), ) offset = op.Cast(offset, to=self.dtype) return op.Sub(op.Div(self, other), offset) From 0a053b03880cd41423fc192ffdda393d131067e4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Sep 2025 15:19:51 -0700 Subject: [PATCH 3/3] ref Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 5a4c4338c8..1a688a4277 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3700,7 +3700,7 @@ def aten_floor_divide(self: TTensor, other: TTensor) -> TTensor: return op.Div(self, other) # Convert truncation to flooring - # Reference: https://pytorch.org/docs/stable/generated/torch.floor_divide.html + # Reference: https://github.com/pytorch/pytorch/blob/ffc645c870f0abd368606ba1e2b3b58cacb03046/torch/_refs/__init__.py#L1401C1-L1409C70 # offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) # return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) offset = op.And(