diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index e25c97cdf1e6e..a323e79ed33f4 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1156,7 +1156,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio... xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('sort', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('special.entr', ''), # aten.special_entr.default - couldn't find symbolic meta function/decomposition xfail('special.erfcx', ''), # aten.special_erfcx.default - couldn't find symbolic meta function/decompos... xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition xfail('special.log_ndtr', ''), # aten.special_log_ndtr.default - couldn't find symbolic meta function/de... diff --git a/test/test_ops.py b/test/test_ops.py index 78e72ebafce75..4ecbc59f5d74b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1594,6 +1594,7 @@ class TestRefsOpsInfo(TestCase): '_refs.to', '_refs.ones', '_refs.ones_like', + '_refs.special.expit', '_refs.std_var', '_refs.swap_axes', '_refs.uniform', @@ -1653,6 +1654,7 @@ class TestRefsOpsInfo(TestCase): '_refs.positive', '_refs.ravel', '_refs.reshape', + '_refs.special.expit', '_refs.square', '_refs.tensor_split', '_refs.to', diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index b51f3642414f1..2b31e185199c8 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1260,7 +1260,6 @@ def f(a, b, c, d, e): xfail('special.bessel_y1', ''), # aten.special_bessel_y1.default - couldn't find symbolic meta function/decomposition xfail('special.chebyshev_polynomial_t', ''), # aten.special_chebyshev_polynomial_t.default - couldn't find symbolic me... xfail('special.chebyshev_polynomial_u', ''), # aten.special_chebyshev_polynomial_u.default - couldn't find symbolic me... - xfail('special.entr', ''), # aten.special_entr.default - couldn't find symbolic meta function/decomposition xfail('special.erfcx', ''), # aten.special_erfcx.default - couldn't find symbolic meta function/decomposition xfail('special.hermite_polynomial_h', ''), # aten.special_hermite_polynomial_h.default - couldn't find symbolic meta f... xfail('special.hermite_polynomial_he', ''), # aten.special_hermite_polynomial_he.default - couldn't find symbolic meta... diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py index 70c57e32c0e28..a9b2234a36791 100644 --- a/torch/_refs/special/__init__.py +++ b/torch/_refs/special/__init__.py @@ -19,6 +19,8 @@ __all__ = [ "bessel_j0", "bessel_j1", + "entr", + "expit", "i0e", "i1", "i1e", @@ -45,6 +47,24 @@ def bessel_j1(a: TensorLikeType) -> TensorLikeType: return prims.bessel_j1(a) +@register_decomposition(torch.ops.aten.special_entr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def entr(a: TensorLikeType) -> TensorLikeType: + return torch.where( + torch.isnan(a), + a, + torch.where(a > 0, -a * torch.log(a), torch.where(a == 0, 0, -torch.inf)), + ) + + +# alias for sigmoid +expit = torch.sigmoid + + @_make_elementwise_unary_reference( ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i0e ) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 91cc2e371ca33..aa82678d0a7a4 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -16946,6 +16946,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ElementwiseUnaryPythonRefInfo( "_refs.sigmoid", torch_opinfo_name="sigmoid", + aliases=('_refs.special.expit',), # Reference: https://github.com/pytorch/pytorch/issues/56012 handles_complex_extremal_values=False, handles_large_floats=False, diff --git a/torch/testing/_internal/opinfo/definitions/special.py b/torch/testing/_internal/opinfo/definitions/special.py index b7fe80f064deb..3f6ddff3a7830 100644 --- a/torch/testing/_internal/opinfo/definitions/special.py +++ b/torch/testing/_internal/opinfo/definitions/special.py @@ -657,6 +657,12 @@ def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs): supports_nvfuser=False, op_db=op_db, ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.entr", + torch_opinfo_name="special.entr", + supports_nvfuser=False, + op_db=op_db, + ), ElementwiseUnaryPythonRefInfo( "_refs.special.i0e", torch_opinfo_name="special.i0e",