From b37b082f3c037f2d9caf6855fca0d38351664389 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 24 Jan 2025 10:25:22 -0800 Subject: [PATCH 1/2] [torchlib] Fix prod --- onnxscript/function_libs/torch_lib/ops/core.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f980465bc4..b6b2567e88 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6682,11 +6682,21 @@ def aten_prelu_backward( raise NotImplementedError() -@torch_op("aten::prod.dim_int", trace_only=True) -def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal: +@torch_op("aten::prod.default", trace_only=True) +def aten_prod(self: TReal, dtype: int = -1) -> TReal: """prod(Tensor self, *, ScalarType? dtype=None) -> Tensor""" - # Todo: add test for this function later + if dtype != -1 and dtype is not None: + self = op.Cast(self, to=dtype) + return op.ReduceProd(self) + + +@torch_op("aten::prod.dim_int", trace_only=True) +def aten_prod(self: TReal, dim: int, keepdim: bool = False, dtype: int = -1) -> TReal: + """prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor""" + + if dtype != -1 and dtype is not None: + self = op.Cast(self, to=dtype) return op.ReduceProd(self, axes=[dim], keepdims=keepdim) From 49862327816632f7d7215e1aa368147caf428d0f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 24 Jan 2025 10:35:35 -0800 Subject: [PATCH 2/2] test --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++-- tests/function_libs/torch_lib/ops_test_data.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b6b2567e88..c3892c6cd3 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6682,7 +6682,7 @@ def aten_prelu_backward( raise NotImplementedError() -@torch_op("aten::prod.default", trace_only=True) +@torch_op("aten::prod", trace_only=True) def aten_prod(self: TReal, dtype: int = -1) -> TReal: """prod(Tensor self, *, ScalarType? dtype=None) -> Tensor""" @@ -6692,7 +6692,7 @@ def aten_prod(self: TReal, dtype: int = -1) -> TReal: @torch_op("aten::prod.dim_int", trace_only=True) -def aten_prod(self: TReal, dim: int, keepdim: bool = False, dtype: int = -1) -> TReal: +def aten_prod_dim_int(self: TReal, dim: int, keepdim: bool = False, dtype: int = -1) -> TReal: """prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor""" if dtype != -1 and dtype is not None: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 35e1778ca2..1399264546 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1271,6 +1271,19 @@ def _where_input_wrangler( ), TorchLibOpInfo("polar", core_ops.aten_polar), TorchLibOpInfo("pow", core_ops.aten_pow), + TorchLibOpInfo("prod", core_ops.aten_prod).skip( + matcher=lambda sample: sample.kwargs.get("dim") is not None + or sample.kwargs.get("keepdim") is not None + or sample.kwargs.get("dtype") != -1, + reason="this Aten overload only accept 1 inputs: self", + ), + TorchLibOpInfo("prod_dim_int", core_ops.aten_prod_dim_int).skip( + matcher=lambda sample: ( + sample.kwargs.get("dim") is None and sample.kwargs.get("keepdim") is None + ) + or sample.kwargs.get("dtype") != -1, + reason="this Aten overload can accept 3 inputs:(self, dim, keepdim)", + ), TorchLibOpInfo("nn.functional.prelu", core_ops.aten_prelu), TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True), TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True), @@ -2203,6 +2216,7 @@ def _where_input_wrangler( OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",) ) ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) +ops_test_common.duplicate_opinfo(OPS_DB, "prod", ("prod_dim_int",)) ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",))