Skip to content

Commit

Permalink
[primTorch] special: entr, expit (pytorch#86592)
Browse files Browse the repository at this point in the history
Add _refs for `entr` & `expit`.

cc @mruberry @kshitij12345!
Pull Request resolved: pytorch#86592
Approved by: https://github.com/mruberry
  • Loading branch information
khushi-411 authored and pytorchmergebot committed Oct 12, 2022
1 parent a47f93b commit 2344135
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 2 deletions.
1 change: 0 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('sort', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('special.entr', ''), # aten.special_entr.default - couldn't find symbolic meta function/decomposition
xfail('special.erfcx', ''), # aten.special_erfcx.default - couldn't find symbolic meta function/decompos...
xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('special.log_ndtr', ''), # aten.special_log_ndtr.default - couldn't find symbolic meta function/de...
Expand Down
2 changes: 2 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,7 @@ class TestRefsOpsInfo(TestCase):
'_refs.to',
'_refs.ones',
'_refs.ones_like',
'_refs.special.expit',
'_refs.std_var',
'_refs.swap_axes',
'_refs.uniform',
Expand Down Expand Up @@ -1653,6 +1654,7 @@ class TestRefsOpsInfo(TestCase):
'_refs.positive',
'_refs.ravel',
'_refs.reshape',
'_refs.special.expit',
'_refs.square',
'_refs.tensor_split',
'_refs.to',
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,6 @@ def f(a, b, c, d, e):
xfail('special.bessel_y1', ''), # aten.special_bessel_y1.default - couldn't find symbolic meta function/decomposition
xfail('special.chebyshev_polynomial_t', ''), # aten.special_chebyshev_polynomial_t.default - couldn't find symbolic me...
xfail('special.chebyshev_polynomial_u', ''), # aten.special_chebyshev_polynomial_u.default - couldn't find symbolic me...
xfail('special.entr', ''), # aten.special_entr.default - couldn't find symbolic meta function/decomposition
xfail('special.erfcx', ''), # aten.special_erfcx.default - couldn't find symbolic meta function/decomposition
xfail('special.hermite_polynomial_h', ''), # aten.special_hermite_polynomial_h.default - couldn't find symbolic meta f...
xfail('special.hermite_polynomial_he', ''), # aten.special_hermite_polynomial_he.default - couldn't find symbolic meta...
Expand Down
20 changes: 20 additions & 0 deletions torch/_refs/special/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
__all__ = [
"bessel_j0",
"bessel_j1",
"entr",
"expit",
"i0e",
"i1",
"i1e",
Expand All @@ -45,6 +47,24 @@ def bessel_j1(a: TensorLikeType) -> TensorLikeType:
return prims.bessel_j1(a)


@register_decomposition(torch.ops.aten.special_entr)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def entr(a: TensorLikeType) -> TensorLikeType:
return torch.where(
torch.isnan(a),
a,
torch.where(a > 0, -a * torch.log(a), torch.where(a == 0, 0, -torch.inf)),
)


# alias for sigmoid
expit = torch.sigmoid


@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i0e
)
Expand Down
1 change: 1 addition & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16946,6 +16946,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
ElementwiseUnaryPythonRefInfo(
"_refs.sigmoid",
torch_opinfo_name="sigmoid",
aliases=('_refs.special.expit',),
# Reference: https://github.com/pytorch/pytorch/issues/56012
handles_complex_extremal_values=False,
handles_large_floats=False,
Expand Down
6 changes: 6 additions & 0 deletions torch/testing/_internal/opinfo/definitions/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,12 @@ def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
supports_nvfuser=False,
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.entr",
torch_opinfo_name="special.entr",
supports_nvfuser=False,
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.i0e",
torch_opinfo_name="special.i0e",
Expand Down

0 comments on commit 2344135

Please sign in to comment.