From d4ecd92677707da40422428c5742c57775bc1b82 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 25 Aug 2023 22:10:07 +0000 Subject: [PATCH 1/2] Implement `aten::round.decimals` | feat(torchlib) --- onnxscript/function_libs/torch_lib/ops/core.py | 16 ++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 13 ++++++------- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 76d1de9a1a..9122c8db33 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5972,6 +5972,22 @@ def aten_round(self: TFloat) -> TFloat: return op.Round(self) +@torch_op("aten::round.decimals") +def aten_round_decimals(self: TFloat, decimals: int = 0) -> TFloat: + """round.decimals(Tensor self, *, int decimals) -> Tensor""" + + if decimals == 0: + result = op.Round(self) + else: + # Scale the input by 10^decimals, round it, and scale it back. + ten = op.CastLike(10.0, self) + scale = op.Pow(ten, op.CastLike(decimals, self)) + self_scaled = op.Mul(self, scale) + rounded = op.Round(self_scaled) + result = op.Div(rounded, scale) + return result + + def aten_row_indices(self: TensorType) -> TensorType: """row_indices(Tensor(a) self) -> Tensor(a)""" diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 7d8a544446..e014df392c 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1153,23 +1153,21 @@ def _where_input_wrangler( TorchLibOpInfo("reshape", core_ops.aten_reshape), TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj), TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg), - TorchLibOpInfo( - "round", - core_ops.aten_round, - ) + TorchLibOpInfo("round", core_ops.aten_round) .xfail( variant_name="decimals_0", - reason="The op does not support decimals yet", + reason="This variant does not accept decimals", test_class_name="TestOutputConsistencyEager", ) .xfail( variant_name="decimals_3", - reason="The op does not support decimals yet", + reason="This variant does not accept decimals", ) .xfail( variant_name="decimals_neg_3", - reason="The op does not support decimals yet", + reason="This variant does not accept decimals", ), + TorchLibOpInfo("round_decimals", core_ops.aten_round_decimals), TorchLibOpInfo("rsqrt", core_ops.aten_rsqrt), TorchLibOpInfo("rsub", core_ops.aten_rsub), TorchLibOpInfo( @@ -1939,6 +1937,7 @@ def _where_input_wrangler( "nn.functional.upsample_nearest3d", ), ) +ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",)) From d2ffad214b23c9e80135ce7c7a6c5a1689378c8b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 25 Aug 2023 22:12:09 +0000 Subject: [PATCH 2/2] Simplify --- onnxscript/function_libs/torch_lib/ops/core.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 9122c8db33..f430d77d48 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5976,16 +5976,12 @@ def aten_round(self: TFloat) -> TFloat: def aten_round_decimals(self: TFloat, decimals: int = 0) -> TFloat: """round.decimals(Tensor self, *, int decimals) -> Tensor""" - if decimals == 0: - result = op.Round(self) - else: - # Scale the input by 10^decimals, round it, and scale it back. - ten = op.CastLike(10.0, self) - scale = op.Pow(ten, op.CastLike(decimals, self)) - self_scaled = op.Mul(self, scale) - rounded = op.Round(self_scaled) - result = op.Div(rounded, scale) - return result + # Scale the input by 10^decimals, round it, and scale it back. + ten = op.CastLike(10.0, self) + scale = op.Pow(ten, op.CastLike(decimals, self)) + self_scaled = op.Mul(self, scale) + rounded = op.Round(self_scaled) + return op.Div(rounded, scale) def aten_row_indices(self: TensorType) -> TensorType: