Skip to content

Commit

Permalink
Port sign to structured (pytorch#57588)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#57588

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D28224600

Pulled By: ezyang

fbshipit-source-id: 71de5211617c1eba34192e23831136ae5c403e61
  • Loading branch information
Freey0 authored and mrshenli committed May 7, 2021
1 parent 2ddf019 commit 59fb012
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
15 changes: 7 additions & 8 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -102,6 +102,12 @@ TORCH_META_FUNC(floor) (const Tensor& self) {
build_unary_op(maybe_get_output(), self);
}

TORCH_META_FUNC(sign) (const Tensor& self) {
TORCH_CHECK(!self.is_complex(),
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
build_unary_op(maybe_get_output(), self);
}

} // namespace meta

namespace native {
Expand Down Expand Up @@ -144,6 +150,7 @@ CREATE_UNARY_TORCH_IMPL_FUNC(reciprocal)
CREATE_UNARY_TORCH_IMPL_FUNC(round)
CREATE_UNARY_TORCH_IMPL_FUNC(rsqrt)
CREATE_UNARY_TORCH_IMPL_FUNC(sigmoid)
CREATE_UNARY_TORCH_IMPL_FUNC(sign)
CREATE_UNARY_TORCH_IMPL_FUNC(sin)
CREATE_UNARY_TORCH_IMPL_FUNC(sinc)
CREATE_UNARY_TORCH_IMPL_FUNC(sinh)
Expand Down Expand Up @@ -402,14 +409,6 @@ Tensor special_erfc(const Tensor& self) { return self.erfc(); }
Tensor& special_erfinv_out(const Tensor& self, Tensor& result) { return at::erfinv_out(result, self); }
Tensor special_erfinv(const Tensor& self) { return self.erfinv(); }

Tensor& sign_out(const Tensor& self, Tensor& result) {
TORCH_CHECK(!self.is_complex(),
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
return unary_op_impl_out(result, self, sign_stub);
}
Tensor sign(const Tensor& self) { return unary_op_impl(self, at::sign_out); }
Tensor& sign_(Tensor& self) { return unary_op_impl_(self, at::sign_out); }

Tensor& sgn_out(const Tensor& self, Tensor& result) {
if (self.is_complex()) {
return unary_op_impl_out(result, self, sgn_stub);
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -6430,18 +6430,22 @@

- func: sign(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: sign.out
variants: function, method
dispatch:
CompositeExplicitAutograd: sign

- func: sign_(Tensor(a!) self) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured_delegate: sign.out
variants: method
dispatch:
CompositeExplicitAutograd: sign_

- func: sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: sign_out

Expand Down

0 comments on commit 59fb012

Please sign in to comment.