diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f980465bc4..c3892c6cd3 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6682,11 +6682,21 @@ def aten_prelu_backward( raise NotImplementedError() -@torch_op("aten::prod.dim_int", trace_only=True) -def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal: +@torch_op("aten::prod", trace_only=True) +def aten_prod(self: TReal, dtype: int = -1) -> TReal: """prod(Tensor self, *, ScalarType? dtype=None) -> Tensor""" - # Todo: add test for this function later + if dtype != -1 and dtype is not None: + self = op.Cast(self, to=dtype) + return op.ReduceProd(self) + + +@torch_op("aten::prod.dim_int", trace_only=True) +def aten_prod_dim_int(self: TReal, dim: int, keepdim: bool = False, dtype: int = -1) -> TReal: + """prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor""" + + if dtype != -1 and dtype is not None: + self = op.Cast(self, to=dtype) return op.ReduceProd(self, axes=[dim], keepdims=keepdim) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 35e1778ca2..1399264546 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1271,6 +1271,19 @@ def _where_input_wrangler( ), TorchLibOpInfo("polar", core_ops.aten_polar), TorchLibOpInfo("pow", core_ops.aten_pow), + TorchLibOpInfo("prod", core_ops.aten_prod).skip( + matcher=lambda sample: sample.kwargs.get("dim") is not None + or sample.kwargs.get("keepdim") is not None + or sample.kwargs.get("dtype") != -1, + reason="this Aten overload only accept 1 inputs: self", + ), + TorchLibOpInfo("prod_dim_int", core_ops.aten_prod_dim_int).skip( + matcher=lambda sample: ( + sample.kwargs.get("dim") is None and sample.kwargs.get("keepdim") is None + ) + or sample.kwargs.get("dtype") != -1, + reason="this Aten overload can accept 3 inputs:(self, dim, keepdim)", + ), TorchLibOpInfo("nn.functional.prelu", core_ops.aten_prelu), TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True), TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True), @@ -2203,6 +2216,7 @@ def _where_input_wrangler( OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",) ) ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) +ops_test_common.duplicate_opinfo(OPS_DB, "prod", ("prod_dim_int",)) ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",))