From 83b2a0935cf3c196dbe99918f4535b9c59587949 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Oct 2024 23:40:19 +0000 Subject: [PATCH 1/2] [torchlib] Do not register rsub --- onnxscript/function_libs/torch_lib/ops/core.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 395f1fcac9..9a60571508 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7352,18 +7352,11 @@ def aten_rsqrt(self: TFloat) -> TFloat: return op.Reciprocal(op.Sqrt(self)) -@torch_op(("aten::rsub.Tensor", "aten::rsub.Scalar")) +# Do not register rsub. It will be decomposed and type promoted by torch def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - return op.Sub(other, op.Mul(self, alpha)) - - -@torch_op(("aten::rsub.Tensor", "aten::rsub.Scalar"), trace_only=True, complex=True) -def aten_rsub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: - """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - - return aten_rsub(self, other, alpha) + raise NotImplementedError @torch_op("aten::scalar_tensor", trace_only=True) From 6f1766819b065420a603a3f42920765c8bbd951e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Oct 2024 23:47:47 +0000 Subject: [PATCH 2/2] rsub --- tests/function_libs/torch_lib/ops_test_data.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index c180c1b71b..35c691109f 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1360,8 +1360,6 @@ def _where_input_wrangler( ), TorchLibOpInfo("round_decimals", core_ops.aten_round_decimals), TorchLibOpInfo("rsqrt", core_ops.aten_rsqrt), - TorchLibOpInfo("rsub", core_ops.aten_rsub), - TorchLibOpInfo("rsub", core_ops.aten_rsub_complex, complex=True), TorchLibOpInfo( "scalar_tensor", core_ops.aten_scalar_tensor,