diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 7bdab046419..bd9da6a4593 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -485,7 +485,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { KERNEL_CPU(ADD_NS(avg_pool1d), "avg_pool1d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool), fp32) KERNEL_CPU(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional), fp32) KERNEL_CPU(ADD_NS(avg_pool3d), "avg_pool3d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional), fp32) - KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &), fp32) + KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &, c10::string_view), fp32) KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d", Tensor (const Tensor &, IntArrayRef, c10::optional), fp32) KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d.vec", Tensor (const Tensor &, c10::optional, c10::optional>), fp32) KERNEL_CPU(ADD_NS(_upsample_nearest_exact1d), "_upsample_nearest_exact1d", Tensor (const Tensor &, IntArrayRef, c10::optional), fp32) diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index ff79939830c..424dacf124e 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -164,12 +164,12 @@ TORCH_META_FUNC(softshrink_backward) ( build_borrowing_binary_op(maybe_get_output(), grad, self); } -TORCH_META_FUNC(gelu) (const Tensor & self) { +TORCH_META_FUNC(gelu) (const Tensor & self, c10::string_view approximate) { build_unary_op(maybe_get_output(), self); } TORCH_META_FUNC(gelu_backward) ( - const Tensor& grad, const Tensor& self + const Tensor& grad, const Tensor& self, c10::string_view approximate ) { build_borrowing_binary_op(maybe_get_output(), grad, self); } @@ -324,37 +324,39 @@ bool use_mkldnn(const Tensor& input) { } TORCH_IMPL_FUNC(gelu_out_cpu) ( - const Tensor& self, const Tensor& result + const Tensor& self, c10::string_view approximate, const Tensor& result ) { +auto approximate_type = get_gelutype_enum(approximate); #if AT_MKLDNN_ENABLED() - if (use_mkldnn(self)) { + if (use_mkldnn(self) && (approximate_type == GeluType::None)) { const ideep::tensor& x = itensor_from_tensor(self); ideep::tensor y = itensor_from_tensor(result); ideep::eltwise_forward::compute( x, y, ideep::algorithm::eltwise_gelu_erf, ideep::prop_kind::forward_training, /*alpha*/ 0.0); } else { - GeluKernel(kCPU, *this); + GeluKernel(kCPU, *this, approximate_type); } #else - GeluKernel(kCPU, *this); + GeluKernel(kCPU, *this, approximate_type); #endif } TORCH_IMPL_FUNC(gelu_backward_out_cpu) ( - const Tensor& grad, const Tensor& self, const Tensor& grad_input + const Tensor& grad, const Tensor& self, c10::string_view approximate, const Tensor& grad_input ) { +auto approximate_type = get_gelutype_enum(approximate); #if AT_MKLDNN_ENABLED() - if (use_mkldnn(self)) { + if (use_mkldnn(self) && (approximate_type == GeluType::None)) { const ideep::tensor& x = itensor_from_tensor(self); ideep::tensor grady = itensor_from_tensor(grad); ideep::tensor gradx = itensor_from_tensor(grad_input); ideep::eltwise_backward::compute(x, grady, gradx, ideep::algorithm::eltwise_gelu_erf, /*alpha*/ 0.0); } else { - GeluBackwardKernel(kCPU, *this); + GeluBackwardKernel(kCPU, *this, approximate_type); } #else - GeluBackwardKernel(kCPU, *this); + GeluBackwardKernel(kCPU, *this, approximate_type); #endif } diff --git a/aten/src/ATen/native/Activation.h b/aten/src/ATen/native/Activation.h index 963dc4665fd..6eb8182737b 100644 --- a/aten/src/ATen/native/Activation.h +++ b/aten/src/ATen/native/Activation.h @@ -14,6 +14,23 @@ class TensorBase; namespace at { namespace native { +// These constants control the approximation behavior of gelu function. +enum GeluType { + None, // Baseline Gelu + Tanh, // Tahn Gelu Approximation + END +}; + +static GeluType get_gelutype_enum(const c10::string_view approximate) { + if (approximate == "none") { + return GeluType::None; + } else if (approximate == "tanh") { + return GeluType::Tanh; + } else { + TORCH_CHECK(false, "approximate argument must be either none or tanh."); + } +} + using structured_activation_fn = void (*)(TensorIteratorBase&); using structured_activation_backward_fn = void (*)(TensorIteratorBase&); @@ -35,6 +52,8 @@ using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&); +using gelu_fn = void (*)(TensorIteratorBase&, GeluType); +using gelu_backward_fn = void (*)(TensorIteratorBase&, GeluType); DECLARE_DISPATCH(elu_fn, elu_stub); DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub); @@ -43,8 +62,8 @@ DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub); DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub); DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub); DECLARE_DISPATCH(threshold_fn, threshold_stub); -DECLARE_DISPATCH(structured_activation_fn, GeluKernel); -DECLARE_DISPATCH(structured_activation_backward_fn, GeluBackwardKernel); +DECLARE_DISPATCH(gelu_fn, GeluKernel); +DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel); DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub); DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub); DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub); diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp index b192d0c4d70..1eebcde30c9 100644 --- a/aten/src/ATen/native/cpu/Activation.cpp +++ b/aten/src/ATen/native/cpu/Activation.cpp @@ -166,7 +166,7 @@ void elu_backward_kernel(TensorIteratorBase& it, const Scalar& alpha, const Scal // TODO(yangxm): Add another fast kernel using formula // y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3))) // and the fast tanh impl from Eigen. -void GeluKernelImpl(TensorIteratorBase& it) { +void GeluKernelImpl(TensorIteratorBase& it, GeluType approximate) { auto grain_size = at::internal::GRAIN_SIZE; // Numbers based on benchmarking. // Benchmark: benchmarks/operator_benchmarks/pt/gelu_test.py @@ -187,53 +187,134 @@ void GeluKernelImpl(TensorIteratorBase& it) { if (it.numel() > GELU_MIN_ELEMENTS_FOR_MULTI_THREADING) { grain_size = it.numel() / at::get_num_threads(); } - AT_DISPATCH_FLOATING_TYPES_AND( - ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() { - using Vec = vec::Vectorized; - const Vec kAlphaVec(scalar_t(M_SQRT1_2)); - const Vec kOneVec(scalar_t(1)); - const Vec kPointFiveVec(scalar_t(0.5)); - cpu_kernel_vec( - it, - [](scalar_t x) { - const scalar_t kAlpha = scalar_t(M_SQRT1_2); - return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha)); - }, - [&](Vec x_vec) { - return x_vec * kPointFiveVec * - (kOneVec + (x_vec * kAlphaVec).erf()); - }, - grain_size); - }); + if (approximate == GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND( + ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() { + using Vec = vec::Vectorized; + const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5)); + const Vec kKappaVec(scalar_t(0.044715)); + const Vec kOneVec(scalar_t(1)); + const Vec kPointFiveVec(scalar_t(0.5)); + cpu_kernel_vec( + it, + [](scalar_t x) { + const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; + const scalar_t kKappa = 0.044715; + auto x_cube = x * x * x; + auto inner = kBeta * (x + kKappa * x_cube); + return scalar_t(0.5) * x * (scalar_t(1) + std::tanh(inner)); + }, + [&](Vec x_vec) { + auto x_cube = x_vec * x_vec * x_vec; + auto inner_vec = kBetaVec * (x_vec + kKappaVec * x_cube); + return kPointFiveVec * x_vec * (kOneVec + inner_vec.tanh()); + }, + grain_size); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND( + ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() { + using Vec = vec::Vectorized; + const Vec kAlphaVec(scalar_t(M_SQRT1_2)); + const Vec kOneVec(scalar_t(1)); + const Vec kPointFiveVec(scalar_t(0.5)); + cpu_kernel_vec( + it, + [](scalar_t x) { + const scalar_t kAlpha = scalar_t(M_SQRT1_2); + return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha)); + }, + [&](Vec x_vec) { + return x_vec * kPointFiveVec * + (kOneVec + (x_vec * kAlphaVec).erf()); + }, + grain_size); + }); + } } -void GeluBackwardKernelImpl(TensorIteratorBase& it) { - AT_DISPATCH_FLOATING_TYPES_AND( - ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() { - using Vec = vec::Vectorized; - const Vec kAlphaVec(scalar_t(M_SQRT1_2)); - const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5)); - const Vec kOneVec(scalar_t(1)); - const Vec kPointFiveVec(scalar_t(0.5)); - const Vec kMinusPointFiveVec(scalar_t(-0.5)); - cpu_kernel_vec( - it, - [](scalar_t dy, scalar_t x) { - const scalar_t kAlpha = scalar_t(M_SQRT1_2); - const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5); - const scalar_t cdf = - scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha)); - const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5)); - return dy * (cdf + x * pdf); - }, - [&](Vec dy_vec, Vec x_vec) { - const Vec cdf_vec = - kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf()); - const Vec pdf_vec = - kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp(); - return dy_vec * (cdf_vec + x_vec * pdf_vec); - }); - }); +void GeluBackwardKernelImpl(TensorIteratorBase& it, GeluType approximate) { + if (approximate == GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND( + ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() { + using Vec = vec::Vectorized; + const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5)); + const Vec kKappaVec(scalar_t(0.044715)); + const Vec kOneVec(scalar_t(1)); + const Vec kThreeVec(scalar_t(3)); + const Vec kPointFiveVec(scalar_t(0.5)); + cpu_kernel_vec( + it, + [](scalar_t dy, scalar_t x) { + const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; + const scalar_t kKappa = 0.044715; + auto x_sq = x * x; + auto x_cube = x_sq * x; + auto inner = kBeta * (x + kKappa * x_cube); + auto tanh_inner = std::tanh(inner); + + auto left = scalar_t(0.5) * x; + auto right = scalar_t(1) + tanh_inner; + + auto left_derivative = scalar_t(0.5) * right; + + auto tanh_derivative = scalar_t(1) - tanh_inner * tanh_inner; + auto inner_derivative = + kBeta * (scalar_t(1) + scalar_t(3) * kKappa * x_sq); + auto right_derivative = left * tanh_derivative * inner_derivative; + + return dy * (left_derivative + right_derivative); + }, + [&](Vec dy_vec, Vec x_vec) { + auto x_sq = x_vec * x_vec; + auto x_cube = x_vec * x_vec * x_vec; + auto inner_vec = + kBetaVec * (x_vec + kKappaVec * x_cube); + auto tanh_inner_vec = inner_vec.tanh(); + + auto left_vec = kPointFiveVec * x_vec; + auto right_vec = kOneVec + tanh_inner_vec; + + auto left_derivative_vec = kPointFiveVec * right_vec; + + auto tanh_derivative_vec = + kOneVec - tanh_inner_vec * tanh_inner_vec; + auto inner_derivative_vec = + kBetaVec * (kOneVec + kThreeVec * kKappaVec * x_sq); + auto right_derivative_vec = + left_vec * tanh_derivative_vec * inner_derivative_vec; + + return dy_vec * (left_derivative_vec + right_derivative_vec); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND( + ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() { + using Vec = vec::Vectorized; + const Vec kAlphaVec(scalar_t(M_SQRT1_2)); + const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5)); + const Vec kOneVec(scalar_t(1)); + const Vec kPointFiveVec(scalar_t(0.5)); + const Vec kMinusPointFiveVec(scalar_t(-0.5)); + cpu_kernel_vec( + it, + [](scalar_t dy, scalar_t x) { + const scalar_t kAlpha = scalar_t(M_SQRT1_2); + const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5); + const scalar_t cdf = + scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha)); + const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5)); + return dy * (cdf + x * pdf); + }, + [&](Vec dy_vec, Vec x_vec) { + const Vec cdf_vec = + kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf()); + const Vec pdf_vec = + kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp(); + return dy_vec * (cdf_vec + x_vec * pdf_vec); + }); + }); + } } void hardsigmoid_kernel(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/Activation.cpp b/aten/src/ATen/native/cuda/Activation.cpp index 2dfe0a862ea..23e8bc697f7 100644 --- a/aten/src/ATen/native/cuda/Activation.cpp +++ b/aten/src/ATen/native/cuda/Activation.cpp @@ -156,15 +156,15 @@ std::tuple prelu_backward_cuda(const Tensor& grad_out_, const Te } TORCH_IMPL_FUNC(gelu_out_cuda) ( - const Tensor& /*self*/, const Tensor& /*result*/ - ) { - GeluCUDAKernelImpl(*this); + const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*result*/ +) { + GeluCUDAKernelImpl(*this, get_gelutype_enum(approximate)); } TORCH_IMPL_FUNC(gelu_backward_out_cuda) ( - const Tensor& /*grad*/, const Tensor& /*self*/, const Tensor& /*grad_input*/ - ) { - GeluBackwardCUDAKernelImpl(*this); + const Tensor& /*grad*/, const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*grad_input*/ +) { + GeluBackwardCUDAKernelImpl(*this, get_gelutype_enum(approximate)); } }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu index 168e142dd29..e3acad92f90 100644 --- a/aten/src/ATen/native/cuda/Activation.cu +++ b/aten/src/ATen/native/cuda/Activation.cu @@ -392,30 +392,71 @@ void elu_backward_kernel(TensorIteratorBase& iter, const Scalar& alpha, const Sc }); } -void GeluCUDAKernelImpl(TensorIteratorBase& it) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { - using T_ACC = acc_type; - gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { - return static_cast(x) * - c10::cuda::compat::normcdf(static_cast(x)); +void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { + if (approximate == GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); + constexpr opmath_t kKappa = 0.044715; + auto x_cube = static_cast(x) * static_cast(x) * static_cast(x); + auto inner = kBeta * (static_cast(x) + kKappa * x_cube); + return opmath_t(0.5) * static_cast(x) * (opmath_t(1) + c10::cuda::compat::tanh(inner)); + }); }); - }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kAlpha = M_SQRT1_2; + return static_cast(x) * opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); + }); + }); + } } -void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { - using T_ACC = acc_type; - gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - constexpr T_ACC kBeta = M_2_SQRTPI * M_SQRT1_2 * T_ACC(0.5); - const T_ACC cdf = c10::cuda::compat::normcdf(static_cast(x)); - const T_ACC pdf = - c10::cuda::compat::exp( - T_ACC(-0.5) * static_cast(x) * static_cast(x)) * - kBeta; - return static_cast(dy) * (cdf + static_cast(x) * pdf); +void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { + if (approximate == GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); + constexpr opmath_t kKappa = 0.044715; + auto x_sq = static_cast(x) * static_cast(x); + auto x_cube = x_sq * static_cast(x); + auto inner = kBeta * (static_cast(x) + kKappa * x_cube); + auto tanh_inner = c10::cuda::compat::tanh(inner); + + auto left = opmath_t(0.5) * static_cast(x); + auto right = opmath_t(1) + tanh_inner; + + auto left_derivative = 0.5 * right; + + auto tanh_derivative = opmath_t(1) - tanh_inner * tanh_inner; + auto inner_derivative = kBeta * (opmath_t(1) + opmath_t(3) * kKappa * x_sq); + auto right_derivative = left * tanh_derivative * inner_derivative; + + return static_cast(dy) * (left_derivative + right_derivative); }); }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_2_SQRTPI * M_SQRT1_2 * opmath_t(0.5); + constexpr opmath_t kAlpha = M_SQRT1_2; + const opmath_t cdf = + opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); + const opmath_t pdf = + c10::cuda::compat::exp( + opmath_t(-0.5) * static_cast(x) * static_cast(x)) * + kBeta; + return static_cast(dy) * (cdf + static_cast(x) * pdf); + }); + }); + } } namespace { diff --git a/aten/src/ATen/native/cuda/Activation.h b/aten/src/ATen/native/cuda/Activation.h index 5e798316c9b..ca0ad3828da 100644 --- a/aten/src/ATen/native/cuda/Activation.h +++ b/aten/src/ATen/native/cuda/Activation.h @@ -1,4 +1,5 @@ +#include #include namespace at { @@ -24,7 +25,7 @@ void launch_prelu_cuda_backward_kernel_multi_weights( const TensorBase &input, const TensorBase &weight, const TensorBase &grad_out, const TensorBase &input_grad, const TensorBase &weight_grad_collector); -void GeluCUDAKernelImpl(TensorIteratorBase& it); -void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it); +void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); +void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); }} // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/Gelu.cpp b/aten/src/ATen/native/mkldnn/Gelu.cpp index fa78cd1c3a9..1d2a6725151 100644 --- a/aten/src/ATen/native/mkldnn/Gelu.cpp +++ b/aten/src/ATen/native/mkldnn/Gelu.cpp @@ -1,17 +1,17 @@ #include #include #include - +#include #if !AT_MKLDNN_ENABLED() namespace at { namespace native { -Tensor mkldnn_gelu(const Tensor& input) { +Tensor mkldnn_gelu(const Tensor& input, c10::string_view approximate) { TORCH_CHECK(false, "mkldnn_gelu: ATen not compiled with MKLDNN support"); } -Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input) { +Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input, c10::string_view approximate) { TORCH_CHECK(false, "mkldnn_gelu_backward: ATen not compiled with MKLDNN support"); } @@ -24,11 +24,13 @@ Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input) { namespace at { namespace native { -Tensor mkldnn_gelu(const Tensor& input) { +Tensor mkldnn_gelu(const Tensor& input, c10::string_view approximate) { if (input.scalar_type() == ScalarType::BFloat16) { TORCH_CHECK(mkldnn_bf16_device_check(), "mkldnn_gelu: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"); } + TORCH_CHECK(get_gelutype_enum(approximate) == GeluType::None, + "mkldnn_gelu: fast, approximate gelu is not supported"); const ideep::tensor& x = itensor_from_tensor(input); ideep::tensor y; ideep::eltwise_forward::compute( @@ -37,7 +39,9 @@ Tensor mkldnn_gelu(const Tensor& input) { input.options().device_opt()); } -Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input) { +Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input, c10::string_view approximate) { + TORCH_CHECK(get_gelutype_enum(approximate) == GeluType::None, + "mkldnn_gelu_backward: fast, approximate gelu is not supported"); const ideep::tensor& x = itensor_from_tensor(input); ideep::tensor grady = itensor_from_tensor(grad_output); ideep::tensor gradx; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8c333efd3bf..93c9ab24c79 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3736,7 +3736,7 @@ CPU: prelu_backward_cpu CUDA: prelu_backward_cuda -- func: gelu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +- func: gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator @@ -3745,7 +3745,7 @@ CPU: gelu_out_cpu CUDA: gelu_out_cuda -- func: gelu(Tensor self) -> Tensor +- func: gelu(Tensor self, *, str approximate='none') -> Tensor structured_delegate: gelu.out device_check: NoCheck # TensorIterator python_module: nn @@ -3753,7 +3753,7 @@ MkldnnCPU: mkldnn_gelu QuantizedCPU: gelu_quantized_cpu -- func: gelu_backward.grad_input(Tensor grad, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) +- func: gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase python_module: nn @@ -3761,7 +3761,7 @@ CPU: gelu_backward_out_cpu CUDA: gelu_backward_out_cuda -- func: gelu_backward(Tensor grad, Tensor self) -> Tensor +- func: gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor structured_delegate: gelu_backward.grad_input python_module: nn dispatch: diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index 23afea3e52c..77c9756e366 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -615,7 +616,7 @@ static void leaky_qrelu_out_kernel(Tensor& out, const Tensor& qx, }); } -void qgelu_kernel(const Tensor& qx, Tensor& qy) { +void qgelu_kernel(const Tensor& qx, Tensor& qy, GeluType approximate) { int64_t zero_point = qx.q_zero_point(); // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) float scale = qx.q_scale(); @@ -626,40 +627,83 @@ void qgelu_kernel(const Tensor& qx, Tensor& qy) { float output_scale = scale; float inv_output_scale = 1.0 / output_scale; const auto kAlphaVec = Vectorized(M_SQRT1_2); + const auto kBetaVec = Vectorized(M_SQRT2 * M_2_SQRTPI * 0.5); + const auto kKappaVec = Vectorized(0.044715); const auto kOneVec = Vectorized(1); const auto kPointFiveVec = Vectorized(0.5); - AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() { - qy = at::_empty_affine_quantized( - qx.sizes(), - // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) - at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()), - output_scale, - output_zero_point, - c10::nullopt); - auto iter = TensorIterator::unary_op(qy, qx); - - using Vec = Vectorized; - cpu_kernel_vec( - iter, - [&](scalar_t value_qx) -> scalar_t { - const auto value_dx = - at::native::dequantize_val(scale, zero_point, value_qx); - const auto value_dy = - value_dx * 0.5 * (1 + std::erf(value_dx * M_SQRT1_2)); - return at::native::quantize_val( - output_scale, output_zero_point, value_dy); - }, - [&](Vec value_qx) -> Vec { - auto value_dx = value_qx.dequantize( - scale_vec, zero_point_vec, scale_neg_zp_premul_vec); - for (auto & value : value_dx) { - value = value * kPointFiveVec * (kOneVec + (value * kAlphaVec).erf()); - } - return Vec::quantize( - value_dx, output_scale, output_zero_point, inv_output_scale); - }); - }); + if (approximate == GeluType::Tanh) { + AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() { + qy = at::_empty_affine_quantized( + qx.sizes(), + // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) + at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()), + output_scale, + output_zero_point, + c10::nullopt); + auto iter = TensorIterator::unary_op(qy, qx); + + using Vec = Vectorized; + cpu_kernel_vec( + iter, + [&](scalar_t value_qx) -> scalar_t { + const auto value_dx = + at::native::dequantize_val(scale, zero_point, value_qx); + + const auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; + const auto kKappa = 0.044715; + const auto x_cube = value_dx * value_dx * value_dx; + const auto inner = kBeta * (value_dx + kKappa * x_cube); + const auto value_dy = 0.5 * value_dx * (1.0 + std::tanh(inner)); + + return at::native::quantize_val( + output_scale, output_zero_point, value_dy); + }, + [&](Vec value_qx) -> Vec { + auto value_dx = value_qx.dequantize( + scale_vec, zero_point_vec, scale_neg_zp_premul_vec); + for (auto & value : value_dx) { + auto value_cube = value * value * value; + auto inner = kBetaVec * (value + kKappaVec * value_cube); + value = kPointFiveVec * value * (kOneVec + inner.tanh()); + } + return Vec::quantize( + value_dx, output_scale, output_zero_point, inv_output_scale); + }); + }); + } else { + AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() { + qy = at::_empty_affine_quantized( + qx.sizes(), + // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) + at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()), + output_scale, + output_zero_point, + c10::nullopt); + auto iter = TensorIterator::unary_op(qy, qx); + + using Vec = Vectorized; + cpu_kernel_vec( + iter, + [&](scalar_t value_qx) -> scalar_t { + const auto value_dx = + at::native::dequantize_val(scale, zero_point, value_qx); + const auto value_dy = + value_dx * 0.5 * (1 + std::erf(value_dx * M_SQRT1_2)); + return at::native::quantize_val( + output_scale, output_zero_point, value_dy); + }, + [&](Vec value_qx) -> Vec { + auto value_dx = value_qx.dequantize( + scale_vec, zero_point_vec, scale_neg_zp_premul_vec); + for (auto & value : value_dx) { + value = value * kPointFiveVec * (kOneVec + (value * kAlphaVec).erf()); + } + return Vec::quantize( + value_dx, output_scale, output_zero_point, inv_output_scale); + }); + }); + } } diff --git a/aten/src/ATen/native/quantized/cpu/qgelu.cpp b/aten/src/ATen/native/quantized/cpu/qgelu.cpp index 7c0ee3cd784..c07796f608d 100644 --- a/aten/src/ATen/native/quantized/cpu/qgelu.cpp +++ b/aten/src/ATen/native/quantized/cpu/qgelu.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -15,9 +16,9 @@ namespace native { DEFINE_DISPATCH(qgelu_stub); -Tensor gelu_quantized_cpu(const Tensor& qx) { +Tensor gelu_quantized_cpu(const Tensor& qx, c10::string_view approximate) { Tensor qy; - qgelu_stub(qx.device().type(), qx, qy); + qgelu_stub(qx.device().type(), qx, qy, get_gelutype_enum(approximate)); return qy; } }} // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/quantized_ops.h b/aten/src/ATen/native/quantized/cpu/quantized_ops.h index a1766380fe5..bfa1f1f7756 100644 --- a/aten/src/ATen/native/quantized/cpu/quantized_ops.h +++ b/aten/src/ATen/native/quantized/cpu/quantized_ops.h @@ -1,4 +1,5 @@ #include +#include #include #include @@ -8,7 +9,7 @@ namespace native { using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/, const Scalar& /*negval_*/); -using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, GeluType /* approximate */); using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point); using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); using qclamp_fn = void (*)( diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h index 9e89fe9acd6..fa18e46b2c6 100644 --- a/caffe2/serialize/versions.h +++ b/caffe2/serialize/versions.h @@ -12,7 +12,7 @@ namespace serialize { constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; #if ENABLE_UPGRADERS -constexpr uint64_t kMaxSupportedFileFormatVersion = 0x9L; +constexpr uint64_t kMaxSupportedFileFormatVersion = 0xAL; #else constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L; #endif @@ -79,7 +79,11 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L; // Bump the version number to 9 to update aten::logspace and // and aten::logspace.out to error out when steps is not // provided. (see: https://github.com/pytorch/pytorch/issues/55951) -constexpr uint64_t kProducedFileFormatVersion = 0x9L; +// 3) [02/11/2022] +// Bump the version number to 10 to update aten::gelu and +// and aten::gelu.out to support the new approximate kwarg. +// (see: https://github.com/pytorch/pytorch/pull/61439) +constexpr uint64_t kProducedFileFormatVersion = 0xAL; #else constexpr uint64_t kProducedFileFormatVersion = 0x3L; #endif diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp index db0f4d25168..1c2a042a471 100644 --- a/test/cpp/api/functional.cpp +++ b/test/cpp/api/functional.cpp @@ -973,10 +973,17 @@ TEST_F(FunctionalTest, GLU) { } TEST_F(FunctionalTest, GELU) { - GELU model; const auto x = torch::linspace(-3.0, 3.0, 100); const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0))); - const auto y = F::gelu(x); + const auto y = F::gelu(x, F::GELUFuncOptions().approximate("none")); + ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05)); +} + +TEST_F(FunctionalTest, TanhGELU) { + const auto x = torch::linspace(-3.0, 3.0, 100); + const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0)); + const auto y_exp = 0.5 * x * (1.0 + inner.tanh()); + const auto y = F::gelu(x, F::GELUFuncOptions().approximate("tanh")); ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05)); } diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index 8632f3e195c..cdf4f0ea0de 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -2860,13 +2860,23 @@ TEST_F(ModulesTest, GLU) { } TEST_F(ModulesTest, GELU) { - GELU model; + GELU model(GELUOptions().approximate("none")); const auto x = torch::linspace(-3.0, 3.0, 100); const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0))); const auto y = model(x); ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05)); } +TEST_F(ModulesTest, TanhGELU) { + GELU model(GELUOptions().approximate("tanh")); + const auto x = torch::linspace(-3.0, 3.0, 100); + const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0)); + const auto y_exp = 0.5 * x * (1.0 + inner.tanh()); + const auto y = model(x); + ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST_F(ModulesTest, Mish) { Mish model; auto x = torch::randn(100) * 10; diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 56070e62cfa..e15ac0f29bc 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -50,12 +50,8 @@ ("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)), ("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)), ("aten::randperm", datetime.date(9999, 1, 1)), - ("aten::_conv_depthwise2d_backward", datetime.date(2022, 1, 31)), - ("aten::conv_depthwise3d_backward", datetime.date(2022, 1, 31)), - ("aten::cudnn_convolution.deprecated", datetime.date(2022, 1, 31)), - ("aten::cudnn_convolution.deprecated2", datetime.date(2022, 1, 31)), - ("aten::cudnn_convolution_transpose.deprecated", datetime.date(2022, 1, 31)), - ("aten::cudnn_convolution_transpose.deprecated2", datetime.date(2022, 1, 31)), + ("aten::gelu", datetime.date(2022, 3, 1)), + ("aten::gelu_backward", datetime.date(2022, 3, 1)), ("aten::cudnn_convolution_backward", datetime.date(2022, 1, 31)), ("aten::cudnn_convolution_backward_input", datetime.date(2022, 1, 31)), ("aten::cudnn_convolution_backward_weight", datetime.date(2022, 1, 31)), diff --git a/test/jit/test_autodiff_subgraph_slicing.py b/test/jit/test_autodiff_subgraph_slicing.py index 8454f786edb..4b72fc6f456 100644 --- a/test/jit/test_autodiff_subgraph_slicing.py +++ b/test/jit/test_autodiff_subgraph_slicing.py @@ -447,7 +447,7 @@ def test_aliased_outputs(self): %0 : int[] = prim::Constant[value=[2, 2, 1]]() %1 : int = prim::Constant[value=0]() %2 : Tensor = aten::t(%b) - %3 : Tensor = aten::gelu(%2) + %3 : Tensor = aten::relu(%2) %4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%b, %3, %2) return (%4) """ @@ -471,7 +471,7 @@ def test_aliased_outputs(self): %1 : int = prim::Constant[value=0]() %d : Tensor = aten::t(%c) %2 : Tensor = aten::t(%b) - %3 : Tensor = aten::gelu(%2) + %3 : Tensor = aten::relu(%2) %4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%3, %2, %d, %b, %c, %b) return (%4) """ diff --git a/test/onnx/test_custom_ops.py b/test/onnx/test_custom_ops.py index bed480fc2d8..c2d1ec27eed 100644 --- a/test/onnx/test_custom_ops.py +++ b/test/onnx/test_custom_ops.py @@ -137,7 +137,7 @@ def test_contrib_op_with_loop(self): class M(torch.nn.Module): def __init__(self): super().__init__() - self.gelu = torch.nn.GELU() + self.gelu = torch.nn.GELU(approximate='none') def forward(self, x): res = [] @@ -150,7 +150,7 @@ def forward(self, x): res.append(x[0]) return torch.stack(res), torch.stack(res2) - def symbolic_custom_gelu(g, input): + def symbolic_custom_gelu(g, input, approximate): return g.op("com.microsoft::Gelu", input).setType(input.type()) from torch.onnx import register_custom_op_symbolic @@ -158,7 +158,7 @@ def symbolic_custom_gelu(g, input): x = torch.randn(3, 3, 4, requires_grad=True) model = torch.jit.script(M()) - run_model_test(self, model, input=(x, )) + run_model_test(self, model, input=(x,)) if __name__ == "__main__": unittest.main() diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 77c2b85f27f..72ff9392254 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -2383,7 +2383,17 @@ def forward(self, input, batch1, batch2): def test_gelu(self): class GeluModel(torch.nn.Module): def forward(self, x): - return torch.nn.functional.gelu(x) + return torch.nn.functional.gelu(x, approximate='none') + + model = GeluModel() + inputs = torch.randn(2, 4, 5, 6, requires_grad=True) + self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE) + + @skipIfUnsupportedMinOpsetVersion(9) + def test_tanh_gelu(self): + class GeluModel(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.gelu(x, approximate='tanh') model = GeluModel() inputs = torch.randn(2, 4, 5, 6, requires_grad=True) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 2ae439d7705..c71a9756408 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -6256,7 +6256,16 @@ def forward(self, x): def test_gelu(self): class GeluModel(torch.nn.Module): def forward(self, x): - return torch.nn.functional.gelu(x) + return torch.nn.functional.gelu(x, approximate='none') + + x = torch.randn(2, 4, 5, 6, requires_grad=True) + self.run_test(GeluModel(), x) + + @skipIfUnsupportedMinOpsetVersion(9) + def test_tanh_gelu(self): + class GeluModel(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.gelu(x, approximate='tanh') x = torch.randn(2, 4, 5, 6, requires_grad=True) self.run_test(GeluModel(), x) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 01393915380..433a9d2cc75 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -822,11 +822,11 @@ def forward(self, input, other): def test_custom_opsets_gelu(self): self.addCleanup(unregister_custom_op_symbolic, "::gelu", 1) - def gelu(g, self): + def gelu(g, self, approximate): return g.op("com.microsoft::Gelu", self).setType(self.type()) register_custom_op_symbolic("::gelu", gelu, 1) - model = torch.nn.GELU() + model = torch.nn.GELU(approximate='none') x = torch.randn(3, 3) f = io.BytesIO() torch.onnx.export(model, (x, ), f, @@ -842,11 +842,11 @@ def gelu(g, self): def test_register_aten_custom_op_symbolic(self): self.addCleanup(unregister_custom_op_symbolic, "aten::gelu", 1) - def gelu(g, self): + def gelu(g, self, approximate): return g.op("com.microsoft::Gelu", self).setType(self.type()) register_custom_op_symbolic("aten::gelu", gelu, 1) - model = torch.nn.GELU() + model = torch.nn.GELU(approximate='none') x = torch.randn(3, 3) f = io.BytesIO() torch.onnx.export(model, (x, ), f, opset_version=self.opset_version) diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index be84e7bd4e8..2097a710d89 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -441,8 +441,9 @@ def test_qgelu(self): shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) dtypes = (torch.quint8, torch.qint8) memory_formats = (torch.channels_last, torch.contiguous_format) - test_cases = itertools.product(shapes, dtypes, memory_formats) - for shape, dtype, memory_format in test_cases: + approximation = ['none', 'tanh'] + test_cases = itertools.product(shapes, dtypes, memory_formats, approximation) + for shape, dtype, memory_format, approximate in test_cases: if memory_format == torch.channels_last and len(shape) != 4: continue X, scale, zero_point, torch_type = \ @@ -454,7 +455,7 @@ def test_qgelu(self): dqX = qX.dequantize() op = torch.nn.functional.gelu - dqY = op(dqX) + dqY = op(dqX, approximate=approximate) qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point, dtype=torch_type) qY_hat = op(qX) diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index c03ff0b3119..1d78a19c5ad 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -2181,18 +2181,21 @@ def t2(x: torch.Tensor, p: float, train: bool): @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_gelu(self): + old_guard = torch._C._jit_set_nvfuser_guard_mode(True) dtype = torch.float device = "cuda" x = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=True) grads = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=False) - def t(x: torch.Tensor): - o = torch.nn.functional.gelu(x) + def t(x: torch.Tensor, mode : str): + o = torch.nn.functional.gelu(x, approximate=mode) o = o * 2.0 return o t_jit = torch.jit.script(t) - self._run_training_helper(t_jit, t, grads, x) + self._run_training_helper(t_jit, t, grads, x, 'none') + self._run_training_helper(t_jit, t, grads, x, 'tanh') + torch._C._jit_set_nvfuser_guard_mode(old_guard) @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index a548a8df4c8..ab2b85c6bb3 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -1321,6 +1321,37 @@ def test_isnan(self): " ".join(["Failed:", str(dtype), 'isnan', device]) ) + def test_gelu(self): + def apply(fn): + return lambda x, approximate: fn(x, approximate) + + unary_ops = [ + F.gelu, + ] + sizes = [(1,), (2,), (4, 4)] + for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes): + # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed + if dtype in [torch.float16, torch.bfloat16] and device == "cpu": + continue + try: + x = self.data_for(dtype, device, size=size) + cond = self.data_for(torch.bool, device) + fn = apply(op) + ref = fn(x, cond) + except Exception: + # If eager mode doesn't support a dtype/op/device combo, + # neither does the fuser. Catch everything to avoid needing to + # guess what errors might be thrown by eager. + continue + try: + t = torch.jit.trace(fn, (x, cond)) + torch.testing.assert_close(ref, t(x, cond)) + self.assertAllFused(t.graph_for(x, cond)) + except Exception as e: + raise RuntimeError( + " ".join(["Failed:", str(dtype), op.__name__, device, str(size)]) + ) + def test_unary_ops(self): def apply(fn): return lambda x: fn(x) @@ -1355,7 +1386,6 @@ def apply(fn): F.softplus, torch.sqrt, torch.rsqrt, - F.gelu, torch.abs, torch.ceil, torch.floor, @@ -2367,7 +2397,6 @@ class TestTEFuserDynamic(TestTEFuser): 'mul', 'ne', 'neg', - 'nn.functional.gelu', 'nn.functional.hardshrink', 'nn.functional.hardsigmoid', 'nn.functional.hardswish', diff --git a/test/test_nn.py b/test/test_nn.py index fb7a172161e..c6a2e24e6cf 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -9202,48 +9202,6 @@ def test_PReLU_backward_requires_grad_false(self): y.mean().backward() self.assertEqual(x.grad, None) - @unittest.skipIf( - not TEST_NUMPY or not TEST_SCIPY, "Numpy or Scipy not found") - def test_gelu(self): - def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None): - numpy_dtype = { - torch.bfloat16: torch.float, torch.float: torch.float, torch.double: torch.double - }[dtype] - devices = ['cpu'] - devices += ['cuda'] if TEST_CUDA else [] - - def _gelu_ref(X): - return X * stats.norm.cdf(X) - - for d in devices: - if contiguous: - X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d) - else: - X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2] - res = F.gelu(X) - ref = _gelu_ref(X.to(numpy_dtype).cpu().detach().numpy()) - self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False) - if dtype == torch.float64: - gradcheck(F.gelu, [X], eps=1e-4) - - for n in range(1, 10): - for m in range(1, 10): - _test_gelu(n, m, torch.bfloat16, True, 1e-2, 0) - _test_gelu(n, m, torch.bfloat16, False, 1e-2, 0) - _test_gelu(n, m, torch.float32, True) - _test_gelu(n, m, torch.float32, False) - _test_gelu(n, m, torch.float64, True) - _test_gelu(n, m, torch.float64, False) - - # Test multi threaded - num_threads = torch.get_num_threads() - torch.set_num_threads(4) - try: - _test_gelu(32, 32, torch.float32, False) - finally: - torch.set_num_threads(num_threads) - - def test_bce_loss_always_nonnegative(self): target = torch.ones(5) input = torch.ones(5) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 582ba69c362..7f7c13f01aa 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1799,10 +1799,15 @@ - name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result) -- name: gelu(Tensor self) -> Tensor - self: "GradMode::is_enabled() ? infinitely_differentiable_gelu_backward(grad, self) : gelu_backward(grad, self)" +- name: gelu(Tensor self, *, str approximate='none') -> Tensor + self: gelu_backward(grad, self, approximate) result: auto_element_wise +- name: gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor + grad_output: gelu_backward(grad, self, approximate) + self: gelu_double_backward(grad, grad_output, self, approximate) + result: gelu_backward(grad_output_t, self_p, approximate) + gelu_double_backward(self_t, grad_output_p, self_p, approximate) + - name: glu(Tensor self, int dim=-1) -> Tensor self: glu_backward(grad, self, dim) diff --git a/torch/csrc/api/include/torch/nn/functional/activation.h b/torch/csrc/api/include/torch/nn/functional/activation.h index b038f1bce6b..2258dd0c431 100644 --- a/torch/csrc/api/include/torch/nn/functional/activation.h +++ b/torch/csrc/api/include/torch/nn/functional/activation.h @@ -336,8 +336,16 @@ inline Tensor glu(const Tensor& input, const GLUFuncOptions& options = {}) { // ============================================================================ -inline Tensor gelu(const Tensor& input) { - return torch::gelu(input); +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor gelu(const Tensor& input, string approximate) { + return torch::gelu(input, approximate); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +inline Tensor gelu(const Tensor& input, const GELUFuncOptions& options = {}) { + return detail::gelu(input, options.approximate()); } // ============================================================================ diff --git a/torch/csrc/api/include/torch/nn/modules/activation.h b/torch/csrc/api/include/torch/nn/modules/activation.h index 28225ee0f68..e4fc02f310d 100644 --- a/torch/csrc/api/include/torch/nn/modules/activation.h +++ b/torch/csrc/api/include/torch/nn/modules/activation.h @@ -570,12 +570,17 @@ TORCH_MODULE(GLU); // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API GELUImpl : public torch::nn::Cloneable { public: + explicit GELUImpl(GELUOptions options_ = {}); + Tensor forward(const Tensor& input); void reset() override; /// Pretty prints the `GELU` module into the given `stream`. void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + GELUOptions options; }; /// A `ModuleHolder` subclass for `GELUImpl`. diff --git a/torch/csrc/api/include/torch/nn/options/activation.h b/torch/csrc/api/include/torch/nn/options/activation.h index 651c800a84c..16ab0245fbb 100644 --- a/torch/csrc/api/include/torch/nn/options/activation.h +++ b/torch/csrc/api/include/torch/nn/options/activation.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -95,6 +96,33 @@ using GLUFuncOptions = GLUOptions; // ============================================================================ +/// Options for the `GELU` module. +/// +/// Example: +/// ``` +/// GELU model(GELUOptions().approximate("none")); +/// ``` +struct TORCH_API GELUOptions { + /// Specifies the approximation to apply to the output. + TORCH_ARG(std::string, approximate) = "none"; +}; + +namespace functional { +/// Options for `torch::nn::functional::gelu`. +/// +/// See the documentation for `torch::nn::GELUOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::gelu(input, F::GELUFuncOptions().approximate("none")); +/// ``` +using GELUFuncOptions = GELUOptions; +} // namespace functional + +// ============================================================================ + /// Options for the `Hardshrink` module. /// /// Example: diff --git a/torch/csrc/api/src/nn/modules/activation.cpp b/torch/csrc/api/src/nn/modules/activation.cpp index 677c9e1cc83..001199e98ed 100644 --- a/torch/csrc/api/src/nn/modules/activation.cpp +++ b/torch/csrc/api/src/nn/modules/activation.cpp @@ -284,8 +284,10 @@ void GLUImpl::pretty_print(std::ostream& stream) const { // ============================================================================ +GELUImpl::GELUImpl(GELUOptions options_) : options(std::move(options_)) {} + Tensor GELUImpl::forward(const Tensor& input) { - return F::gelu(input); + return F::detail::gelu(input, options.approximate()); } void GELUImpl::reset() {} diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index b4bcc4e4316..951b5eeca96 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -2338,6 +2339,47 @@ std::tuple prelu_double_backward( } } +Tensor gelu_double_backward( + const Tensor & ggI, + const Tensor & gO, + const Tensor & input, + c10::string_view approximate) { + //if (at::native::get_gelutype_enum(approximate) == at::native::GeluType::Tanh) { + if (approximate == "tanh") { + constexpr auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; + constexpr auto kKappa = 0.044715; + + auto inner = kBeta * (input + kKappa * pow(input, 3)); + auto tanh_inner = tanh(inner); + auto sech_inner = 1 / cosh(inner); + + auto f = 0.5 * input; + auto g = 1 - tanh_inner * tanh_inner; + auto h = kBeta * (1 + 3 * kKappa * input * input); + + auto f_prime_gh = 0.5 * g * h; + + auto g_prime = (2 * sech_inner) * (-sech_inner * tanh_inner) * h; + auto g_prime_fh = f * h * g_prime; + + auto h_prime = 6 * kKappa * input * kBeta; + auto h_prime_fg = f * g * h_prime; + + // left_derivative = f_prime_gh + // right_derivative = f_prime_gh + g_prime_fh + h_prime_fg + // dgrad_dX = left_derivative + right_derivative + auto gI = ggI * gO * (2 * f_prime_gh + g_prime_fh + h_prime_fg); + return gI; + } else { + constexpr auto kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5; + auto input_sq = input * input; + auto pdf = kBeta * at::exp(-0.5 * input_sq); + auto dgrad_dInput = 2 * pdf - input_sq * pdf; + auto gI = ggI * gO * dgrad_dInput; + return gI; + } +} + Tensor elu_double_backward( const Tensor& grad, const Tensor& grad_output, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 739b44b4d62..9451f5f49d2 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -303,6 +303,11 @@ std::tuple prelu_double_backward( const Tensor & grad_out, const Tensor & input_, const Tensor & weight_); +Tensor gelu_double_backward( + const Tensor & ggI, + const Tensor & gO, + const Tensor & input, + c10::string_view approximate); Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional storage_offset_); std::tuple atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array output_mask); std::tuple layer_norm_double_backward( diff --git a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp index 08d3e89d21c..47c0316abda 100644 --- a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp +++ b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp @@ -77,6 +77,10 @@ Value* createConditionalConstant(Node* profile_ivalue) { // int val = IValue( static_cast(profile_ivalue->i(Symbol::attr("profiled_int")))); + } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_str"))) { + // str + val = IValue(static_cast( + profile_ivalue->s(Symbol::attr("profiled_str")))); } else { GRAPH_DEBUG("profile_ivalue: ", *profile_ivalue); TORCH_WARN( diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index a33b33895c5..11c27cffec2 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -12,6 +12,8 @@ #include #include +#include + #include #include @@ -62,6 +64,7 @@ const auto& intListAttr = Symbol::attr("profiled_int_list"); const auto& intAttr = Symbol::attr("profiled_int"); const auto& boolListAttr = Symbol::attr("profiled_bool_list"); const auto& boolAttr = Symbol::attr("profiled_bool"); +const auto& strAttr = Symbol::attr("profiled_str"); typedef Val* CgValue; typedef Expr* CgOp; @@ -2273,7 +2276,8 @@ class IrParser { } { - auto ptr_op = getOperatorForLiteral("aten::gelu(Tensor self) -> Tensor"); + auto ptr_op = getOperatorForLiteral( + "aten::gelu(Tensor self, *, str approximate='none') -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { @@ -2283,7 +2287,21 @@ class IrParser { c10::nullopt, value_map[node->inputs()[0]->unique()]); auto self = list_val.front(); list_val.pop_front(); - auto out = gelu(self); + + auto approximate = constant_as(node->input(1)); + TORCH_INTERNAL_ASSERT( + approximate.has_value(), + "The approximate parameter is required."); + const auto kApproximate = approximate.value(); + + Val* out = nullptr; + if (at::native::get_gelutype_enum(kApproximate) == + at::native::GeluType::Tanh) { + out = fast_gelu(self); + } else { + out = unaryOp(UnaryOpType::Gelu, self); + } + value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, @@ -2293,7 +2311,7 @@ class IrParser { { auto ptr_op = getOperatorForLiteral( - "aten::gelu_backward(Tensor grad, Tensor self) -> Tensor"); + "aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { @@ -2308,7 +2326,20 @@ class IrParser { auto self = list_val.front(); list_val.pop_front(); - auto grad_in = gelu_backward(grad_out, self); + auto approximate = constant_as(node->input(2)); + TORCH_INTERNAL_ASSERT( + approximate.has_value(), + "The approximate parameter is required."); + const auto kApproximate = approximate.value(); + + Val* grad_in = nullptr; + if (at::native::get_gelutype_enum(kApproximate) == + at::native::GeluType::Tanh) { + grad_in = fast_gelu_backward(grad_out, self); + } else { + grad_in = gelu_backward(grad_out, self); + } + value_map.emplace( node->output()->unique(), ValueHolder(grad_in, format)); }, @@ -2453,9 +2484,13 @@ class IrParser { } value_map_.emplace(val->unique(), cg_val); return true; - } else if (val->type()->isSubtypeOf( - static_cast(NoneType::get()))) { + } else if ( + val->type()->isSubtypeOf( + static_cast(StringType::get())) || + val->type()->isSubtypeOf(static_cast(NoneType::get()))) { // TODO: should we consider adding support for NoneType; + // String scalars are only used in parsing rules; + // Do not register string with codegen IR. return true; } else if (val->type()->cast()) { // TODO: we don't support list type in codegen yet; @@ -2646,6 +2681,34 @@ void profileIntList(ProfilingRecord* pr, Node* node, size_t offset) { pn->setCallback(ivalue_profiler); } +void profileString(ProfilingRecord* pr, Node* node, size_t offset) { + auto pn = insertProfileIValueOp(node, offset, pr); + + const auto ivalue_profiler = [pr, pn](Stack& stack) { + std::lock_guard lock(pr->mutex_); + + // TODO: we don't care about merging multiple profiling runs as we don't + // support it at all; + int64_t frame_id = 0; + pop(stack, frame_id); + IValue value; + pop(stack, value); + TORCH_INTERNAL_ASSERT( + value.isString(), "profiling seeing the wrong data type"); + if (!pn->hasAttribute(strAttr)) { + pn->s_(strAttr, value.toStringRef()); + } else { + const auto& profiled_str = pn->s(strAttr); + const auto& input_str = value.toStringRef(); + TORCH_INTERNAL_ASSERT( + input_str == profiled_str, "profiling ivalue doesn't support merge"); + } + push(stack, value); + }; + + pn->setCallback(ivalue_profiler); +} + void profileBool(ProfilingRecord* pr, Node* node, size_t offset) { auto pn = insertProfileIValueOp(node, offset, pr); @@ -3015,6 +3078,38 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { } } + static auto gelu_schema = + getOperatorForLiteral( + "aten::gelu(Tensor self, *, str approximate='none') -> Tensor") + ->schema(); + if (node->matches(gelu_schema)) { + switch (offset) { + // argument 1: approximate; + case 1: + profileString(pr, node, offset); + break; + default: + return false; + } + return true; + } + + static auto gelu_backward_schema = + getOperatorForLiteral( + "aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor") + ->schema(); + if (node->matches(gelu_backward_schema)) { + switch (offset) { + // argument 2: approximate; + case 2: + profileString(pr, node, offset); + break; + default: + return false; + } + return true; + } + static auto softmax_backward_data_schema = getOperatorForLiteral( "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor") diff --git a/torch/csrc/jit/mobile/upgrader_mobile.cpp b/torch/csrc/jit/mobile/upgrader_mobile.cpp index 28192504558..83e23342d5c 100644 --- a/torch/csrc/jit/mobile/upgrader_mobile.cpp +++ b/torch/csrc/jit/mobile/upgrader_mobile.cpp @@ -43,21 +43,29 @@ getOperatorVersionMapForMobile() { std::vector({ Upgrader({0, 3, "div__Tensor_0_3", 3}) })}, + {std::string("aten::gelu"), + std::vector({ + Upgrader({0, 9, "gelu_0_9", 5}) + })}, + {std::string("aten::gelu.out"), + std::vector({ + Upgrader({0, 9, "gelu_out_0_9", 6}) + })}, {std::string("aten::linspace"), std::vector({ - Upgrader({0, 7, "linspace_0_7", 5}) + Upgrader({0, 7, "linspace_0_7", 7}) })}, {std::string("aten::linspace.out"), std::vector({ - Upgrader({0, 7, "linspace_out_0_7", 6}) + Upgrader({0, 7, "linspace_out_0_7", 8}) })}, {std::string("aten::logspace"), std::vector({ - Upgrader({0, 8, "logspace_0_8", 7}) + Upgrader({0, 8, "logspace_0_8", 9}) })}, {std::string("aten::logspace.out"), std::vector({ - Upgrader({0, 8, "logspace_out_0_8", 8}) + Upgrader({0, 8, "logspace_out_0_8", 10}) })}, }); return operatorVersionMapForMobile; @@ -292,6 +300,45 @@ const std::vector& getUpgraderBytecodeList() { OperatorString({"aten::div", "out_mode", 4}), }), // operators list }), + ByteCodeFunctionWithOperator({ + mobile::Function::registerFunc( + "gelu_0_9", + std::vector({ + Instruction{OpCode::STORE, 1, 0}, + Instruction{OpCode::MOVE, 1, 0}, + Instruction{OpCode::OP, 0, 0}, + Instruction{OpCode::RET, 0, 0}, + }), // instructions list, + std::vector({ + c10::IValue("none"), + }), // constants list, + std::vector(), // types list, + 1 + ), + std::vector({ + OperatorString({"aten::gelu", "", 1}), + }), // operators list + }), + ByteCodeFunctionWithOperator({ + mobile::Function::registerFunc( + "gelu_out_0_9", + std::vector({ + Instruction{OpCode::STOREN, 1, 2}, + Instruction{OpCode::MOVE, 1, 0}, + Instruction{OpCode::MOVE, 2, 0}, + Instruction{OpCode::OP, 0, 0}, + Instruction{OpCode::RET, 0, 0}, + }), // instructions list, + std::vector({ + c10::IValue("none"), + }), // constants list, + std::vector(), // types list, + 2 + ), + std::vector({ + OperatorString({"aten::gelu", "out", 2}), + }), // operators list + }), ByteCodeFunctionWithOperator({ mobile::Function::registerFunc( "linspace_0_7", diff --git a/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp index 75201cf5d67..7b09cc409a4 100644 --- a/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp +++ b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp @@ -14,64 +14,64 @@ namespace torch { namespace jit { -static std::unordered_map kUpgradersEntryMap( - {{"logspace_0_8", R"SCRIPT( +static std::unordered_map kUpgradersEntryMap({ + {"logspace_0_8", R"SCRIPT( def logspace_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, dtype: Optional[int], layout: Optional[int], device: Optional[Device], pin_memory: Optional[bool]): if (steps is None): return torch.logspace(start=start, end=end, steps=100, base=base, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory) return torch.logspace(start=start, end=end, steps=steps, base=base, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory) )SCRIPT"}, - {"logspace_out_0_8", R"SCRIPT( + {"logspace_out_0_8", R"SCRIPT( def logspace_out_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, out: Tensor): if (steps is None): return torch.logspace(start=start, end=end, steps=100, base=base, out=out) return torch.logspace(start=start, end=end, steps=steps, base=base, out=out) )SCRIPT"}, - {"linspace_0_7", R"SCRIPT( + {"linspace_0_7", R"SCRIPT( def linspace_0_7(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], *, dtype: Optional[int], layout: Optional[int], device: Optional[Device], pin_memory: Optional[bool]): if (steps is None): return torch.linspace(start=start, end=end, steps=100, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory) return torch.linspace(start=start, end=end, steps=steps, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory) )SCRIPT"}, - {"linspace_out_0_7", R"SCRIPT( + {"linspace_out_0_7", R"SCRIPT( def linspace_out_0_7(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], *, out: Tensor): if (steps is None): return torch.linspace(start=start, end=end, steps=100, out=out) return torch.linspace(start=start, end=end, steps=steps, out=out) )SCRIPT"}, - {"div_Tensor_0_3", R"SCRIPT( + {"div_Tensor_0_3", R"SCRIPT( def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor: if (self.is_floating_point() or other.is_floating_point()): return self.true_divide(other) return self.divide(other, rounding_mode='trunc') )SCRIPT"}, - {"div_Scalar_0_3", R"SCRIPT( + {"div_Scalar_0_3", R"SCRIPT( def div_Scalar_0_3(self: Tensor, other: number) -> Tensor: if (self.is_floating_point() or isinstance(other, float)): return self.true_divide(other) return self.divide(other, rounding_mode='trunc') )SCRIPT"}, - {"div_out_0_3", R"SCRIPT( + {"div_out_0_3", R"SCRIPT( def div_out_0_3(self: Tensor, other: Tensor, *, out: Tensor) -> Tensor: if (self.is_floating_point() or other.is_floating_point() or out.is_floating_point()): return self.true_divide(other, out=out) return self.divide(other, rounding_mode='trunc', out=out) )SCRIPT"}, - {"div__Tensor_0_3", R"SCRIPT( + {"div__Tensor_0_3", R"SCRIPT( def div__Tensor_0_3(self: Tensor, other: Tensor) -> Tensor: if (self.is_floating_point() or other.is_floating_point()): return self.true_divide_(other) return self.divide_(other, rounding_mode='trunc') )SCRIPT"}, - {"div__Scalar_0_3", R"SCRIPT( + {"div__Scalar_0_3", R"SCRIPT( def div__Scalar_0_3(self: Tensor, other: number) -> Tensor: if (self.is_floating_point() or isinstance(other, float)): return self.true_divide_(other) return self.divide_(other, rounding_mode='trunc') )SCRIPT"}, - {"full_0_4", R"SCRIPT( + {"full_0_4", R"SCRIPT( def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None, layout:Optional[int]=None, device:Optional[Device]=None, pin_memory:Optional[bool]=None) -> Tensor: @@ -79,10 +79,19 @@ def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None, fill_value = float(fill_value) return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory) )SCRIPT"}, - {"full_out_0_4", R"SCRIPT( + {"full_out_0_4", R"SCRIPT( def full_out_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor: return torch.full(size, fill_value, out=out) -)SCRIPT"}}); +)SCRIPT"}, + {"gelu_0_9", R"SCRIPT( +def gelu_0_9(self: Tensor) -> Tensor: + return torch.gelu(self, approximate='none') +)SCRIPT"}, + {"gelu_out_0_9", R"SCRIPT( +def gelu_out_0_9(self: Tensor, *, out: Tensor) -> Tensor: + return torch.gelu(self, approximate='none', out=out) +)SCRIPT"}, +}); std::shared_ptr create_upgrader_graph( const std::string& upgrader_name, diff --git a/torch/csrc/jit/operator_upgraders/version_map.cpp b/torch/csrc/jit/operator_upgraders/version_map.cpp index e6860e318ce..1e19f4cc39d 100644 --- a/torch/csrc/jit/operator_upgraders/version_map.cpp +++ b/torch/csrc/jit/operator_upgraders/version_map.cpp @@ -59,7 +59,12 @@ static std::unordered_map> operatorVersi {"aten::full.out", {{5, "full_out_0_4", - "aten::full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)"}}}}); + "aten::full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)"}}}, + {"aten::gelu", {{10, "gelu_0_9", "aten::gelu(Tensor self) -> Tensor"}}}, + {"aten::gelu.out", + {{10, + "gelu_out_0_9", + "aten::gelu.out(Tensor self, *, Tensor(a!) out) -> Tensor"}}}}); const std::unordered_map>& get_operator_version_map() { diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 262e9b35110..0f79d01104a 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -872,7 +872,7 @@ class ShapePropagator : public PropertyPropBase { "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", "aten::rsqrt(Tensor self) -> Tensor", "aten::selu(Tensor self) -> Tensor", - "aten::gelu(Tensor self) -> Tensor", + "aten::gelu(Tensor self, *, str approximate='none') -> Tensor", "aten::sigmoid(Tensor self) -> Tensor", "aten::sign(Tensor self) -> Tensor", "aten::sin(Tensor self) -> Tensor", diff --git a/torch/csrc/jit/runtime/symbolic_script.cpp b/torch/csrc/jit/runtime/symbolic_script.cpp index f385ba3875b..64f8dd5b4c5 100644 --- a/torch/csrc/jit/runtime/symbolic_script.cpp +++ b/torch/csrc/jit/runtime/symbolic_script.cpp @@ -913,16 +913,10 @@ const std::vector functions = { return grad_output * torch.where(self > 0, 1.0, negative_slope).type_as(result), None return result, backward - def gelu(self): - result = torch.gelu(self) - def backward(grad_output): - m_2_sqrtpi = 1.12837916709551257390 - m_sqrt1_2 = 0.707106781186547524401 - alpha = m_sqrt1_2 - beta = m_2_sqrtpi * m_sqrt1_2 * 0.5 - cdf = (torch.erf(self * m_sqrt1_2) + 1.0) * 0.5 - pdf = beta * torch.exp(self * self * -0.5) - return grad_output * (cdf + self * pdf) + def gelu(self : Tensor, *, approximate : str): + result = torch.gelu(self, approximate=approximate) + def backward(grad_output): + return torch.gelu_backward(grad_output, self, approximate=approximate), None return result, backward def hardswish(self): diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp index fbe3a0c36c7..71c9730f5ba 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp @@ -76,7 +76,7 @@ const OperatorMap& get_tensorexpr_elementwise_set() { {"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor", "unary"}, {"aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor", "unary"}, {"aten::relu6(Tensor self) -> Tensor", "unary"}, - {"aten::gelu(Tensor self) -> Tensor", "unary"}, + {"aten::gelu(Tensor self, *, str approximate='none') -> Tensor", "unary"}, {"aten::neg(Tensor self) -> Tensor", "unary"}, {"aten::reciprocal(Tensor self) -> Tensor", "unary"}, {"aten::expm1(Tensor self) -> Tensor", "unary"}, diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index cc6d2678685..a48c9d07e29 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -362,6 +362,8 @@ ArgValue TensorExprKernel::toArg(const torch::jit::Value* v) const { return val.toIntVector(); } else if (val.isDoubleList()) { return val.toDoubleVector(); + } else if (val.isString()) { + return val.toStringRef(); } else { throw unsupported_dtype(val.type()->str()); } diff --git a/torch/csrc/jit/tensorexpr/lowerings.cpp b/torch/csrc/jit/tensorexpr/lowerings.cpp index ee3a77084cd..c0905588c2d 100644 --- a/torch/csrc/jit/tensorexpr/lowerings.cpp +++ b/torch/csrc/jit/tensorexpr/lowerings.cpp @@ -3,6 +3,8 @@ #include #include +#include + namespace torch { namespace jit { namespace tensorexpr { @@ -641,22 +643,44 @@ int nnc_lowerings_lazy_registration() { }); RegisterNNCLoweringsFunction aten_gelu( - {"aten::gelu(Tensor self) -> (Tensor)"}, + {"aten::gelu(Tensor self, *, str approximate='none') -> (Tensor)"}, [](const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType, at::Device device) { - return computeOneOperand( - "aten_gelu", - inputs, - outputShape, - outputType, - [](const ExprHandle& a) { - auto m_sqrt1_2 = Cast::make(a.dtype(), M_SQRT1_2); - auto one = Cast::make(a.dtype(), 1.); - auto point_five = Cast::make(a.dtype(), .5); - return a * point_five * (one + erf(a * m_sqrt1_2)); - }); + const auto& kApproximate = c10::get(inputs[1]); + std::vector operands = {inputs.front()}; + if (at::native::get_gelutype_enum(kApproximate) == + at::native::GeluType::Tanh) { + // approximate == 'tanh' + return computeOneOperand( + "aten_tanh_gelu", + operands, + outputShape, + outputType, + [](const ExprHandle& a) { + auto one = Cast::make(a.dtype(), 1.); + auto point_five = Cast::make(a.dtype(), .5); + auto beta = Cast::make(a.dtype(), M_SQRT2 * M_2_SQRTPI * 0.5); + auto kappa = Cast::make(a.dtype(), 0.044715); + auto a_cube = a * a * a; + auto inner = beta * (a + kappa * a_cube); + return point_five * a * (one + tanh(inner)); + }); + } else { + // approximate == 'none' + return computeOneOperand( + "aten_gelu", + operands, + outputShape, + outputType, + [](const ExprHandle& a) { + auto m_sqrt1_2 = Cast::make(a.dtype(), M_SQRT1_2); + auto one = Cast::make(a.dtype(), 1.); + auto point_five = Cast::make(a.dtype(), .5); + return a * point_five * (one + erf(a * m_sqrt1_2)); + }); + } }); RegisterNNCLoweringsFunction aten_batch_norm( diff --git a/torch/csrc/jit/tensorexpr/lowerings.h b/torch/csrc/jit/tensorexpr/lowerings.h index 19aa85810b9..aac507ff132 100644 --- a/torch/csrc/jit/tensorexpr/lowerings.h +++ b/torch/csrc/jit/tensorexpr/lowerings.h @@ -26,6 +26,7 @@ using ArgValue = c10::variant< BufList, DoubleList, IntList, + std::string, ArgNone>; using NNCLoweringFunction = std::function Tensor +gelu(input, approximate = 'none') -> Tensor -Applies element-wise the function +When the approximate argument is 'none', it applies element-wise the function :math:`\text{GELU}(x) = x * \Phi(x)` where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. +When the approximate argument is 'tanh', Gelu is estimated with: + :math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3))) + See `Gaussian Error Linear Units (GELUs) `_. """) - hardshrink = _add_docstr( torch.hardshrink, r""" diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 8e92e29d6c6..0ab153991ca 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -141,7 +141,7 @@ def rrelu(input: Tensor, lower: float = ..., upper: float = ..., training: bool inplace: bool = ...) -> Tensor: ... -def gelu(input: Any): ... +def gelu(input: Any, approximate: str = ...): ... def hardshrink(input: Tensor, lambd: float = ...) -> Tensor: ... diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 6066d855c8c..aeb5590bd47 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -654,6 +654,13 @@ class GELU(Module): where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. + When the approximate argument is 'tanh', Gelu is estimated with: + :math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3))) + + Args: + approximate (string, optional): the gelu approximation algorithm to use: + ``'none'`` | ``'tanh'``. Default: ``'none'`` + Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. @@ -666,8 +673,18 @@ class GELU(Module): >>> input = torch.randn(2) >>> output = m(input) """ + __constants__ = ['approximate'] + approximate: str + + def __init__(self, approximate: str = 'none') -> None: + super(GELU, self).__init__() + self.approximate = approximate + def forward(self, input: Tensor) -> Tensor: - return F.gelu(input) + return F.gelu(input, approximate=self.approximate) + + def extra_repr(self) -> str: + return 'approximate={}'.format(self.approximate) class Hardshrink(Module): diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 776a239fb10..acdde766120 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -2953,12 +2953,27 @@ def remainder(g, input, other): quo = g.op("Mul", div, other) return g.op("Sub", input, quo) - -def gelu(g, self): - _sqrt2 = 1.4142135623730951 - erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double))) - erf_plusone = add(g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double))) - return mul(g, mul(g, self, erf_plusone), g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double))) +@parse_args("v", "s") +def gelu(g, self, approximate): + # none approximate : onnx::Constant[value={0}] + # tanh approximate : onnx::Constant[value={1}] + if approximate == 'tanh': + kBeta = math.sqrt(2 / math.pi) + kKappa = 0.044715 + + beta = torch.tensor(kBeta, dtype=torch.double) + kappa = torch.tensor(kKappa, dtype=torch.double) + one = torch.tensor(1., dtype=torch.double) + half = torch.tensor(0.5, dtype=torch.double) + + self_cube = mul(g, self, mul(g, self, self)) + inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube))) + return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner)))) + else: + _sqrt2 = 1.4142135623730951 + erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double))) + erf_plusone = add(g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double))) + return mul(g, mul(g, self, erf_plusone), g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double))) @parse_args("v", "i", "v", "v", "f", "i") def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): diff --git a/torch/overrides.py b/torch/overrides.py index 408012ea6e9..76a5fe67069 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -730,7 +730,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1), torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1, - torch.nn.functional.gelu: lambda input: -1, + torch.nn.functional.gelu: lambda input, approximate='none': -1, torch.nn.functional.glu: lambda input, dim=-1: -1, torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1, torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1, diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py index 00ed8072495..4b1058fe35a 100644 --- a/torch/testing/_internal/autocast_test_lists.py +++ b/torch/testing/_internal/autocast_test_lists.py @@ -327,7 +327,8 @@ def __init__(self, dev): self.nn_fp32 = [ ("avg_pool2d", dummy_bf16[2], {"kernel_size": (3, 2), "stride": (1, 1)}), ("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}), - ("gelu", dummy_bf16[3]), + ("gelu", dummy_bf16[3], {"approximate": 'none'}), + ("gelu", dummy_bf16[3], {"approximate": 'tanh'}), ("upsample_nearest1d", dummy_bf16[2], {"output_size": (n)}), ("upsample_nearest2d", dummy_bf16[3], {"output_size": (n, n)}), ("upsample_nearest3d", dummy_bf16[4], {"output_size": (n, n, n)}), diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index c25d04ebe61..e9332bd4f01 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -43,6 +43,7 @@ has_scipy_fft = False if TEST_SCIPY: + from scipy import stats import scipy.special try: import scipy.fft @@ -3903,7 +3904,6 @@ def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs): # With `None` weight and bias (tests failing for this, see the link above) # yield SampleInput(make_arg((1, 2)), args=((2,), None, make_arg((2,)))) - def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -3925,7 +3925,6 @@ def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kw for input_shape, size, kwargs in cases: yield SampleInput(make_arg(input_shape), args=(size,), kwargs=kwargs) - def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs): N = 5 # make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ? @@ -4080,10 +4079,16 @@ def shape(size, rank, with_batch_channel=True): return sample_inputs + def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs): N = 5 - tensors = [SampleInput(make_tensor((N * 2, N * 2), device=device, dtype=dtype, - requires_grad=requires_grad, low=-3, high=3)) for _ in range(1, N)] + tensors = [] + for _ in range(1, N): + for approximate in ['none', 'tanh']: + tensors.append(SampleInput( + make_tensor((N * 2, N * 2), device=device, dtype=dtype, + requires_grad=requires_grad, low=-3, high=3), + kwargs=dict(approximate=approximate))) return tensors def sample_inputs_max_min_reduction_with_dim(op_info, device, dtype, requires_grad, **kwargs): @@ -7965,6 +7970,20 @@ def reference_softplus(input, beta=1, threshold=20): output[non_linear] = np.log(1 + np.exp(beta * input[non_linear])) / beta return output +def reference_gelu(X, *, approximate='none'): + def _gelu_ref(X): + return X * stats.norm.cdf(X) + + def _tanh_gelu_ref(X): + M_SQRT_2_PI = math.sqrt(2 / math.pi) + Z = M_SQRT_2_PI * (X + 0.044715 * np.power(X, 3.0)) + return 0.5 * X * (1.0 + np.tanh(Z)) + + if approximate == 'tanh': + return _tanh_gelu_ref(X) + else: + return _gelu_ref(X) + def reference_one_hot(a: np.ndarray, num_classes: int = -1) -> np.ndarray: if num_classes == -1: @@ -11772,6 +11791,7 @@ def ref_pairwise_distance(input1, input2): ), OpInfo('nn.functional.gelu', aten_name="gelu", + ref=reference_gelu if TEST_SCIPY else _NOTHING, supports_autograd=True, assert_autodiffed=True, sample_inputs_func=sample_inputs_gelu, diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 599e4fecabe..dadcac9285f 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -3716,12 +3716,16 @@ def unsqueeze_inp(inp): ), dict( module_name='GELU', + constructor_args=('none',), + cpp_constructor_args='torch::nn::GELUOptions().approximate(\"none\")', input_size=(), desc='scalar', reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))), ), dict( module_name='GELU', + constructor_args=('none',), + cpp_constructor_args='torch::nn::GELUOptions().approximate(\"none\")', input_size=(3, 2, 5), reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))), ),