Skip to content

Commit

Permalink
Bring back math_silu_backward which works for all backends. (pytorch#…
Browse files Browse the repository at this point in the history
…49439)

Summary: Pull Request resolved: pytorch#49439

Test Plan: Imported from OSS

Reviewed By: nikithamalgifb, ngimel

Differential Revision: D25594129

Pulled By: ailzhang

fbshipit-source-id: 627bbea9ba478ee3a8edcc6695abab6431900192
  • Loading branch information
Ailing Zhang authored and hwangdeyu committed Dec 23, 2020
1 parent 6a59ef2 commit 3b1186d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
7 changes: 7 additions & 0 deletions aten/src/ATen/native/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,13 @@ Tensor silu_backward(
return grad_input;
}

Tensor math_silu_backward(
const Tensor& grad_output,
const Tensor& input) {
auto input_sigmoid = at::sigmoid(input);
return grad_output * (input_sigmoid * (1 + input * (1 - input_sigmoid)));
}

template <typename scalar_t>
inline void _rrelu_with_noise_train(
Tensor& output,
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3665,6 +3665,9 @@
- func: silu_backward(Tensor grad_output, Tensor self) -> Tensor
use_c10_dispatcher: full
python_module: nn
dispatch:
CPU, CUDA: silu_backward
Math: math_silu_backward

- func: sigmoid(Tensor self) -> Tensor
use_c10_dispatcher: full
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/test/math_kernel_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,11 @@ TEST(MathKernelTest, Addr) {
}
}
}

TEST(MathKernelTest, SiluBackward) {
const auto input = rand({20, 10});
const auto grad_output = rand({20, 10});
auto out = at::native::silu_backward(grad_output, input);
auto math_out = at::native::math_silu_backward(grad_output, input);
ASSERT_ALLCLOSE_TOLERANCES(out, math_out, 1e-4, 1e-6);
}

0 comments on commit 3b1186d

Please sign in to comment.