diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 76d1de9a1a..f430d77d48 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5972,6 +5972,18 @@ def aten_round(self: TFloat) -> TFloat: return op.Round(self) +@torch_op("aten::round.decimals") +def aten_round_decimals(self: TFloat, decimals: int = 0) -> TFloat: + """round.decimals(Tensor self, *, int decimals) -> Tensor""" + + # Scale the input by 10^decimals, round it, and scale it back. + ten = op.CastLike(10.0, self) + scale = op.Pow(ten, op.CastLike(decimals, self)) + self_scaled = op.Mul(self, scale) + rounded = op.Round(self_scaled) + return op.Div(rounded, scale) + + def aten_row_indices(self: TensorType) -> TensorType: """row_indices(Tensor(a) self) -> Tensor(a)""" diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 7d8a544446..e014df392c 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1153,23 +1153,21 @@ def _where_input_wrangler( TorchLibOpInfo("reshape", core_ops.aten_reshape), TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj), TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg), - TorchLibOpInfo( - "round", - core_ops.aten_round, - ) + TorchLibOpInfo("round", core_ops.aten_round) .xfail( variant_name="decimals_0", - reason="The op does not support decimals yet", + reason="This variant does not accept decimals", test_class_name="TestOutputConsistencyEager", ) .xfail( variant_name="decimals_3", - reason="The op does not support decimals yet", + reason="This variant does not accept decimals", ) .xfail( variant_name="decimals_neg_3", - reason="The op does not support decimals yet", + reason="This variant does not accept decimals", ), + TorchLibOpInfo("round_decimals", core_ops.aten_round_decimals), TorchLibOpInfo("rsqrt", core_ops.aten_rsqrt), TorchLibOpInfo("rsub", core_ops.aten_rsub), TorchLibOpInfo( @@ -1939,6 +1937,7 @@ def _where_input_wrangler( "nn.functional.upsample_nearest3d", ), ) +ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",))