Skip to content

Commit

Permalink
Merge pull request #1392 from haritha1313/multmerge
Browse files Browse the repository at this point in the history
MultiplyMerge layer added.
  • Loading branch information
zoq committed May 16, 2018
2 parents a30f5f8 + a6a99bf commit d44d4cd
Show file tree
Hide file tree
Showing 7 changed files with 293 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/mlpack/methods/ann/layer/CMakeLists.txt
Expand Up @@ -59,6 +59,8 @@ set(SOURCES
mean_pooling_impl.hpp
multiply_constant.hpp
multiply_constant_impl.hpp
multiply_merge.hpp
multiply_merge_impl.hpp
negative_log_likelihood.hpp
negative_log_likelihood_impl.hpp
parametric_relu.hpp
Expand Down
2 changes: 1 addition & 1 deletion src/mlpack/methods/ann/layer/add_merge.hpp
Expand Up @@ -45,7 +45,7 @@ class AddMerge
/**
* Create the AddMerge object using the specified parameters.
*
* @param model Expose the all network modules.
* @param model Expose all the network modules.
*/
AddMerge(const bool model = false);

Expand Down
1 change: 1 addition & 0 deletions src/mlpack/methods/ann/layer/layer.hpp
Expand Up @@ -22,6 +22,7 @@
#include "linear.hpp"
#include "linear_no_bias.hpp"
#include "lstm.hpp"
#include "multiply_merge.hpp"
#include "gru.hpp"
#include "fast_lstm.hpp"
#include "recurrent.hpp"
Expand Down
7 changes: 7 additions & 0 deletions src/mlpack/methods/ann/layer/layer_types.hpp
Expand Up @@ -111,6 +111,12 @@ template<
>
class RecurrentAttention;

template<typename InputDataType,
typename OutputDataType,
typename... CustomLayers
>
class MultiplyMerge;

template <typename... CustomLayers>
using LayerTypes = boost::variant<
Add<arma::mat, arma::mat>*,
Expand Down Expand Up @@ -150,6 +156,7 @@ using LayerTypes = boost::variant<
MaxPooling<arma::mat, arma::mat>*,
MeanPooling<arma::mat, arma::mat>*,
MultiplyConstant<arma::mat, arma::mat>*,
MultiplyMerge<arma::mat, arma::mat>*,
NegativeLogLikelihood<arma::mat, arma::mat>*,
PReLU<arma::mat, arma::mat>*,
Recurrent<arma::mat, arma::mat>*,
Expand Down
172 changes: 172 additions & 0 deletions src/mlpack/methods/ann/layer/multiply_merge.hpp
@@ -0,0 +1,172 @@
/**
* @file multiply_merge.hpp
* @author Haritha Nair
*
* Definition of the MultiplyMerge module which multiplies the output of the
* given modules element-wise.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_HPP
#define MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_HPP

#include <mlpack/prereqs.hpp>

#include "../visitor/delete_visitor.hpp"
#include "../visitor/delta_visitor.hpp"
#include "../visitor/output_parameter_visitor.hpp"

#include "layer_types.hpp"

namespace mlpack {
namespace ann /** Artificial Neural Network. */ {

/**
* Implementation of the MultiplyMerge module class. The MultiplyMerge class
* multiplies the output of various modules element-wise.
*
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
* @tparam CustomLayers Additional custom layers that can be added.
*/
template<
typename InputDataType = arma::mat,
typename OutputDataType = arma::mat,
typename... CustomLayers
>
class MultiplyMerge
{
public:
/**
* Create the MultiplyMerge object using the specified parameters.
*
* @param model Expose all the network modules.
*/
MultiplyMerge(const bool model = false);

//! Destructor to release allocated memory.
~MultiplyMerge();

/**
* Ordinary feed forward pass of a neural network, evaluating the function
* f(x) by propagating the activity forward through f.
*
* @param input Input data used for evaluating the specified function.
* @param output Resulting output activation.
*/
template<typename InputType, typename OutputType>
void Forward(const InputType&& /* input */, OutputType&& output);

/**
* Ordinary feed backward pass of a neural network, calculating the function
* f(x) by propagating x backwards trough f, using the results from the feed
* forward pass.
*
* @param input The propagated input activation.
* @param gy The backpropagated error.
* @param g The calculated gradient.
*/
template<typename eT>
void Backward(const arma::Mat<eT>&& /* input */,
arma::Mat<eT>&& gy,
arma::Mat<eT>&& g);

/*
* Add a new module to the model.
*
* @param layer The Layer to be added to the model.
*/
void Add(LayerTypes<CustomLayers...> layer) { network.push_back(layer); }

/*
* Add a new module to the model.
*
* @param layer The Layer to be added to the model.
*/
template<typename LayerType>
void Add(const LayerType& layer) { network.push_back(new LayerType(layer)); }

/*
* Add a new module to the model.
*
* @param args The layer parameter.
*/
template <class LayerType, class... Args>
void Add(Args... args) { network.push_back(new LayerType(args...)); }

//! Get the input parameter.
InputDataType const& InputParameter() const { return inputParameter; }
//! Modify the input parameter.
InputDataType& InputParameter() { return inputParameter; }

//! Get the output parameter.
OutputDataType const& OutputParameter() const { return outputParameter; }
//! Modify the output parameter.
OutputDataType& OutputParameter() { return outputParameter; }

//! Get the delta.
OutputDataType const& Delta() const { return delta; }
//! Modify the delta.
OutputDataType& Delta() { return delta; }

//! Return the model modules.
std::vector<LayerTypes<CustomLayers...> >& Model()
{
if (model)
{
return network;
}

return empty;
}

/**
* Serialize the layer.
*/
template<typename Archive>
void serialize(Archive& ar, const unsigned int /* version */);

private:
//! Parameter which indicates if the modules should be exposed.
bool model;

//! We need this to know whether we should delete the layer in the destructor.
bool ownsLayer;

//! Locally-stored network modules.
std::vector<LayerTypes<CustomLayers...> > network;

//! Locally-stored empty list of modules.
std::vector<LayerTypes<CustomLayers...> > empty;

//! Locally-stored delete visitor module object.
DeleteVisitor deleteVisitor;

//! Locally-stored output parameter visitor module object.
OutputParameterVisitor outputParameterVisitor;

//! Locally-stored delta visitor module object.
DeltaVisitor deltaVisitor;

//! Locally-stored delta object.
OutputDataType delta;

//! Locally-stored input parameter object.
InputDataType inputParameter;

//! Locally-stored output parameter object.
OutputDataType outputParameter;
}; // class MultiplyMerge

} // namespace ann
} // namespace mlpack

