From 7c7210c7518bdeab48a4cb5692d9ac5602c6c421 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Mon, 31 Jan 2022 08:59:01 -0800 Subject: [PATCH] Implement Tanh Gelu Approximation (#61439) Summary: 1. Implements https://github.com/pytorch/pytorch/issues/39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - https://github.com/pytorch/xla/pull/3039 Pull Request resolved: https://github.com/pytorch/pytorch/pull/61439 Reviewed By: cpuhrsch Differential Revision: D33850228 Pulled By: jbschlosser fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33 (cherry picked from commit 3a53b3e94fd58190d1261efd3cf41b53506fb96e) --- torch/csrc/jit/codegen/cuda/parser.cpp | 68 ++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index a33b338..4c791e6 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -12,6 +12,8 @@ #include #include +#include + #include #include @@ -2273,7 +2275,8 @@ class IrParser { } { - auto ptr_op = getOperatorForLiteral("aten::gelu(Tensor self) -> Tensor"); + auto ptr_op = getOperatorForLiteral( + "aten::gelu(Tensor self, int approximate=0) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { @@ -2283,7 +2286,20 @@ class IrParser { c10::nullopt, value_map[node->inputs()[0]->unique()]); auto self = list_val.front(); list_val.pop_front(); - auto out = gelu(self); + + auto approximate = constant_as(node->input(1)); + TORCH_INTERNAL_ASSERT( + approximate.has_value(), + "The approximate parameter is required."); + const bool kApproximate = approximate.value(); + + Val* out = nullptr; + if (kApproximate == at::Gelu::Tanh) { + out = fast_gelu(self); + } else { + out = unaryOp(UnaryOpType::Gelu, self); + } + value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, @@ -2293,7 +2309,7 @@ class IrParser { { auto ptr_op = getOperatorForLiteral( - "aten::gelu_backward(Tensor grad, Tensor self) -> Tensor"); + "aten::gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { @@ -2308,7 +2324,19 @@ class IrParser { auto self = list_val.front(); list_val.pop_front(); - auto grad_in = gelu_backward(grad_out, self); + auto approximate = constant_as(node->input(2)); + TORCH_INTERNAL_ASSERT( + approximate.has_value(), + "The approximate parameter is required."); + const bool kApproximate = approximate.value(); + + Val* grad_in = nullptr; + if (kApproximate == at::Gelu::Tanh) { + grad_in = fast_gelu_backward(grad_out, self); + } else { + grad_in = gelu_backward(grad_out, self); + } + value_map.emplace( node->output()->unique(), ValueHolder(grad_in, format)); }, @@ -3015,6 +3043,38 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { } } + static auto gelu_schema = + getOperatorForLiteral( + "aten::gelu(Tensor self, int approximate=0) -> Tensor") + ->schema(); + if (node->matches(gelu_schema)) { + switch (offset) { + // argument 1: approximate; + case 1: + profileInt(pr, node, offset); + break; + default: + return false; + } + return true; + } + + static auto gelu_backward_schema = + getOperatorForLiteral( + "aten::gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor") + ->schema(); + if (node->matches(gelu_backward_schema)) { + switch (offset) { + // argument 2: approximate; + case 2: + profileInt(pr, node, offset); + break; + default: + return false; + } + return true; + } + static auto softmax_backward_data_schema = getOperatorForLiteral( "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor")