Skip to content

Commit

Permalink
Implement Tanh Gelu Approximation (#61439)
Browse files Browse the repository at this point in the history
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
  • Loading branch information
rdspring1 authored and pytorchmergebot committed Feb 14, 2022
1 parent 53c9bed commit 1b970ec
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 6 deletions.
4 changes: 4 additions & 0 deletions graph_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ Value* createConditionalConstant(Node* profile_ivalue) {
// int
val = IValue(
static_cast<int>(profile_ivalue->i(Symbol::attr("profiled_int"))));
} else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_str"))) {
// str
val = IValue(static_cast<std::string>(
profile_ivalue->s(Symbol::attr("profiled_str"))));
} else {
GRAPH_DEBUG("profile_ivalue: ", *profile_ivalue);
TORCH_WARN(
Expand Down
107 changes: 101 additions & 6 deletions parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <torch/csrc/jit/ir/constants.h>

#include <ATen/native/Activation.h>

#include <unordered_map>
#include <utility>

Expand Down Expand Up @@ -62,6 +64,7 @@ const auto& intListAttr = Symbol::attr("profiled_int_list");
const auto& intAttr = Symbol::attr("profiled_int");
const auto& boolListAttr = Symbol::attr("profiled_bool_list");
const auto& boolAttr = Symbol::attr("profiled_bool");
const auto& strAttr = Symbol::attr("profiled_str");

typedef Val* CgValue;
typedef Expr* CgOp;
Expand Down Expand Up @@ -2273,7 +2276,8 @@ class IrParser {
}

{
auto ptr_op = getOperatorForLiteral("aten::gelu(Tensor self) -> Tensor");
auto ptr_op = getOperatorForLiteral(
"aten::gelu(Tensor self, *, str approximate='none') -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
Expand All @@ -2283,7 +2287,21 @@ class IrParser {
c10::nullopt, value_map[node->inputs()[0]->unique()]);
auto self = list_val.front();
list_val.pop_front();
auto out = gelu(self);

auto approximate = constant_as<std::string>(node->input(1));
TORCH_INTERNAL_ASSERT(
approximate.has_value(),
"The approximate parameter is required.");
const auto kApproximate = approximate.value();

Val* out = nullptr;
if (at::native::get_gelutype_enum(kApproximate) ==
at::native::GeluType::Tanh) {
out = fast_gelu(self);
} else {
out = unaryOp(UnaryOpType::Gelu, self);
}

value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
Expand All @@ -2293,7 +2311,7 @@ class IrParser {

{
auto ptr_op = getOperatorForLiteral(
"aten::gelu_backward(Tensor grad, Tensor self) -> Tensor");
"aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
Expand All @@ -2308,7 +2326,20 @@ class IrParser {
auto self = list_val.front();
list_val.pop_front();

auto grad_in = gelu_backward(grad_out, self);
auto approximate = constant_as<std::string>(node->input(2));
TORCH_INTERNAL_ASSERT(
approximate.has_value(),
"The approximate parameter is required.");
const auto kApproximate = approximate.value();

Val* grad_in = nullptr;
if (at::native::get_gelutype_enum(kApproximate) ==
at::native::GeluType::Tanh) {
grad_in = fast_gelu_backward(grad_out, self);
} else {
grad_in = gelu_backward(grad_out, self);
}

value_map.emplace(
node->output()->unique(), ValueHolder(grad_in, format));
},
Expand Down Expand Up @@ -2453,9 +2484,13 @@ class IrParser {
}
value_map_.emplace(val->unique(), cg_val);
return true;
} else if (val->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
} else if (
val->type()->isSubtypeOf(
static_cast<c10::TypePtr>(StringType::get())) ||
val->type()->isSubtypeOf(static_cast<c10::TypePtr>(NoneType::get()))) {
// TODO: should we consider adding support for NoneType;
// String scalars are only used in parsing rules;
// Do not register string with codegen IR.
return true;
} else if (val->type()->cast<ListType>()) {
// TODO: we don't support list type in codegen yet;
Expand Down Expand Up @@ -2646,6 +2681,34 @@ void profileIntList(ProfilingRecord* pr, Node* node, size_t offset) {
pn->setCallback(ivalue_profiler);
}

void profileString(ProfilingRecord* pr, Node* node, size_t offset) {
auto pn = insertProfileIValueOp(node, offset, pr);

const auto ivalue_profiler = [pr, pn](Stack& stack) {
std::lock_guard<std::mutex> lock(pr->mutex_);

// TODO: we don't care about merging multiple profiling runs as we don't
// support it at all;
int64_t frame_id = 0;
pop(stack, frame_id);
IValue value;
pop(stack, value);
TORCH_INTERNAL_ASSERT(
value.isString(), "profiling seeing the wrong data type");
if (!pn->hasAttribute(strAttr)) {
pn->s_(strAttr, value.toStringRef());
} else {
const auto& profiled_str = pn->s(strAttr);
const auto& input_str = value.toStringRef();
TORCH_INTERNAL_ASSERT(
input_str == profiled_str, "profiling ivalue doesn't support merge");
}
push(stack, value);
};

pn->setCallback(ivalue_profiler);
}

void profileBool(ProfilingRecord* pr, Node* node, size_t offset) {
auto pn = insertProfileIValueOp(node, offset, pr);

Expand Down Expand Up @@ -3015,6 +3078,38 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
}
}

static auto gelu_schema =
getOperatorForLiteral(
"aten::gelu(Tensor self, *, str approximate='none') -> Tensor")
->schema();
if (node->matches(gelu_schema)) {
switch (offset) {
// argument 1: approximate;
case 1:
profileString(pr, node, offset);
break;
default:
return false;
}
return true;
}

static auto gelu_backward_schema =
getOperatorForLiteral(
"aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor")
->schema();
if (node->matches(gelu_backward_schema)) {
switch (offset) {
// argument 2: approximate;
case 2:
profileString(pr, node, offset);
break;
default:
return false;
}
return true;
}

static auto softmax_backward_data_schema =
getOperatorForLiteral(
"aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor")
Expand Down

0 comments on commit 1b970ec

Please sign in to comment.