Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1392 from haritha1313/multmerge
MultiplyMerge layer added.
- Loading branch information
Showing
7 changed files
with
293 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
/** | ||
* @file multiply_merge.hpp | ||
* @author Haritha Nair | ||
* | ||
* Definition of the MultiplyMerge module which multiplies the output of the | ||
* given modules element-wise. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_HPP | ||
#define MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
|
||
#include "../visitor/delete_visitor.hpp" | ||
#include "../visitor/delta_visitor.hpp" | ||
#include "../visitor/output_parameter_visitor.hpp" | ||
|
||
#include "layer_types.hpp" | ||
|
||
namespace mlpack { | ||
namespace ann /** Artificial Neural Network. */ { | ||
|
||
/** | ||
* Implementation of the MultiplyMerge module class. The MultiplyMerge class | ||
* multiplies the output of various modules element-wise. | ||
* | ||
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat, | ||
* arma::sp_mat or arma::cube). | ||
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat, | ||
* arma::sp_mat or arma::cube). | ||
* @tparam CustomLayers Additional custom layers that can be added. | ||
*/ | ||
template< | ||
typename InputDataType = arma::mat, | ||
typename OutputDataType = arma::mat, | ||
typename... CustomLayers | ||
> | ||
class MultiplyMerge | ||
{ | ||
public: | ||
/** | ||
* Create the MultiplyMerge object using the specified parameters. | ||
* | ||
* @param model Expose all the network modules. | ||
*/ | ||
MultiplyMerge(const bool model = false); | ||
|
||
//! Destructor to release allocated memory. | ||
~MultiplyMerge(); | ||
|
||
/** | ||
* Ordinary feed forward pass of a neural network, evaluating the function | ||
* f(x) by propagating the activity forward through f. | ||
* | ||
* @param input Input data used for evaluating the specified function. | ||
* @param output Resulting output activation. | ||
*/ | ||
template<typename InputType, typename OutputType> | ||
void Forward(const InputType&& /* input */, OutputType&& output); | ||
|
||
/** | ||
* Ordinary feed backward pass of a neural network, calculating the function | ||
* f(x) by propagating x backwards trough f, using the results from the feed | ||
* forward pass. | ||
* | ||
* @param input The propagated input activation. | ||
* @param gy The backpropagated error. | ||
* @param g The calculated gradient. | ||
*/ | ||
template<typename eT> | ||
void Backward(const arma::Mat<eT>&& /* input */, | ||
arma::Mat<eT>&& gy, | ||
arma::Mat<eT>&& g); | ||
|
||
/* | ||
* Add a new module to the model. | ||
* | ||
* @param layer The Layer to be added to the model. | ||
*/ | ||
void Add(LayerTypes<CustomLayers...> layer) { network.push_back(layer); } | ||
|
||
/* | ||
* Add a new module to the model. | ||
* | ||
* @param layer The Layer to be added to the model. | ||
*/ | ||
template<typename LayerType> | ||
void Add(const LayerType& layer) { network.push_back(new LayerType(layer)); } | ||
|
||
/* | ||
* Add a new module to the model. | ||
* | ||
* @param args The layer parameter. | ||
*/ | ||
template <class LayerType, class... Args> | ||
void Add(Args... args) { network.push_back(new LayerType(args...)); } | ||
|
||
//! Get the input parameter. | ||
InputDataType const& InputParameter() const { return inputParameter; } | ||
//! Modify the input parameter. | ||
InputDataType& InputParameter() { return inputParameter; } | ||
|
||
//! Get the output parameter. | ||
OutputDataType const& OutputParameter() const { return outputParameter; } | ||
//! Modify the output parameter. | ||
OutputDataType& OutputParameter() { return outputParameter; } | ||
|
||
//! Get the delta. | ||
OutputDataType const& Delta() const { return delta; } | ||
//! Modify the delta. | ||
OutputDataType& Delta() { return delta; } | ||
|
||
//! Return the model modules. | ||
std::vector<LayerTypes<CustomLayers...> >& Model() | ||
{ | ||
if (model) | ||
{ | ||
return network; | ||
} | ||
|
||
return empty; | ||
} | ||
|
||
/** | ||
* Serialize the layer. | ||
*/ | ||
template<typename Archive> | ||
void serialize(Archive& ar, const unsigned int /* version */); | ||
|
||
private: | ||
//! Parameter which indicates if the modules should be exposed. | ||
bool model; | ||
|
||
//! We need this to know whether we should delete the layer in the destructor. | ||
bool ownsLayer; | ||
|
||
//! Locally-stored network modules. | ||
std::vector<LayerTypes<CustomLayers...> > network; | ||
|
||
//! Locally-stored empty list of modules. | ||
std::vector<LayerTypes<CustomLayers...> > empty; | ||
|
||
//! Locally-stored delete visitor module object. | ||
DeleteVisitor deleteVisitor; | ||
|
||
//! Locally-stored output parameter visitor module object. | ||
OutputParameterVisitor outputParameterVisitor; | ||
|
||
//! Locally-stored delta visitor module object. | ||
DeltaVisitor deltaVisitor; | ||
|
||
//! Locally-stored delta object. | ||
OutputDataType delta; | ||
|
||
//! Locally-stored input parameter object. | ||
InputDataType inputParameter; | ||
|
||
//! Locally-stored output parameter object. | ||
OutputDataType outputParameter; | ||
}; // class MultiplyMerge | ||
|
||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
// Include implementation. | ||
#include "multiply_merge_impl.hpp" | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/** | ||
* @file multiply_merge_impl.hpp | ||
* @author Haritha Nair | ||
* | ||
* Definition of the MultiplyMerge module which multiplies the output of the | ||
* given modules element-wise. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_IMPL_HPP | ||
#define MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_IMPL_HPP | ||
|
||
// In case it hasn't yet been included. | ||
#include "multiply_merge.hpp" | ||
|
||
namespace mlpack { | ||
namespace ann /** Artificial Neural Network. */ { | ||
|
||
template<typename InputDataType, typename OutputDataType, | ||
typename... CustomLayers> | ||
MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>::MultiplyMerge( | ||
const bool model) : model(model), ownsLayer(!model) | ||
{ | ||
// Nothing to do here. | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType, | ||
typename... CustomLayers> | ||
MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>::~MultiplyMerge() | ||
{ | ||
if (ownsLayer) | ||
{ | ||
std::for_each(network.begin(), network.end(), | ||
boost::apply_visitor(deleteVisitor)); | ||
} | ||
} | ||
|
||
template <typename InputDataType, typename OutputDataType, | ||
typename... CustomLayers> | ||
template<typename InputType, typename OutputType> | ||
void MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>::Forward( | ||
const InputType&& /* input */, OutputType&& output) | ||
{ | ||
output = boost::apply_visitor(outputParameterVisitor, network.front()); | ||
|
||
for (size_t i = 1; i < network.size(); ++i) | ||
{ | ||
output %= boost::apply_visitor(outputParameterVisitor, network[i]); | ||
} | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType, | ||
typename... CustomLayers> | ||
template<typename eT> | ||
void MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>::Backward( | ||
const arma::Mat<eT>&& /* input */, arma::Mat<eT>&& gy, arma::Mat<eT>&& g) | ||
{ | ||
g = gy; | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType, | ||
typename... CustomLayers> | ||
template<typename Archive> | ||
void MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>::serialize( | ||
Archive& ar, const unsigned int /* version */) | ||
{ | ||
if (Archive::is_loading::value) | ||
network.clear(); | ||
|
||
ar & BOOST_SERIALIZATION_NVP(network); | ||
} | ||
|
||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters