diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index dfbd56270..1a688a427 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3688,23 +3688,27 @@ def python_math_floor(self: TFloat) -> TInt: @torch_op("aten::floor_divide", trace_only=True) -def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat: +def aten_floor_divide(self: TTensor, other: TTensor) -> TTensor: """floor_divide(Tensor self, Tensor other) -> Tensor""" - return op.Floor(op.Div(self, other)) + if self.dtype.is_floating_point(): + return op.Floor(op.Div(self, other)) + assert self.dtype.is_integer() -@torch_op("aten::floor_divide", trace_only=True) -def aten_floor_divide_int(self: TInt, other: TInt) -> TInt: - """floor_divide(Tensor self, Tensor other) -> Tensor""" + if not self.dtype.is_signed(): + return op.Div(self, other) - # TODO(justinchuby): This can be simplified if we can constrain the - # inputs to be positive integers. Consider how we can embed constraints in the model. - dtype = self.dtype - self = op.Cast(self, to=FLOAT.dtype) - other = op.Cast(other, to=FLOAT.dtype) - result = op.Floor(op.Div(self, other)) - return op.Cast(result, to=dtype) + # Convert truncation to flooring + # Reference: https://github.com/pytorch/pytorch/blob/ffc645c870f0abd368606ba1e2b3b58cacb03046/torch/_refs/__init__.py#L1401C1-L1409C70 + # offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) + # return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) + offset = op.And( + op.Not(op.Equal(op.Sign(self), op.Sign(other))), + op.Cast(op.Mod(self, other), to=BOOL.dtype), + ) + offset = op.Cast(offset, to=self.dtype) + return op.Sub(op.Div(self, other), offset) @torch_op("_operator::floordiv", trace_only=True) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 4f4a3872e..b03cb5880 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -2270,18 +2270,9 @@ def __init__(self): opinfo_core.BinaryUfuncInfo( "ops.aten.floor_divide", aten_name="floor_divide", - dtypes=common_dtype.floating_types_and_half(), + dtypes=common_dtype.all_types_and_half(), rhs_make_tensor_kwargs=dict(exclude_zero=True), ), - opinfo_core.BinaryUfuncInfo( - "ops.aten.floor_divide.int", - aten_name="floor_divide", - op=torch.ops.aten.floor_divide, - dtypes=common_dtype.integral_types(), - # Create only positive inputs - lhs_make_tensor_kwargs=dict(low=0), - rhs_make_tensor_kwargs=dict(exclude_zero=True, low=0), - ), opinfo_core.OpInfo( "ops.aten.hamming_window", aten_name="hamming_window", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 98d10d9e5..92495d201 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -794,7 +794,6 @@ def _where_input_wrangler( TorchLibOpInfo("flatten", core_ops.aten_flatten), TorchLibOpInfo("floor", core_ops.aten_floor), TorchLibOpInfo("ops.aten.floor_divide", core_ops.aten_floor_divide), - TorchLibOpInfo("ops.aten.floor_divide.int", core_ops.aten_floor_divide_int), TorchLibOpInfo("fmod", core_ops.aten_fmod), TorchLibOpInfo("frac", core_ops.aten_frac), TorchLibOpInfo("full", core_ops.aten_full),