Skip to content

Commit

Permalink
GELU (#7132)
Browse files Browse the repository at this point in the history
* GELU activation

* GELU derivative

* GELU wrappers

* optional PI definition

* fast gelu & derivative
  • Loading branch information
raver119 committed Feb 9, 2019
1 parent 62f3895 commit d3cb594
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 9 deletions.
4 changes: 3 additions & 1 deletion libnd4j/include/loops/legacy_ops.h
Expand Up @@ -159,7 +159,9 @@
(49, LogSigmoid), \
(50, Erfc) ,\
(51, Expm1), \
(52, ATanh)
(52, ATanh) ,\
(53, GELU) ,\
(54, GELUDerivative)

// these ops return one of FLOAT data types
#define TRANSFORM_FLOAT_OPS \
Expand Down
61 changes: 53 additions & 8 deletions libnd4j/include/ops/ops.h
Expand Up @@ -1509,15 +1509,60 @@ namespace simdOps {
};

template <typename X>
class Swish {
public:
no_op_exec_special_same
no_op_exec_special_same_cuda
class Swish {
public:
no_op_exec_special_same
no_op_exec_special_same_cuda

op_def static X op(X d1, X *params) {
return d1 * nd4j::math::nd4j_sigmoid<X,X>(d1);
}
};
op_def static X op(X d1, X *params) {
return d1 * nd4j::math::nd4j_sigmoid<X,X>(d1);
}
};

template <typename X>
class GELU {
public:
no_op_exec_special_same
no_op_exec_special_same_cuda

op_def static X op(X d1, X *params) {
bool precise = params != nullptr && params[0] > static_cast<X>(0.0f) ? true : false;

if (precise) {
auto sp = nd4j::math::nd4j_sqrt<X, X>(static_cast<X>(2) / static_cast<X>(M_PI));
auto xp = d1 + nd4j::math::nd4j_pow<X, X, X>(static_cast<X>(0.044715) * d1, static_cast<X>(3));
return (d1 / static_cast<X>(2)) * (static_cast<X>(1) + nd4j::math::nd4j_tanh<X, X>(sp * xp));
} else {
return d1 * nd4j::math::nd4j_sigmoid<X,X>(static_cast<X>(1.702f) * d1);
}
}
};

template <typename X>
class GELUDerivative {
public:
no_op_exec_special_same
no_op_exec_special_same_cuda

op_def static X op(X d1, X *params) {
bool precise = params != nullptr && params[0] > static_cast<X>(0.0f) ? true : false;

if (precise) {
auto x79 = static_cast<X>(0.797885) * d1;
auto x03 = nd4j::math::nd4j_pow<X, int, X>(static_cast<X>(0.0356774) * d1, 3);
auto x39 = static_cast<X>(0.398942) * d1;
auto x05 = nd4j::math::nd4j_pow<X, int, X>(static_cast<X>(0.0535161) * d1, 3);
auto scz = nd4j::math::nd4j_sech<X, X>(x79 + x03);
// 0.5 + (0.398942 x + 0.0535161 x^3) Sech[0.797885 x + 0.0356774 x^3]^2 + 0.5 Tanh[0.797885 x + 0.0356774 x^3]
return static_cast<X>(0.5) + (x39 + x05) * nd4j::math::nd4j_pow<X, int, X>(scz, 2) + static_cast<X>(0.5) * nd4j::math::nd4j_tanh<X, X>(x79 + x03);
} else {
auto x17 = static_cast<X>(1.702f) * d1;
auto ep = nd4j::math::nd4j_pow<X,X,X>(static_cast<X>(M_E), x17);
// (E^(1.702 x) (1. + E^(1.702 x) + 1.702 x))/(1. + E^(1.702 x))^2
return (ep * (static_cast<X>(1.f) + ep + x17)) / nd4j::math::nd4j_pow<X, int, X>((static_cast<X>(1.f) + ep), 2);
}
}
};


template <typename X>
Expand Down
12 changes: 12 additions & 0 deletions libnd4j/include/templatemath.h
Expand Up @@ -39,6 +39,10 @@
#define M_E 2.718281828459
#endif

#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif


namespace nd4j {
#ifdef __CUDACC__
Expand Down Expand Up @@ -250,6 +254,9 @@ namespace nd4j {
template<typename T, typename Z>
math_def inline Z nd4j_acos(T val);

template<typename T, typename Z>
math_def inline Z nd4j_sech(T val);

template<typename T, typename Z>
math_def inline Z nd4j_acosh(T val);

Expand Down Expand Up @@ -606,6 +613,11 @@ namespace nd4j {
return p_acos<Z>(static_cast<Z>(val));
}

template <typename X, typename Z>
math_def inline Z nd4j_sech(X val) {
return static_cast<Z>(1) / nd4j_cosh<X,Z>(val);
}

template <typename X, typename Z>
math_def inline Z nd4j_acosh(X val) {
return p_acosh<Z>(static_cast<Z>(val));
Expand Down
Expand Up @@ -1123,6 +1123,9 @@ public SDVariable swishDerivative(SDVariable iX) {
return new SwishDerivative(sameDiff(), iX, false).outputVariable();
}

public SDVariable geluDerivative(SDVariable iX) {
return new GELUDerivative(sameDiff(), iX, false).outputVariable();
}

public SDVariable sign(SDVariable iX) {
return new Sign(sameDiff(), iX, false).outputVariable();
Expand Down
@@ -0,0 +1,77 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.nd4j.linalg.api.ops.impl.transforms.strict;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;

import java.util.Arrays;
import java.util.List;

/**
* GELU function
*
* @author raver119@gmail.com
*/
public class GELU extends BaseTransformStrictOp {
public GELU(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}

public GELU() {
}

public GELU(INDArray x, INDArray z) {
super(x, z);
}

public GELU(INDArray ndArray) {
super(ndArray);
}

@Override
public int opNum() {
return 53;
}

@Override
public String opName() {
return "gelu";
}

@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}

@Override
public String tensorflowName() {
return "GELU";
}


@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = f().geluDerivative(arg()).mul(i_v.get(0));
return Arrays.asList(ret);
}


}
@@ -0,0 +1,79 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.nd4j.linalg.api.ops.impl.transforms.strict;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;

import java.util.List;

/**
* GELU derivative
*
* @author Adam Gibson
*/
public class GELUDerivative extends BaseTransformStrictOp {
public GELUDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}

public GELUDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}

public GELUDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}

public GELUDerivative() {}

public GELUDerivative(INDArray x, INDArray z) {
super(x, z);
}

public GELUDerivative(INDArray x) {
super(x);
}

@Override
public int opNum() {
return 54;
}

@Override
public String opName() {
return "_geluderivative";
}

@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}

@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}

@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException();
}
}

0 comments on commit d3cb594

Please sign in to comment.