Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt MultiheadAttention and LayerNorm to new Layer interface #3547

Merged
merged 160 commits into from Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 156 commits
Commits
Show all changes
160 commits
Select commit Hold shift + click to select a range
53cd8e7
Adapt MultiheadAttention to new Layer interface.
akropp Jun 7, 2023
727323b
Adapt LayerNorm to new Layer interface.
akropp Jun 7, 2023
a79ca49
Adapt MultiheadAttention to new Layer interface.
akropp Jun 7, 2023
8ef9f59
Adapt LayerNorm to new Layer interface.
akropp Jun 7, 2023
beb2873
Pass the correct "input" values to the sub-layers during the Backward…
akropp May 25, 2023
0a48537
Pass the correct "input" values to the sub-layers during the Backward…
akropp May 25, 2023
4de7035
Use MakeAlias to slice input data
Oct 19, 2023
ed0b423
Remove unnecessary template arg
Oct 19, 2023
b686012
Remove unnecessary template arg
Oct 19, 2023
441248f
Update layer_norm.hpp
akropp Oct 20, 2023
dcf4ce3
Adapt multihead_attention
Jun 7, 2023
89b78f8
Add self-attention flag to MultiheadAttention layer
Jun 13, 2023
7993469
Fix LayerNorm
Jun 13, 2023
2791789
Fix references to "input" in the Backwards call.
Jun 7, 2023
a0c3747
Update softmax_impl.hpp
akropp Oct 23, 2023
15c2c0f
Adding input and output parameters to Backward() method on Layer.
akropp Oct 27, 2023
b26a2fc
Merge remote-tracking branch 'upstream/master' into adapt_multihead
akropp Oct 31, 2023
91a62a5
Update src/mlpack/methods/ann/ffn_impl.hpp
akropp Nov 1, 2023
efae2fd
Update src/mlpack/methods/ann/ffn_impl.hpp
akropp Nov 1, 2023
a4ab4a6
Update src/mlpack/methods/ann/ffn_impl.hpp
akropp Nov 1, 2023
a3466f2
Update src/mlpack/methods/ann/ffn_impl.hpp
akropp Nov 1, 2023
e4309d0
Update src/mlpack/methods/ann/ffn_impl.hpp
akropp Nov 1, 2023
b891e25
Update src/mlpack/methods/ann/ffn_impl.hpp
akropp Nov 1, 2023
aec1c16
Update src/mlpack/methods/ann/layer/add.hpp
akropp Nov 1, 2023
e5320bd
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 1, 2023
e8e8432
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 1, 2023
77e41af
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 1, 2023
9c2bf7f
Update src/mlpack/methods/ann/layer/alpha_dropout.hpp
akropp Nov 1, 2023
dd3e259
Update src/mlpack/methods/ann/layer/batch_norm.hpp
akropp Nov 1, 2023
4723829
Update src/mlpack/methods/ann/layer/c_relu_impl.hpp
akropp Nov 1, 2023
bad9160
Update src/mlpack/methods/ann/layer/celu.hpp
akropp Nov 1, 2023
b7c25a8
Update src/mlpack/methods/ann/layer/concat.hpp
akropp Nov 1, 2023
6417e30
Update src/mlpack/methods/ann/layer/concat.hpp
akropp Nov 1, 2023
c7d94a6
Update src/mlpack/methods/ann/layer/concat_impl.hpp
akropp Nov 1, 2023
1dab4e2
Update src/mlpack/methods/ann/layer/concat_impl.hpp
akropp Nov 1, 2023
72a26ba
Update src/mlpack/methods/ann/layer/concatenate.hpp
akropp Nov 1, 2023
06c6dd9
Update src/mlpack/methods/ann/layer/convolution.hpp
akropp Nov 1, 2023
4bae489
Update src/mlpack/methods/ann/layer/dropout.hpp
akropp Nov 1, 2023
bf18a3d
Update src/mlpack/methods/ann/layer/elu.hpp
akropp Nov 1, 2023
28328e2
Update src/mlpack/methods/ann/layer/grouped_convolution.hpp
akropp Nov 1, 2023
b2b0590
Update src/mlpack/methods/ann/layer/identity.hpp
akropp Nov 1, 2023
5cc5e9b
Update src/mlpack/methods/ann/layer/leaky_relu.hpp
akropp Nov 1, 2023
50c8f09
Update src/mlpack/methods/ann/layer/linear.hpp
akropp Nov 1, 2023
bc08ce2
Update src/mlpack/methods/ann/layer/linear3d.hpp
akropp Nov 1, 2023
9aa7e1d
Update src/mlpack/methods/ann/layer/linear_no_bias.hpp
akropp Nov 1, 2023
2f8bf78
Update src/mlpack/methods/ann/layer/lstm.hpp
akropp Nov 1, 2023
a8782e6
Update src/mlpack/methods/ann/layer/log_softmax.hpp
akropp Nov 1, 2023
845ae5b
Update src/mlpack/methods/ann/layer/max_pooling.hpp
akropp Nov 1, 2023
b35fbc0
Update src/mlpack/methods/ann/layer/mean_pooling.hpp
akropp Nov 1, 2023
83f89f7
Update src/mlpack/methods/ann/layer/multi_layer_impl.hpp
akropp Nov 1, 2023
8505b46
Update src/mlpack/methods/ann/layer/multi_layer_impl.hpp
akropp Nov 1, 2023
d874e78
Update src/mlpack/methods/ann/layer/noisylinear.hpp
akropp Nov 1, 2023
85d7e44
Update src/mlpack/methods/ann/layer/padding.hpp
akropp Nov 1, 2023
d2c8e89
Update src/mlpack/methods/ann/layer/parametric_relu.hpp
akropp Nov 1, 2023
8bbb2d9
Update src/mlpack/methods/ann/layer/relu6.hpp
akropp Nov 1, 2023
0872c2a
Update src/mlpack/methods/ann/layer/softmax.hpp
akropp Nov 1, 2023
7acd8a8
Update src/mlpack/methods/ann/layer/softmin.hpp
akropp Nov 1, 2023
38a1fd8
Update src/mlpack/tests/ann/activation_functions_test.cpp
akropp Nov 1, 2023
f46312f
Update src/mlpack/tests/ann/layer/parametric_relu.cpp
akropp Nov 1, 2023
103db42
Update src/mlpack/methods/ann/layer/layer_norm.hpp
akropp Nov 1, 2023
99348fc
Update src/mlpack/methods/ann/layer/layer_norm.hpp
akropp Nov 1, 2023
490b55f
Update src/mlpack/methods/ann/layer/layer_norm.hpp
akropp Nov 1, 2023
99bdb8f
Update src/mlpack/methods/ann/layer/layer_norm_impl.hpp
akropp Nov 1, 2023
4e282dd
Update src/mlpack/methods/ann/layer/layer_norm_impl.hpp
akropp Nov 2, 2023
7d84005
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 2, 2023
bec84f4
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 2, 2023
9388c6f
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 2, 2023
1907e52
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 2, 2023
b79f690
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 2, 2023
346ea05
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 2, 2023
6dd71c1
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 2, 2023
21fe33e
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 2, 2023
c482407
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 2, 2023
d9e7d61
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 2, 2023
b6a2c58
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 2, 2023
45f88d1
Clean-up to fix compilation/test.
akropp Nov 2, 2023
e24bfd2
Remove input sizing (srcSeqLen, embedDim) from constructor.
akropp Nov 2, 2023
8dce94d
Update multihead_attention_impl.hpp
akropp Nov 2, 2023
b0a60bd
Remove dependency on input in Backward call
akropp Nov 2, 2023
636c4fc
consolidate constructors with default arg
akropp Nov 2, 2023
8cec72d
(correctly) change Backward to use input instead of output
akropp Nov 2, 2023
8401f64
Doc updates
akropp Nov 2, 2023
ce846ea
Defining Backward in terms of the inputs
akropp Nov 2, 2023
a553782
Optimization of Backward()
akropp Nov 2, 2023
20a48b8
Unnecessary alias
akropp Nov 2, 2023
b7ab3ba
Fix ftswish to use input instead of output
akropp Nov 3, 2023
360c17c
Fix c_relu test and backward method
akropp Nov 3, 2023
4596810
Add LayerNorm and MultiAttentionHead tests
akropp Nov 3, 2023
cc99b4a
Update activation functions to take input and output values
akropp Nov 8, 2023
517bb0b
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 8, 2023
ef3183e
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 8, 2023
c40787c
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 8, 2023
1b8313e
Update src/mlpack/methods/ann/layer/flexible_relu_impl.hpp
akropp Nov 8, 2023
f6d78ab
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 8, 2023
d91c5f0
Update src/mlpack/methods/ann/layer/c_relu_impl.hpp
akropp Nov 8, 2023
65b4878
Clairfy input dimensions documentation, check valid input dims
akropp Nov 8, 2023
7b5e4d2
Merge remote-tracking branch 'upstream/master' into adapt_multihead
akropp Nov 9, 2023
3363707
Add in LayerNorm Test
akropp Nov 10, 2023
2bdae7d
Add MultiheadAttention test, fix dimension calculation
akropp Nov 10, 2023
39f7a10
Comment out unused params
akropp Nov 10, 2023
c990b93
Comment unused parameter
akropp Nov 10, 2023
f5be365
Correct calculation of srcSeqLen with more than 2 input dimensions
akropp Nov 13, 2023
eb3e7f2
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 13, 2023
5f385ff
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 13, 2023
89f93d1
Update src/mlpack/tests/ann/layer/layer_norm.cpp
akropp Nov 13, 2023
fb94b6d
Update src/mlpack/tests/ann/activation_functions_test.cpp
akropp Nov 13, 2023
008aabb
Update src/mlpack/tests/ann/layer/multihead_attention.cpp
akropp Nov 13, 2023
9451a8f
Update src/mlpack/tests/ann/activation_functions_test.cpp
akropp Nov 13, 2023
68f15ed
Update src/mlpack/methods/ann/activation_functions/elish_function.hpp
akropp Nov 13, 2023
d9f9e4f
Style changes, author
akropp Nov 13, 2023
214a731
Merge remote-tracking branch 'origin/adapt_multihead' into adapt_mult…
akropp Nov 13, 2023
a4ead8b
author
akropp Nov 13, 2023
e3a293b
simplify dy expression
akropp Nov 13, 2023
ae578a7
Update src/mlpack/methods/ann/activation_functions/gaussian_function.hpp
akropp Nov 13, 2023
7846bc0
Update src/mlpack/methods/ann/activation_functions/gaussian_function.hpp
akropp Nov 13, 2023
bc94bb9
Update src/mlpack/methods/ann/activation_functions/mish_function.hpp
akropp Nov 13, 2023
8f89db7
Update src/mlpack/methods/ann/activation_functions/lisht_function.hpp
akropp Nov 13, 2023
8de759e
Update src/mlpack/methods/ann/activation_functions/gelu_function.hpp
akropp Nov 13, 2023
1456ab9
Update src/mlpack/methods/ann/activation_functions/hard_swish_functio…
akropp Nov 13, 2023
18bac82
Update src/mlpack/methods/ann/activation_functions/hard_sigmoid_funct…
akropp Nov 13, 2023
d99c88e
Update src/mlpack/methods/ann/activation_functions/rectifier_function…
akropp Nov 13, 2023
6d6768f
Update src/mlpack/methods/ann/activation_functions/multi_quadratic_fu…
akropp Nov 13, 2023
4f66883
Update src/mlpack/methods/ann/activation_functions/poisson1_function.hpp
akropp Nov 13, 2023
7ab7d41
Update src/mlpack/methods/ann/activation_functions/quadratic_function…
akropp Nov 13, 2023
ce538a8
Update src/mlpack/methods/ann/activation_functions/rectifier_function…
akropp Nov 13, 2023
8914a1d
Update src/mlpack/methods/ann/activation_functions/silu_function.hpp
akropp Nov 13, 2023
e8f64d5
Update src/mlpack/methods/ann/activation_functions/softplus_function.hpp
akropp Nov 13, 2023
c8c8dbe
Update src/mlpack/methods/ann/activation_functions/softsign_function.hpp
akropp Nov 13, 2023
1cf92e5
Update src/mlpack/methods/ann/activation_functions/swish_function.hpp
akropp Nov 13, 2023
26b0311
Update src/mlpack/methods/ann/activation_functions/tanh_exponential_f…
akropp Nov 13, 2023
cd2c13e
Update src/mlpack/methods/ann/activation_functions/tanh_function.hpp
akropp Nov 13, 2023
4244a33
Update src/mlpack/methods/ann/activation_functions/spline_function.hpp
akropp Nov 13, 2023
d004462
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 13, 2023
dfcf40a
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 13, 2023
4cb0cd1
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 13, 2023
94bb5ff
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 13, 2023
5c7fde6
Fixes
akropp Nov 13, 2023
deb68b9
Fix deriv
akropp Nov 13, 2023
1539b6e
Remove commentary
akropp Nov 13, 2023
62cedfd
Fix conditionals
akropp Nov 13, 2023
a389184
Update elish_function.hpp
akropp Nov 14, 2023
b788e4a
Perf improvement
akropp Nov 14, 2023
fdb29d3
Update elish_function.hpp
akropp Nov 14, 2023
d90282a
Update elish_function.hpp
akropp Nov 14, 2023
677a6d8
input dimension size calc was wrong in Reset()
akropp Nov 15, 2023
8ad2e22
Temporary debugging spew to see why this fails in azure
akropp Nov 15, 2023
74f42a4
Fix bug with old armadillo version
akropp Nov 16, 2023
130b075
Update src/mlpack/methods/ann/activation_functions/hard_swish_functio…
akropp Nov 16, 2023
27ad35c
Update src/mlpack/methods/ann/layer/layer_norm.hpp
akropp Nov 16, 2023
2328247
Update src/mlpack/methods/ann/layer/multihead_attention_impl.hpp
akropp Nov 16, 2023
def2346
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 16, 2023
13cb8c2
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 16, 2023
0d4eb70
Update src/mlpack/methods/ann/layer/multihead_attention.hpp
akropp Nov 16, 2023
511fe69
Add author lines
akropp Nov 16, 2023
8e0039d
Update HISTORY.md
akropp Nov 16, 2023
df24810
Merge branch 'adapt_multihead' of https://github.com/akropp/mlpack in…
akropp Nov 16, 2023
7bce96a
Update src/mlpack/methods/ann/ffn_impl.hpp
akropp Nov 20, 2023
0f7f472
Update src/mlpack/methods/ann/ffn_impl.hpp
akropp Nov 20, 2023
27eae7c
Update HISTORY.md
akropp Nov 20, 2023
cb7c6d3
Update HISTORY.md
akropp Nov 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Expand Up @@ -11,6 +11,10 @@
* Fix setting number of classes correctly in `SoftmaxRegression::Train()`
(#3553).

* Adapt MultiheadAttention and LayerNorm to new Layer interface (#3547)
akropp marked this conversation as resolved.
Show resolved Hide resolved

* Inconsistent use of the "input" parameter to the Backward method in ANNs (#3551)
akropp marked this conversation as resolved.
Show resolved Hide resolved

### mlpack 4.2.1
###### 2023-09-05
* Reinforcement Learning: Gaussian noise (#3515).
Expand Down
Expand Up @@ -5,6 +5,24 @@
* Convenience include for all activation functions implemented for mlpack's
* neural network toolkit.
*
* An activation function should define methods to evaluate the function
* and its derivative.
*
* For the forward pass, a class should define
* static double Fn(double x) -- evaluate y = F(x) at a single point
* and
* static void Fn(const InputVecType& x, OutputVecType& y) -- evaluate y = F(x)
* for a vector
*
* For the backward pass, a class should define the derivative function. For
* efficiency of implementation, it will be provided both x (the inputs) and
* y (the result of F(x)). The following should be defined
* static double Deriv(double x, double y) -- evaluate dF(x)/dx for one value
* of x given both x and y=F(x)
* static void Deriv(const InputVecType& x, const OutputVecType& y,
* DerivVecType& dy) -- evaluate dF(x)/dx for a vector x
* and a vector y=F(x)
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
Expand Down
Expand Up @@ -54,24 +54,26 @@ class BipolarSigmoidFunction
/**
* Computes the first derivative of the Bipolar Sigmoid function.
*
* @param y Input activation.
* @param x Input activation.
* @param y Result of Fn(x).
* @return f'(x)
*/
static double Deriv(const double y)
static double Deriv(const double /* x */, const double y)
{
return (1.0 - std::pow(y,2 )) / 2.0;
}

/**
* Computes the first derivatives of the Bipolar Sigmoid function.
*
* @param y Input activations.
* @param x The resulting derivatives.
* @param x Input activation.
* @param y Result of Fn(x).
* @param dy The resulting derivatives.
*/
template<typename InputVecType, typename OutputVecType>
static void Deriv(const InputVecType& y, OutputVecType& x)
template<typename InputVecType, typename OutputVecType, typename DerivVecType>
static void Deriv(const InputVecType& /* x */, const OutputVecType& y, DerivVecType& dy)
akropp marked this conversation as resolved.
Show resolved Hide resolved
{
x = (1.0 - arma::pow(y, 2)) / 2.0;
dy = (1.0 - arma::pow(y, 2)) / 2.0;
}
}; // class BipolarSigmoidFunction

Expand Down
49 changes: 33 additions & 16 deletions src/mlpack/methods/ann/activation_functions/elish_function.hpp
@@ -1,6 +1,7 @@
/**
* @file methods/ann/activation_functions/elish_function.hpp
* @author Bisakh Mondal
* @author Adam Kropp
*
* Definition and implementation of the ELiSH function as described by
* Mina Basirat and Peter M. Roth.
Expand Down Expand Up @@ -70,40 +71,56 @@ class ElishFunction
template<typename InputVecType, typename OutputVecType>
static void Fn(const InputVecType& x, OutputVecType& y)
{
y = ((x < 0.0) % ((arma::exp(x) -1) / (1 + arma::exp(-x))))
y = ((x < 0.0) % ((arma::exp(x) - 1) / (1 + arma::exp(-x))))
+ ((x >= 0.0) % (x / (1 + arma::exp(-x))));
}

/**
* Computes the first derivatives of ELiSH function.
*
* @param y Input data.
* @param x Input activation.
* @param y Result of Fn(x).
* @return f'(x).
*/
static double Deriv(const double y)
static double Deriv(const double x, const double y)
{
if (y < 0.0)
if (x < 0.0)
{
return std::exp(y) - 2 / (1 + std::exp(y)) +
2 / std::pow(1 + std::exp(y) , 2);
return std::exp(x) - 2 / (1 + std::exp(x)) +
2 / std::pow(1 + std::exp(x) , 2);
}
else if (x == 0) {
return 0.5; // the expression below is indeterminate at 0, even though
// the expression solely in terms of x is defined (= 0.5)
} else {
return (y / x) * (1 + x - y);
rcurtin marked this conversation as resolved.
Show resolved Hide resolved
}

return 1 / (1 + std::exp(-y)) + y * std::exp(-y) /
std::pow(1 + std::exp(-y) , 2);
}

/**
* Computes the first derivatives of the ELiSH function.
*
* @param y Input data.
* @param x The resulting derivatives.
* @param x Input activation.
* @param y Result of Fn(x).
* @param dy The resulting derivatives.
*/
template<typename InputVecType, typename OutputVecType>
static void Deriv(const InputVecType& y, OutputVecType& x)
template<typename InputVecType, typename OutputVecType, typename DerivVecType>
static void Deriv(const InputVecType& x,
const OutputVecType& y,
DerivVecType& dy)
{
x = ((y < 0.0) % (arma::exp(y) - 2 / (1 + arma::exp(y)) + 2 / arma::pow(
1 + arma::exp(y), 2))) + ((y >= 0.0) % (1 / (1 + arma::exp(-y)) + y %
arma::exp(-y) / arma::pow(1 + arma::exp(-y), 2)));
// simplified the x>=0 part to be in terms of x and y -- maybe
// the x<0 part can be as well?
// the expression is indeterminate at 0, even though
// the expression solely in terms of x is defined (= 0.5)
// only calculate exp(x) once for each element where x < 0
// this gives approx 3x speedup, despite allocating the temp vector
akropp marked this conversation as resolved.
Show resolved Hide resolved
DerivVecType ex = (x < 0) % arma::exp(x);
dy = ((x < 0) % ((ex - 2 / (1 + ex) + 2 / arma::pow(1 + ex, 2)))) +
((x > 0) % ((y / x) % (1.0 + x - y)));
// need to do this here, because the /x above gives nans even when the
// condition is not met (e.g. when x > 0 is false)
dy(arma::find(x == 0)).fill(0.5);
}
}; // class ElishFunction

Expand Down
20 changes: 12 additions & 8 deletions src/mlpack/methods/ann/activation_functions/elliot_function.hpp
Expand Up @@ -65,24 +65,28 @@ class ElliotFunction
/**
* Computes the first derivative of the Elliot function.
*
* @param y Input data.
* @param x Input activation.
* @param y Result of Fn(x).
* @return f'(x).
*/
static double Deriv(const double y)
static double Deriv(const double x, const double /* y */)
{
return std::pow(1.0 - std::abs(y), 2);
return 1.0 / std::pow(1.0 + std::abs(x), 2);
}

/**
* Computes the first derivatives of the Elliot function.
*
* @param y Input activations.
* @param x The resulting derivatives.
* @param x Input activation.
* @param y Result of Fn(x).
* @param dy The resulting derivatives.
*/
template <typename InputVecType, typename OutputVecType>
static void Deriv(const InputVecType &y, OutputVecType &x)
template <typename InputVecType, typename OutputVecType, typename DerivVecType>
static void Deriv(const InputVecType & x,
const OutputVecType& /* y */,
DerivVecType &dy)
{
x = arma::pow(1.0 - arma::abs(y), 2);
dy = 1.0 / arma::pow(1.0 + arma::abs(x), 2);
}
}; // class ElliotFunction

Expand Down
19 changes: 11 additions & 8 deletions src/mlpack/methods/ann/activation_functions/gaussian_function.hpp
@@ -1,6 +1,7 @@
/**
* @file gaussian_function.hpp
* @author Himanshu Pathak
* @author Adam Kropp
*
* Definition and implementation of the gaussian function.
*
Expand Down Expand Up @@ -54,24 +55,26 @@ class GaussianFunction
/**
* Computes the first derivative of the gaussian function.
*
* @param y Input data.
* @param x Input activation.
* @param y Result of Fn(x).
* @return f'(x)
*/
static double Deriv(const double y)
static double Deriv(const double x, const double y)
{
return 2 * -y * std::exp(-1 * std::pow(y, 2));
return -2 * x * y;
}

/**
* Computes the first derivatives of the gaussian function.
*
* @param y Input activations.
* @param x The resulting derivatives.
* @param x Input activation.
* @param y Result of Fn(x).
* @param dy The resulting derivatives.
*/
template<typename InputVecType, typename OutputVecType>
static void Deriv(const InputVecType& y, OutputVecType& x)
template<typename InputVecType, typename OutputVecType, typename DerivVecType>
static void Deriv(const InputVecType& x, const OutputVecType& y, DerivVecType& dy)
{
x = 2 * -y % arma::exp(-1 * arma::pow(y, 2));
dy = -2 * x % y;
}
}; // class GaussianFunction

