Skip to content

Commit

Permalink
Implement Tanh Gelu Approximation (#61439)
Browse files Browse the repository at this point in the history
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: cpuhrsch

Differential Revision: D33850228

Pulled By: jbschlosser

fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33
(cherry picked from commit 3a53b3e94fd58190d1261efd3cf41b53506fb96e)
  • Loading branch information
rdspring1 authored and pytorchmergebot committed Jan 31, 2022
1 parent 084eafd commit 7c7210c
Showing 1 changed file with 64 additions and 4 deletions.
68 changes: 64 additions & 4 deletions torch/csrc/jit/codegen/cuda/parser.cpp
Expand Up @@ -12,6 +12,8 @@
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <torch/csrc/jit/ir/constants.h>

#include <ATen/native/Activation.h>

#include <unordered_map>
#include <utility>

Expand Down Expand Up @@ -2273,7 +2275,8 @@ class IrParser {
}

{
auto ptr_op = getOperatorForLiteral("aten::gelu(Tensor self) -> Tensor");
auto ptr_op = getOperatorForLiteral(
"aten::gelu(Tensor self, int approximate=0) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
Expand All @@ -2283,7 +2286,20 @@ class IrParser {
c10::nullopt, value_map[node->inputs()[0]->unique()]);
auto self = list_val.front();
list_val.pop_front();
auto out = gelu(self);

auto approximate = constant_as<int64_t>(node->input(1));
TORCH_INTERNAL_ASSERT(
approximate.has_value(),
"The approximate parameter is required.");
const bool kApproximate = approximate.value();

Val* out = nullptr;
if (kApproximate == at::Gelu::Tanh) {
out = fast_gelu(self);
} else {
out = unaryOp(UnaryOpType::Gelu, self);
}

value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
Expand All @@ -2293,7 +2309,7 @@ class IrParser {

{
auto ptr_op = getOperatorForLiteral(
"aten::gelu_backward(Tensor grad, Tensor self) -> Tensor");
"aten::gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
Expand All @@ -2308,7 +2324,19 @@ class IrParser {
auto self = list_val.front();
list_val.pop_front();

auto grad_in = gelu_backward(grad_out, self);
auto approximate = constant_as<int64_t>(node->input(2));
TORCH_INTERNAL_ASSERT(
approximate.has_value(),
"The approximate parameter is required.");
const bool kApproximate = approximate.value();

Val* grad_in = nullptr;
if (kApproximate == at::Gelu::Tanh) {
grad_in = fast_gelu_backward(grad_out, self);
} else {
grad_in = gelu_backward(grad_out, self);
}

value_map.emplace(
node->output()->unique(), ValueHolder(grad_in, format));
},
Expand Down Expand Up @@ -3015,6 +3043,38 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
}
}

static auto gelu_schema =
getOperatorForLiteral(
"aten::gelu(Tensor self, int approximate=0) -> Tensor")
->schema();
if (node->matches(gelu_schema)) {
switch (offset) {
// argument 1: approximate;
case 1:
profileInt(pr, node, offset);
break;
default:
return false;
}
return true;
}

static auto gelu_backward_schema =
getOperatorForLiteral(
"aten::gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor")
->schema();
if (node->matches(gelu_backward_schema)) {
switch (offset) {
// argument 2: approximate;
case 2:
profileInt(pr, node, offset);
break;
default:
return false;
}
return true;
}

static auto softmax_backward_data_schema =
getOperatorForLiteral(
"aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor")
Expand Down

0 comments on commit 7c7210c

Please sign in to comment.