Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented FTSwish Activation Function #3485

Merged
merged 2 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
114 changes: 114 additions & 0 deletions src/mlpack/methods/ann/layer/ftswish.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/**
* @file methods/ann/layer/flatten_t_swish.hpp
* @author Mayank Raj
*
* Definition of Flatten T Swish layer first introduced in the acoustic model,
* Hock Hung Chieng, Noorhaniza Wahid, Pauline Ong, Sai Raj Kishore Perla,
* "Flatten-T Swish: a thresholded ReLU-Swish-like activation function for deep learning", 2018
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_ANN_LAYER_FTSWISH_HPP
#define MLPACK_METHODS_ANN_LAYER_FTSWISH_HPP

#include <mlpack/prereqs.hpp>

#include "layer.hpp"

namespace mlpack {

/**
* The Flatten T Swish activation function, defined by
*
* @f{eqnarray*}{
* f'(x) &=& \left\{
* \begin{array}{lr}
* frac{x}{1+exp(-x)} + T & : x \ge 0 \\
* T & : x < 0
* \end{array}
* \right. \\
* f'(x) &=& \left\{
* \begin{array}{lr}
* \sigma(x)(1 - f(x)) + f(x) & : x > 0 \\
* 0 & : x \le 0
* \end{array}
* \right.
* @f}
*
* @tparam MatType Matrix representation to accept as input and use for
* computation.
*/
template<typename MatType = arma::mat>
class FTSwishType : public Layer<MatType>
{
public:
/**
* Create the Flatten T Swish object using the specified parameters.
* The thresholded value T can be adjusted via T paramaters.
* When the x is < 0, T will be used instead of 0.
* The default value of T is -0.20 as suggested in the paper.
* @param T
*/
FTSwishType(const double T = -0.20);

//! Clone the FTSwishType object. This handles polymorphism correctly.
FTSwishType* Clone() const { return new FTSwishType(*this); }

// Virtual destructor.
virtual ~FTSwishType() { }

//! Copy the given FTSwishType.
FTSwishType(const FTSwishType& other);
//! Take ownership of the given FTSwishType.
FTSwishType(FTSwishType&& other);
//! Copy the given FTSwishType.
FTSwishType& operator=(const FTSwishType& other);
//! Take ownership of the given FTSwishType.
FTSwishType& operator=(FTSwishType&& other);

/**
* Ordinary feed forward pass of a neural network, evaluating the function
* f(x) by propagating the activity forward through f.
*
* @param input Input data used for evaluating the specified function.
* @param output Resulting output activation.
*/
void Forward(const MatType& input, MatType& output);

/**
* Ordinary feed backward pass of a neural network, calculating the function
* f(x) by propagating x backwards through f. Using the results from the feed
* forward pass.
*
* @param input The propagated input activation.
* @param gy The backpropagated error.
* @param g The calculated gradient.
*/
void Backward(const MatType& input, const MatType& gy, MatType& g);

//! Get the threshold value.
double const& Threshold() const { return T; }
//! Modify the threshold value.
double& Threshold() { return T; }

//! Serialize the layer.
template<typename Archive>
void serialize(Archive& ar, const uint32_t /* version */);

private:
//! Threshold value for x < 0.
double T;
}; // class FTSwishType

// Convenience typedefs.
typedef FTSwishType<arma::mat> FTSwish;

} // namespace mlpack

// Include implementation.
#include "ftswish_impl.hpp"