Expand Down
37 changes: 22 additions & 15 deletions src/mlpack/methods/ann/activation_functions/gelu_function.hpp
@@ -1,6 +1,7 @@
/**
* @file methods/ann/activation_functions/gelu_function.hpp
* @author Himanshu Pathak
* @author Adam Kropp
*
* Definition and implementation of the Gaussian Error Linear Unit (GELU)
* function.
Expand All @@ -22,7 +23,7 @@ namespace mlpack {
*
* @f{eqnarray*}{
* f(x) = 0.5 * x * {1 + tanh[(2/pi)^(1/2) * (x + 0.044715 * x^3)]} \\
* f'(x) = 0.5 * tanh(0.0356774 * x^3) + 0.797885 * x) +
* f'(x) = 0.5 * tanh(0.0356774 * x^3 + 0.797885 * x) +
* (0.0535161x^3 + 0.398942 * x) *
* sech^2(0.0356774 * x^3+0.797885 * x) + 0.5\\
* @f}
Expand Down Expand Up @@ -58,30 +59,36 @@ class GELUFunction
/**
* Computes the first derivative of the GELU function.
*
* @param y Input data.
* @param x Input activation.
* @param y Result of Fn(x).
* @return f'(x)
*/
static double Deriv(const double y)
static double Deriv(const double x, const double /* y */)
{
return 0.5 * std::tanh(0.0356774 * std::pow(y, 3) + 0.797885 * y) +
(0.0535161 * std::pow(y, 3) + 0.398942 * y) *
std::pow(1 / std::cosh(0.0356774 * std::pow(y, 3) +
0.797885 * y), 2) + 0.5;
if (x < -10) return 0.0; // catch overflows
return 0.5 * std::tanh(0.0356774 * std::pow(x, 3) + 0.797885 * x) +
(0.0535161 * std::pow(x, 3) + 0.398942 * x) *
std::pow(1 / std::cosh(0.0356774 * std::pow(x, 3) +
0.797885 * x), 2) + 0.5;
}

