Skip to content

Commit

Permalink
Fixing 'rho' issue in RNN
Browse files Browse the repository at this point in the history
  • Loading branch information
sumedhghaisas committed Aug 13, 2017
1 parent 759445d commit 5ee5f24
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 181 deletions.
4 changes: 3 additions & 1 deletion src/mlpack/methods/ann/layer/gru.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define MLPACK_METHODS_ANN_LAYER_GRU_HPP

#include <list>
#include <limits>

#include <mlpack/prereqs.hpp>

Expand Down Expand Up @@ -66,7 +67,8 @@ class GRU
* @param outSize The number of output units.
* @param rho Maximum number of steps to backpropagate through time (BPTT).
*/
GRU(const size_t inSize, const size_t outSize, const size_t rho);
GRU(const size_t inSize, const size_t outSize, const size_t rho =
std::numeric_limits<size_t>::max());

/**
* Ordinary feed forward pass of a neural network, evaluating the function
Expand Down
5 changes: 4 additions & 1 deletion src/mlpack/methods/ann/layer/lstm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#include <mlpack/prereqs.hpp>

#include <limits>

#include "../visitor/delta_visitor.hpp"
#include "../visitor/output_parameter_visitor.hpp"

Expand Down Expand Up @@ -54,7 +56,8 @@ class LSTM
* @param outSize The number of output units.
* @param rho Maximum number of steps to backpropagate through time (BPTT).
*/
LSTM(const size_t inSize, const size_t outSize, const size_t rho);
LSTM(const size_t inSize, const size_t outSize, const size_t rho =
std::numeric_limits<size_t>::max());

/**
* Ordinary feed forward pass of a neural network, evaluating the function
Expand Down
5 changes: 0 additions & 5 deletions src/mlpack/methods/ann/rnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,6 @@ class RNN
//! Serialize the model.
template<typename Archive>
void Serialize(Archive& ar, const unsigned int /* version */);

//! Get the maximum number of steps to backpropagate through time (BPTT).
size_t Rho() const { return rho; }
//! Modify the maximum number of steps to backpropagate through time (BPTT).
size_t& Rho() { return rho; }
private:
// Helper functions.
/**
Expand Down
15 changes: 1 addition & 14 deletions src/mlpack/methods/ann/rnn_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "visitor/gradient_set_visitor.hpp"
#include "visitor/gradient_visitor.hpp"
#include "visitor/weight_set_visitor.hpp"
#include "visitor/rho_set_visitor.hpp"

namespace mlpack {
namespace ann /** Artificial Neural Network. */ {
Expand Down Expand Up @@ -184,15 +183,6 @@ template<typename OutputLayerType, typename InitializationRuleType>
void RNN<OutputLayerType, InitializationRuleType>::SinglePredict(
const arma::mat& predictors, arma::mat& results)
{
if (prevRho != rho)
{
for (size_t i = 1; i < network.size(); ++i)
boost::apply_visitor(RhoSetVisitor(rho), network[i]);

inputSize = predictors.n_elem / rho;
prevRho = rho;
}

for (size_t seqNum = 0; seqNum < rho; ++seqNum)
{
currentInput = predictors.rows(seqNum * inputSize,
Expand Down Expand Up @@ -225,13 +215,10 @@ double RNN<OutputLayerType, InitializationRuleType>::Evaluate(
arma::mat target = arma::mat(responses.colptr(i), responses.n_rows,
1, false, true);

if (prevRho != rho)
if (!inputSize)
{
for (size_t i = 1; i < network.size(); ++i)
boost::apply_visitor(RhoSetVisitor(rho), network[i]);
inputSize = input.n_elem / rho;
targetSize = target.n_elem / rho;
prevRho = rho;
}

ResetCells();
Expand Down
4 changes: 2 additions & 2 deletions src/mlpack/methods/ann/visitor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ set(SOURCES
parameters_set_visitor_impl.hpp
parameters_visitor.hpp
parameters_visitor_impl.hpp
reset_cell_visitor.hpp
reset_cell_visitor_impl.hpp
reset_visitor.hpp
reset_visitor_impl.hpp
reward_set_visitor.hpp
reward_set_visitor_impl.hpp
rho_set_visitor.hpp
rho_set_visitor_impl.hpp
save_output_parameter_visitor.hpp
save_output_parameter_visitor_impl.hpp
set_input_height_visitor.hpp
Expand Down
77 changes: 0 additions & 77 deletions src/mlpack/methods/ann/visitor/rho_set_visitor.hpp

This file was deleted.

81 changes: 0 additions & 81 deletions src/mlpack/methods/ann/visitor/rho_set_visitor_impl.hpp

This file was deleted.

0 comments on commit 5ee5f24

Please sign in to comment.