#endif
119 changes: 119 additions & 0 deletions src/mlpack/methods/ann/layer/ftswish_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/**
* @file methods/ann/layer/ftswish_impl.hpp
* @author Mayank Raj
*
* Definition of Flatten T Swish layer first introduced in the acoustic model,
* Hock Hung Chieng, Noorhaniza Wahid, Pauline Ong, Sai Raj Kishore Perla,
*
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_ANN_LAYER_FTSWISH_IMPL_HPP
#define MLPACK_METHODS_ANN_LAYER_FTSWISH_IMPL_HPP

// In case it hasn't yet been included.
#include "ftswish.hpp"

namespace mlpack {

template<typename MatType>
FTSwishType<MatType>::FTSwishType(const double T) :
Layer<MatType>(),
T(T)
{
// Nothing to do here.
}

template<typename MatType>
FTSwishType<MatType>::FTSwishType(const FTSwishType& other) :
Layer<MatType>(other),
T(other.T)
{
// Nothing to do.
}

template<typename MatType>
FTSwishType<MatType>::FTSwishType(FTSwishType&& other) :
Layer<MatType>(std::move(other)),
T(std::move(other.T))
{
// Nothing to do.
}

template<typename MatType>
FTSwishType<MatType>&
FTSwishType<MatType>::operator=(const FTSwishType& other)
{
if (&other != this)
{
Layer<MatType>::operator=(other);
T = other.T;
}

return *this;
}

template<typename MatType>
FTSwishType<MatType>&
FTSwishType<MatType>::operator=(FTSwishType&& other)
{
if (&other != this)
{
Layer<MatType>::operator=(std::move(other));
T = std::move(other.T);
}

return *this;
}

template<typename MatType>
void FTSwishType<MatType>::Forward(const MatType& input, MatType& output)
{
#pragma omp for
for (size_t i = 0; i < (size_t) input.n_elem; ++i)
{
if (input(i) >= 0)
output(i) = input(i) / (1 + std::exp(-input(i))) + T;
else
output(i) = T;
}
}

template<typename MatType>
void FTSwishType<MatType>::Backward(
const MatType& input, const MatType& gy, MatType& g)
{
#pragma omp for
for (size_t i = 0; i < (size_t) input.n_elem; ++i)
{
if (input(i) >= 0)
{
const double f_x = input(i) / (1 + std::exp(-input(i)));
mayank-root marked this conversation as resolved.
Show resolved Hide resolved
const double sigmoid_x = 1 / (1 + std::exp(-input(i)));

g(i) = gy(i) * (sigmoid_x * (1 - f_x) + f_x);
mayank-root marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
g(i) = 0;
}
}
}

template<typename MatType>
template<typename Archive>
void FTSwishType<MatType>::serialize(
Archive& ar,
const uint32_t /* version */)
{
ar(cereal::base_class<Layer<MatType>>(this));

ar(CEREAL_NVP(T));
}

} // namespace mlpack

#endif
1 change: 1 addition & 0 deletions src/mlpack/methods/ann/layer/layer_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include <mlpack/methods/ann/layer/radial_basis_function.hpp>
#include <mlpack/methods/ann/layer/softmax.hpp>
#include <mlpack/methods/ann/layer/softmin.hpp>
#include <mlpack/methods/ann/layer/ftswish.hpp>

// Convolution modes.
#include <mlpack/methods/ann/convolution_rules/border_modes.hpp>
Expand Down
1 change: 1 addition & 0 deletions src/mlpack/methods/ann/layer/serialization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
CEREAL_REGISTER_TYPE(mlpack::SoftmaxType<__VA_ARGS__>); \
CEREAL_REGISTER_TYPE(mlpack::SoftminType<__VA_ARGS__>); \
CEREAL_REGISTER_TYPE(mlpack::HardTanHType<__VA_ARGS__>); \
CEREAL_REGISTER_TYPE(mlpack::FTSwishType<__VA_ARGS__>); \

CEREAL_REGISTER_MLPACK_LAYERS(arma::mat);

Expand Down
62 changes: 62 additions & 0 deletions src/mlpack/tests/ann/layer/ftswish.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* @file tests/ann/layer/hard_tanh.cpp
mayank-root marked this conversation as resolved.
Show resolved Hide resolved
* @author Mayank Raj
*
* Tests the FTSwish layer.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#include <mlpack/core.hpp>
#include <mlpack/methods/ann.hpp>

#include "../../test_catch_tools.hpp"
#include "../../catch.hpp"
#include "../../serialization.hpp"
#include "../ann_test_tools.hpp"

using namespace mlpack;

/**
* Simple test case for the FTSwish layer.
*/
TEST_CASE("FTSwishTest", "[ANNLayerTest]")
{
// Set the threshold value for the FTSwish layer.
double threshold = -0.2;

// Create the FTSwish layer.
FTSwishType<> layer(threshold);

// Input and output matrices.
arma::mat input = {{0.234, 1.23, -1.34},
{1.45, 2.001, -0.98},
{-3.14, 3.43, 9.9}};
arma::mat actualOutput = {{-0.06937312, 0.75179685, -0.2},
{0.97449773, 1.56268497, -0.2},
{-0.2, 3.1223977, 9.6995033 }};
arma::mat output;
output.set_size(3,3);
// Forward pass.
layer.Forward(input, output);

// Test the Forward function
REQUIRE(abs(arma::accu(output - actualOutput)) <= 0.0001);

arma::mat delta = {{0.0 ,0.84327731, 0.0},
{0.91985943, 1.05058055 ,0.0 },
{0.0, 1.08399125 ,1.00053333}};

arma::mat gy, g;
gy.set_size(3,3);
gy.fill(1);
g.set_size(3,3);
// Backward pass.
layer.Backward(output, gy, g);

//Test the Backward function
REQUIRE(abs(arma::accu(g - delta)) <= 0.0001);
}
1 change: 1 addition & 0 deletions src/mlpack/tests/ann/layer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@
#include "layer/parametric_relu.cpp"
#include "layer/softmax.cpp"
#include "layer/softmin.cpp"
#include "layer/ftswish.cpp"