/**
* Computes the first derivatives of the GELU function.
*
* @param y Input data.
* @param x The resulting derivatives.
* @param x Input activation.
* @param y Result of Fn(x).
* @param dy The resulting derivatives.
*/
template<typename InputVecType, typename OutputVecType>
static void Deriv(const InputVecType& y, OutputVecType& x)
template<typename InputVecType, typename OutputVecType, typename DerivVecType>
static void Deriv(const InputVecType& x,
const OutputVecType& /* y */,
DerivVecType& dy)
{
x = 0.5 * arma::tanh(0.0356774 * arma::pow(y, 3) + 0.797885 * y) +
(0.0535161 * arma::pow(y, 3) + 0.398942 * y) %
arma::pow(1 / arma::cosh(0.0356774 * arma::pow(y, 3) +
0.797885 * y), 2) + 0.5;
dy = 0.5 * arma::tanh(0.0356774 * arma::pow(x, 3) + 0.797885 * x) +
(0.0535161 * arma::pow(x, 3) + 0.398942 * x) %
arma::pow(1 / arma::cosh(0.0356774 * arma::pow(x, 3) +
0.797885 * x), 2) + 0.5;
dy(arma::find(x < -10)).fill(0); // catch overflows
}
}; // class GELUFunction

Expand Down
Expand Up @@ -63,10 +63,11 @@ class HardSigmoidFunction
/**
* Computes the first derivatives of hard sigmoid function.
*
* @param y Input data.
* @param x Input activation.
* @param y Result of Fn(x).
* @return f'(x)
*/
static double Deriv(const double y)
static double Deriv(const double /* x */, const double y)
{
if (y == 0.0 || y == 1.0)
{
Expand All @@ -78,18 +79,21 @@ class HardSigmoidFunction
/**
* Computes the first derivatives of the hard sigmoid function.
*
* @param y Input data.
* @param x The resulting derivatives.
* @param x Input activation.
* @param y Result of Fn(x).
* @param dy The resulting derivatives.
*/
template<typename InputVecType, typename OutputVecType>
static void Deriv(const InputVecType& y, OutputVecType& x)
template<typename InputVecType, typename OutputVecType, typename DerivVecType>
static void Deriv(const InputVecType& x,
const OutputVecType& y,
DerivVecType& dy)
{
x.set_size(size(y));
dy.set_size(size(y));

#pragma omp for
for (size_t i = 0; i < (size_t) y.n_elem; ++i)
{
x(i) = Deriv(y(i));
dy(i) = Deriv(x(i), y(i));
}
}
}; // class HardSigmoidFunction
Expand Down
28 changes: 16 additions & 12 deletions src/mlpack/methods/ann/activation_functions/hard_swish_function.hpp
Expand Up @@ -81,33 +81,37 @@ class HardSwishFunction
/**
* Computes the first derivative of the Hard Swish function.
*
* @param y Input data.
* @param x Input activation.
* @param * (y) Result of Fn(x).
* @return f'(x).
*/
static double Deriv(const double y)
static double Deriv(const double x, const double /* y */)
{
if (y <= -3)
if (x <= -3)
return 0;
else if (y >= 3)
else if (x >= 3)
return 1;

return (2 * y + 3.0) / 6.0;
return (2 * x + 3.0) / 6.0;
}

/**
* Computes the first derivatives of the Hard Swish function.
*
* @param y Input data.
* @param x The resulting derivatives.
* @param x Input activation.
* @param y Result of Fn(x).
* @param dy The resulting derivatives.
*/
template <typename InputVecType, typename OutputVecType>
static void Deriv(const InputVecType &y, OutputVecType &x)
template <typename InputVecType, typename OutputVecType, typename DerivVecType>
static void Deriv(const InputVecType &x,
const OutputVecType& y,
DerivVecType &dy)
{
x.set_size(size(y));
dy.set_size(size(x));

#pragma omp for
for (size_t i = 0; i < (size_t) y.n_elem; i++)
x(i) = Deriv(y(i));
for (size_t i = 0; i < (size_t) x.n_elem; i++)
dy(i) = Deriv(x(i), y(i));
}
}; // class HardSwishFunction

Expand Down