Skip to content

Commit

Permalink
Merge pull request #2261 from himanshupathak21061998/add-rbf
Browse files Browse the repository at this point in the history
Implementing Simple Radial Basis Function Layer
  • Loading branch information
saksham189 committed Jun 16, 2020
2 parents 5dc47f3 + c93c0aa commit 398a276
Show file tree
Hide file tree
Showing 10 changed files with 477 additions and 3 deletions.
2 changes: 2 additions & 0 deletions HISTORY.md
Expand Up @@ -34,6 +34,8 @@
* Fix incorrect neighbors for `k > 1` searches in `approx_kfn` binding, for
the `QDAFN` algorithm (#2448).

* Add `RBF` layer in ann module to make `RBFN` architecture (#2261).

### mlpack 3.3.1
###### 2020-04-29
* Minor Julia and Python documentation fixes (#2373).
Expand Down
1 change: 1 addition & 0 deletions src/mlpack/methods/ann/activation_functions/CMakeLists.txt
Expand Up @@ -18,6 +18,7 @@ set(SOURCES
spline_function.hpp
multi_quadratic_function.hpp
poisson1_function.hpp
gaussian_function.hpp
)

# Add directory name to sources.
Expand Down
82 changes: 82 additions & 0 deletions src/mlpack/methods/ann/activation_functions/gaussian_function.hpp
@@ -0,0 +1,82 @@
/**
* @file gaussian_function.hpp
* @author Himanshu Pathak
*
* Definition and implementation of the gaussian function.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GAUSSIAN_FUNCTION_HPP
#define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GAUSSIAN_FUNCTION_HPP

#include <mlpack/prereqs.hpp>

namespace mlpack {
namespace ann /** Artificial Neural Network. */ {

/**
* The gaussian function, defined by
*
* @f{eqnarray*}{
* f(x) &=& e^{-1 * x^2} \\
* f'(x) &=& 2 * -x * f(x)
* @f}
*/
class GaussianFunction
{
public:
/**
* Computes the gaussian function.
*
* @param x Input data.
* @return f(x).
*/
template<typename eT>
static double Fn(const eT x)
{
return std::exp(-1 * std::pow(x, 2));
}

/**
* Computes the gaussian function.
*
* @param x Input data.
* @param y The resulting output activation.
*/
template<typename InputVecType, typename OutputVecType>
static void Fn(const InputVecType& x, OutputVecType& y)
{
y = arma::exp(-1 * arma::pow(x, 2));
}

/**
* Computes the first derivative of the gaussian function.
*
* @param y Input data.
* @return f'(x)
*/
static double Deriv(const double y)
{
return 2 * -y * std::exp(-1 * std::pow(y, 2));
}

/**
* Computes the first derivatives of the gaussian function.
*
* @param y Input activations.
* @param x The resulting derivatives.
*/
template<typename InputVecType, typename OutputVecType>
static void Deriv(const InputVecType& y, OutputVecType& x)
{
x = 2 * -y % arma::exp(-1 * arma::pow(y, 2));
}
}; // class GaussianFunction

} // namespace ann
} // namespace mlpack

#endif
2 changes: 2 additions & 0 deletions src/mlpack/methods/ann/layer/CMakeLists.txt
Expand Up @@ -87,6 +87,8 @@ set(SOURCES
reinforce_normal_impl.hpp
reparametrization.hpp
reparametrization_impl.hpp
radial_basis_function.hpp
radial_basis_function_impl.hpp
select.hpp
select_impl.hpp
sequential.hpp
Expand Down
13 changes: 13 additions & 0 deletions src/mlpack/methods/ann/layer/base_layer.hpp
Expand Up @@ -26,6 +26,7 @@
#include <mlpack/methods/ann/activation_functions/gelu_function.hpp>
#include <mlpack/methods/ann/activation_functions/elliot_function.hpp>
#include <mlpack/methods/ann/activation_functions/elish_function.hpp>
#include <mlpack/methods/ann/activation_functions/gaussian_function.hpp>

namespace mlpack {
namespace ann /** Artificial Neural Network. */ {
Expand All @@ -48,6 +49,7 @@ namespace ann /** Artificial Neural Network. */ {
* - GELULayer
* - ELiSHLayer
* - ElliotLayer
* - GaussianLayer
*
* @tparam ActivationFunction Activation function used for the embedding layer.
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
Expand Down Expand Up @@ -264,6 +266,17 @@ template <
using ElishFunctionLayer = BaseLayer<
ActivationFunction, InputDataType, OutputDataType>;

/**
* Standard Gaussian-Layer using the Gaussian activation function.
*/
template <
class ActivationFunction = GaussianFunction,
typename InputDataType = arma::mat,
typename OutputDataType = arma::mat
>
using GaussianFunctionLayer = BaseLayer<
ActivationFunction, InputDataType, OutputDataType>;

} // namespace ann
} // namespace mlpack

Expand Down
12 changes: 10 additions & 2 deletions src/mlpack/methods/ann/layer/layer_types.hpp
Expand Up @@ -49,6 +49,7 @@
#include <mlpack/methods/ann/layer/hardshrink.hpp>
#include <mlpack/methods/ann/layer/celu.hpp>
#include <mlpack/methods/ann/layer/softshrink.hpp>
#include <mlpack/methods/ann/layer/radial_basis_function.hpp>

// Convolution modules.
#include <mlpack/methods/ann/convolution_rules/border_modes.hpp>
Expand Down Expand Up @@ -80,6 +81,11 @@ template<typename InputDataType,
typename RegularizerType>
class Linear;

template<typename InputDataType,
typename OutputDataType,
typename Activation>
class RBF;

template<typename InputDataType,
typename OutputDataType,
typename RegularizerType>
Expand Down Expand Up @@ -209,7 +215,9 @@ using MoreTypes = boost::variant<
Sequential<arma::mat, arma::mat, true>*,
Subview<arma::mat, arma::mat>*,
VRClassReward<arma::mat, arma::mat>*,
VirtualBatchNorm<arma::mat, arma::mat>*
VirtualBatchNorm<arma::mat, arma::mat>*,
RBF<arma::mat, arma::mat, GaussianFunction>*,
BaseLayer<GaussianFunction, arma::mat, arma::mat>*
>;

template <typename... CustomLayers>
Expand All @@ -226,8 +234,8 @@ using LayerTypes = boost::variant<
BaseLayer<LogisticFunction, arma::mat, arma::mat>*,
BaseLayer<IdentityFunction, arma::mat, arma::mat>*,
BaseLayer<TanhFunction, arma::mat, arma::mat>*,
BaseLayer<RectifierFunction, arma::mat, arma::mat>*,
BaseLayer<SoftplusFunction, arma::mat, arma::mat>*,
BaseLayer<RectifierFunction, arma::mat, arma::mat>*,
BatchNorm<arma::mat, arma::mat>*,
BilinearInterpolation<arma::mat, arma::mat>*,
CELU<arma::mat, arma::mat>*,
Expand Down
154 changes: 154 additions & 0 deletions src/mlpack/methods/ann/layer/radial_basis_function.hpp
@@ -0,0 +1,154 @@
/**
* @file radial_basis_function.hpp
* @author Himanshu Pathak
*
* Definition of the Radial Basis Function module class.
*
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_ANN_LAYER_RBF_HPP
#define MLPACK_METHODS_ANN_LAYER_RBF_HPP

#include <mlpack/prereqs.hpp>
#include <mlpack/methods/ann/activation_functions/gaussian_function.hpp>

#include "layer_types.hpp"

namespace mlpack {
namespace ann /** Artificial Neural Network. */ {


/**
* Implementation of the Radial Basis Function layer. The RBF class when use with a
* non-linear activation function acts as a Radial Basis Function which can be used
* with Feed-Forward neural network.
*
* For more information, refer to the following paper,
*
* @code
* @article{Volume 51: Artificial Intelligence and Statistics,
* author = {Qichao Que, Mikhail Belkin},
* title = {Back to the Future: Radial Basis Function Networks Revisited},
* year = {2016},
* url = {http://proceedings.mlr.press/v51/que16.pdf},
* }
* @endcode
*
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
* @tparam Activation Type of the activation function (mlpack::ann::Gaussian).
*/

template <
typename InputDataType = arma::mat,
typename OutputDataType = arma::mat,
typename Activation = GaussianFunction
>
class RBF
{
public:
//! Create the RBF object.
RBF();

/**
* Create the Radial Basis Function layer object using the specified
* parameters.
*
* @param inSize The number of input units.
* @param outSize The number of output units.
* @param centres The centres calculated using k-means of data.
* @param betas The beta value to be used with centres.
*/
RBF(const size_t inSize,
const size_t outSize,
arma::mat& centres,
double betas = 0);

/**
* Ordinary feed forward pass of the radial basis function.
*
* @param input Input data used for evaluating the specified function.
* @param output Resulting output activation.
*/
template<typename eT>
void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);

/**
* Ordinary feed backward pass of the radial basis function.
*
*/
template<typename eT>
void Backward(const arma::Mat<eT>& /* input */,
const arma::Mat<eT>& /* gy */,
arma::Mat<eT>& /* g */);

//! Get the output parameter.
OutputDataType const& OutputParameter() const { return outputParameter; }
//! Modify the output parameter.
OutputDataType& OutputParameter() { return outputParameter; }
//! Get the parameters.

//! Get the input parameter.
InputDataType const& InputParameter() const { return inputParameter; }
//! Modify the input parameter.
InputDataType& InputParameter() { return inputParameter; }

//! Get the input size.
size_t InputSize() const { return inSize; }

//! Get the output size.
size_t OutputSize() const { return outSize; }

//! Get the detla.
OutputDataType const& Delta() const { return delta; }
//! Modify the delta.
OutputDataType& Delta() { return delta; }

/**
* Serialize the layer.
*/
template<typename Archive>
void serialize(Archive& ar, const unsigned int /* version */);

private:
//! Locally-stored number of input units.
size_t inSize;

//! Locally-stored number of output units.
size_t outSize;

//! Locally-stored delta object.
OutputDataType delta;

//! Locally-stored output parameter object.
OutputDataType outputParameter;

//! Locally-stored the sigmas values.
double sigmas;

//! Locally-stored the betas values.
double betas;

//! Locally-stored the learnable centre of the shape.
InputDataType centres;

//! Locally-stored input parameter object.
InputDataType inputParameter;

//! Locally-stored the output distances of the shape.
OutputDataType distances;
}; // class RBF

} // namespace ann
} // namespace mlpack

// Include implementation.
#include "radial_basis_function_impl.hpp"

#endif

0 comments on commit 398a276

Please sign in to comment.