// Include implementation.
#include "multiply_merge_impl.hpp"

#endif
79 changes: 79 additions & 0 deletions src/mlpack/methods/ann/layer/multiply_merge_impl.hpp
@@ -0,0 +1,79 @@
/**
* @file multiply_merge_impl.hpp
* @author Haritha Nair
*
* Definition of the MultiplyMerge module which multiplies the output of the
* given modules element-wise.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_IMPL_HPP
#define MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_IMPL_HPP

// In case it hasn't yet been included.
#include "multiply_merge.hpp"

namespace mlpack {
namespace ann /** Artificial Neural Network. */ {

template<typename InputDataType, typename OutputDataType,
typename... CustomLayers>
MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>::MultiplyMerge(
const bool model) : model(model), ownsLayer(!model)
{
// Nothing to do here.
}

template<typename InputDataType, typename OutputDataType,
typename... CustomLayers>
MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>::~MultiplyMerge()
{
if (ownsLayer)
{
std::for_each(network.begin(), network.end(),
boost::apply_visitor(deleteVisitor));
}
}

template <typename InputDataType, typename OutputDataType,
typename... CustomLayers>
template<typename InputType, typename OutputType>
void MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>::Forward(
const InputType&& /* input */, OutputType&& output)
{
output = boost::apply_visitor(outputParameterVisitor, network.front());

for (size_t i = 1; i < network.size(); ++i)
{
output %= boost::apply_visitor(outputParameterVisitor, network[i]);
}
}

template<typename InputDataType, typename OutputDataType,
typename... CustomLayers>
template<typename eT>
void MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>::Backward(
const arma::Mat<eT>&& /* input */, arma::Mat<eT>&& gy, arma::Mat<eT>&& g)
{
g = gy;
}

template<typename InputDataType, typename OutputDataType,
typename... CustomLayers>
template<typename Archive>
void MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>::serialize(
Archive& ar, const unsigned int /* version */)
{
if (Archive::is_loading::value)
network.clear();

ar & BOOST_SERIALIZATION_NVP(network);
}

} // namespace ann
} // namespace mlpack

#endif
31 changes: 31 additions & 0 deletions src/mlpack/tests/ann_layer_test.cpp
Expand Up @@ -1607,4 +1607,35 @@ BOOST_AUTO_TEST_CASE(GradientTransposedConvolutionLayerTest)
BOOST_REQUIRE_LE(CheckGradient(function), 1e-4);
}

/**
* Simple multiply merge module test.
*/
BOOST_AUTO_TEST_CASE(SimpleMultiplyMergeLayerTest)
{
arma::mat output, input, delta;
input = arma::ones(10, 1);

for (size_t i = 0; i < 5; ++i)
{
MultiplyMerge<> module;
const size_t numMergeModules = math::RandInt(2, 10);
for (size_t m = 0; m < numMergeModules; ++m)
{
IdentityLayer<> identityLayer;
identityLayer.Forward(std::move(input),
std::move(identityLayer.OutputParameter()));

module.Add(identityLayer);
}

// Test the Forward function.
module.Forward(std::move(input), std::move(output));
BOOST_REQUIRE_EQUAL(10, arma::accu(output));

// Test the Backward function.
module.Backward(std::move(input), std::move(output), std::move(delta));
BOOST_REQUIRE_EQUAL(arma::accu(output), arma::accu(delta));
}
}

BOOST_AUTO_TEST_SUITE_END();

0 comments on commit d44d4cd

Please sign in to comment.