Skip to content

Commit

Permalink
Implement Tanh Gelu Approximation (#61439)
Browse files Browse the repository at this point in the history
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
  • Loading branch information
rdspring1 authored and cyyever committed Feb 17, 2022
1 parent 01cfd84 commit 50eef6c
Show file tree
Hide file tree
Showing 51 changed files with 828 additions and 273 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/autocast_mode.cpp
Expand Up @@ -485,7 +485,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(avg_pool1d), "avg_pool1d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool), fp32)
KERNEL_CPU(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(avg_pool3d), "avg_pool3d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &), fp32)
KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double>>), fp32)
KERNEL_CPU(ADD_NS(_upsample_nearest_exact1d), "_upsample_nearest_exact1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
Expand Down
22 changes: 12 additions & 10 deletions aten/src/ATen/native/Activation.cpp
Expand Up @@ -164,12 +164,12 @@ TORCH_META_FUNC(softshrink_backward) (
build_borrowing_binary_op(maybe_get_output(), grad, self);
}

TORCH_META_FUNC(gelu) (const Tensor & self) {
TORCH_META_FUNC(gelu) (const Tensor & self, c10::string_view approximate) {
build_unary_op(maybe_get_output(), self);
}

TORCH_META_FUNC(gelu_backward) (
const Tensor& grad, const Tensor& self
const Tensor& grad, const Tensor& self, c10::string_view approximate
) {
build_borrowing_binary_op(maybe_get_output(), grad, self);
}
Expand Down Expand Up @@ -324,37 +324,39 @@ bool use_mkldnn(const Tensor& input) {
}

TORCH_IMPL_FUNC(gelu_out_cpu) (
const Tensor& self, const Tensor& result
const Tensor& self, c10::string_view approximate, const Tensor& result
) {
auto approximate_type = get_gelutype_enum(approximate);
#if AT_MKLDNN_ENABLED()
if (use_mkldnn(self)) {
if (use_mkldnn(self) && (approximate_type == GeluType::None)) {
const ideep::tensor& x = itensor_from_tensor(self);
ideep::tensor y = itensor_from_tensor(result);
ideep::eltwise_forward::compute(
x, y, ideep::algorithm::eltwise_gelu_erf, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
} else {
GeluKernel(kCPU, *this);
GeluKernel(kCPU, *this, approximate_type);
}
#else
GeluKernel(kCPU, *this);
GeluKernel(kCPU, *this, approximate_type);
#endif
}

TORCH_IMPL_FUNC(gelu_backward_out_cpu) (
const Tensor& grad, const Tensor& self, const Tensor& grad_input
const Tensor& grad, const Tensor& self, c10::string_view approximate, const Tensor& grad_input
) {
auto approximate_type = get_gelutype_enum(approximate);
#if AT_MKLDNN_ENABLED()
if (use_mkldnn(self)) {
if (use_mkldnn(self) && (approximate_type == GeluType::None)) {
const ideep::tensor& x = itensor_from_tensor(self);
ideep::tensor grady = itensor_from_tensor(grad);
ideep::tensor gradx = itensor_from_tensor(grad_input);
ideep::eltwise_backward::compute(x, grady, gradx,
ideep::algorithm::eltwise_gelu_erf, /*alpha*/ 0.0);
} else {
GeluBackwardKernel(kCPU, *this);
GeluBackwardKernel(kCPU, *this, approximate_type);
}
#else
GeluBackwardKernel(kCPU, *this);
GeluBackwardKernel(kCPU, *this, approximate_type);
#endif
}

Expand Down
23 changes: 21 additions & 2 deletions aten/src/ATen/native/Activation.h
Expand Up @@ -14,6 +14,23 @@ class TensorBase;

namespace at { namespace native {

// These constants control the approximation behavior of gelu function.
enum GeluType {
None, // Baseline Gelu
Tanh, // Tahn Gelu Approximation
END
};

static GeluType get_gelutype_enum(const c10::string_view approximate) {
if (approximate == "none") {
return GeluType::None;
} else if (approximate == "tanh") {
return GeluType::Tanh;
} else {
TORCH_CHECK(false, "approximate argument must be either none or tanh.");
}
}

using structured_activation_fn = void (*)(TensorIteratorBase&);
using structured_activation_backward_fn = void (*)(TensorIteratorBase&);

Expand All @@ -35,6 +52,8 @@ using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const
using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
using gelu_fn = void (*)(TensorIteratorBase&, GeluType);
using gelu_backward_fn = void (*)(TensorIteratorBase&, GeluType);

DECLARE_DISPATCH(elu_fn, elu_stub);
DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub);
Expand All @@ -43,8 +62,8 @@ DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub);
DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub);
DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub);
DECLARE_DISPATCH(threshold_fn, threshold_stub);
DECLARE_DISPATCH(structured_activation_fn, GeluKernel);
DECLARE_DISPATCH(structured_activation_backward_fn, GeluBackwardKernel);
DECLARE_DISPATCH(gelu_fn, GeluKernel);
DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel);
DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
Expand Down
173 changes: 127 additions & 46 deletions aten/src/ATen/native/cpu/Activation.cpp
Expand Up @@ -166,7 +166,7 @@ void elu_backward_kernel(TensorIteratorBase& it, const Scalar& alpha, const Scal
// TODO(yangxm): Add another fast kernel using formula
// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
// and the fast tanh impl from Eigen.
void GeluKernelImpl(TensorIteratorBase& it) {
void GeluKernelImpl(TensorIteratorBase& it, GeluType approximate) {
auto grain_size = at::internal::GRAIN_SIZE;
// Numbers based on benchmarking.
// Benchmark: benchmarks/operator_benchmarks/pt/gelu_test.py
Expand All @@ -187,53 +187,134 @@ void GeluKernelImpl(TensorIteratorBase& it) {
if (it.numel() > GELU_MIN_ELEMENTS_FOR_MULTI_THREADING) {
grain_size = it.numel() / at::get_num_threads();
}
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
},
[&](Vec x_vec) {
return x_vec * kPointFiveVec *
(kOneVec + (x_vec * kAlphaVec).erf());
},
grain_size);
});
if (approximate == GeluType::Tanh) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
const Vec kKappaVec(scalar_t(0.044715));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t x) {
const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
const scalar_t kKappa = 0.044715;
auto x_cube = x * x * x;
auto inner = kBeta * (x + kKappa * x_cube);
return scalar_t(0.5) * x * (scalar_t(1) + std::tanh(inner));
},
[&](Vec x_vec) {
auto x_cube = x_vec * x_vec * x_vec;
auto inner_vec = kBetaVec * (x_vec + kKappaVec * x_cube);
return kPointFiveVec * x_vec * (kOneVec + inner_vec.tanh());
},
grain_size);
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
},
[&](Vec x_vec) {
return x_vec * kPointFiveVec *
(kOneVec + (x_vec * kAlphaVec).erf());
},
grain_size);
});
}
}

void GeluBackwardKernelImpl(TensorIteratorBase& it) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
const Vec kMinusPointFiveVec(scalar_t(-0.5));
cpu_kernel_vec(
it,
[](scalar_t dy, scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5);
const scalar_t cdf =
scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
return dy * (cdf + x * pdf);
},
[&](Vec dy_vec, Vec x_vec) {
const Vec cdf_vec =
kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
const Vec pdf_vec =
kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
return dy_vec * (cdf_vec + x_vec * pdf_vec);
});
});
void GeluBackwardKernelImpl(TensorIteratorBase& it, GeluType approximate) {
if (approximate == GeluType::Tanh) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
const Vec kKappaVec(scalar_t(0.044715));
const Vec kOneVec(scalar_t(1));
const Vec kThreeVec(scalar_t(3));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t dy, scalar_t x) {
const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
const scalar_t kKappa = 0.044715;
auto x_sq = x * x;
auto x_cube = x_sq * x;
auto inner = kBeta * (x + kKappa * x_cube);
auto tanh_inner = std::tanh(inner);

auto left = scalar_t(0.5) * x;
auto right = scalar_t(1) + tanh_inner;

auto left_derivative = scalar_t(0.5) * right;

auto tanh_derivative = scalar_t(1) - tanh_inner * tanh_inner;
auto inner_derivative =
kBeta * (scalar_t(1) + scalar_t(3) * kKappa * x_sq);
auto right_derivative = left * tanh_derivative * inner_derivative;

return dy * (left_derivative + right_derivative);
},
[&](Vec dy_vec, Vec x_vec) {
auto x_sq = x_vec * x_vec;
auto x_cube = x_vec * x_vec * x_vec;
auto inner_vec =
kBetaVec * (x_vec + kKappaVec * x_cube);
auto tanh_inner_vec = inner_vec.tanh();

auto left_vec = kPointFiveVec * x_vec;
auto right_vec = kOneVec + tanh_inner_vec;

auto left_derivative_vec = kPointFiveVec * right_vec;

auto tanh_derivative_vec =
kOneVec - tanh_inner_vec * tanh_inner_vec;
auto inner_derivative_vec =
kBetaVec * (kOneVec + kThreeVec * kKappaVec * x_sq);
auto right_derivative_vec =
left_vec * tanh_derivative_vec * inner_derivative_vec;

return dy_vec * (left_derivative_vec + right_derivative_vec);
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
const Vec kMinusPointFiveVec(scalar_t(-0.5));
cpu_kernel_vec(
it,
[](scalar_t dy, scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5);
const scalar_t cdf =
scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
return dy * (cdf + x * pdf);
},
[&](Vec dy_vec, Vec x_vec) {
const Vec cdf_vec =
kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
const Vec pdf_vec =
kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
return dy_vec * (cdf_vec + x_vec * pdf_vec);
});
});
}
}

void hardsigmoid_kernel(TensorIteratorBase& iter) {
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/cuda/Activation.cpp
Expand Up @@ -156,15 +156,15 @@ std::tuple<Tensor, Tensor> prelu_backward_cuda(const Tensor& grad_out_, const Te
}

TORCH_IMPL_FUNC(gelu_out_cuda) (
const Tensor& /*self*/, const Tensor& /*result*/
) {
GeluCUDAKernelImpl(*this);
const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*result*/
) {
GeluCUDAKernelImpl(*this, get_gelutype_enum(approximate));
}

TORCH_IMPL_FUNC(gelu_backward_out_cuda) (
const Tensor& /*grad*/, const Tensor& /*self*/, const Tensor& /*grad_input*/
) {
GeluBackwardCUDAKernelImpl(*this);
const Tensor& /*grad*/, const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*grad_input*/
) {
GeluBackwardCUDAKernelImpl(*this, get_gelutype_enum(approximate));
}

}} // namespace at::native

0 comments on commit 50eef6c

Please sign in to comment.