diff --git a/src/mlpack/methods/ann/layer/gru_impl.hpp b/src/mlpack/methods/ann/layer/gru_impl.hpp index a13f8952ae5..d489c875aba 100644 --- a/src/mlpack/methods/ann/layer/gru_impl.hpp +++ b/src/mlpack/methods/ann/layer/gru_impl.hpp @@ -131,7 +131,7 @@ void GRU::Forward( hiddenStateModule)); forwardStep++; - if (forwardStep == rho && !deterministic) + if (forwardStep == rho) { forwardStep = 0; if (!deterministic) diff --git a/src/mlpack/methods/ann/layer/lstm.hpp b/src/mlpack/methods/ann/layer/lstm.hpp index 773f97bb695..8d2791fd406 100644 --- a/src/mlpack/methods/ann/layer/lstm.hpp +++ b/src/mlpack/methods/ann/layer/lstm.hpp @@ -93,6 +93,8 @@ class LSTM void Gradient(arma::Mat&& input, arma::Mat&& /* error */, arma::Mat&& /* gradient */); + + void ResetCell(); //! The value of the deterministic parameter. bool Deterministic() const { return deterministic; } @@ -152,10 +154,10 @@ class LSTM OutputDataType weights; //! Locally-stored previous output. - arma::mat prevOutput; + std::list::iterator prevOutput; //! Locally-stored previous cell state. - arma::mat prevCell; + std::list::iterator prevCell; //! Locally-stored input 2 gate module. LayerTypes input2GateModule; @@ -200,17 +202,17 @@ class LSTM size_t gradientStep; //! Locally-stored cell parameters. - std::vector cellParameter; + std::list cellParameter; //! Locally-stored output parameters. - std::vector outParameter; + std::list outParameter; + + std::list::iterator backIterator; + std::list::iterator gradIterator; //! Locally-stored previous error. arma::mat prevError; - //! Locally-stored cell activation error. - arma::mat cellActivationError; - //! Locally-stored foget gate error. arma::mat forgetGateError; diff --git a/src/mlpack/methods/ann/layer/lstm_impl.hpp b/src/mlpack/methods/ann/layer/lstm_impl.hpp index 29f1a7877b0..58f1f48b140 100644 --- a/src/mlpack/methods/ann/layer/lstm_impl.hpp +++ b/src/mlpack/methods/ann/layer/lstm_impl.hpp @@ -64,13 +64,22 @@ LSTM::LSTM( network.push_back(cellModule); network.push_back(cellActivationModule); - prevOutput = arma::zeros(outSize, 1); - prevCell = arma::zeros(outSize, 1); + //prevOutput = arma::zeros(outSize, 1); + //prevCell = arma::zeros(outSize, 1); prevError = arma::zeros(4 * outSize, 1); - cellActivationError = arma::zeros(outSize, 1); - - cellParameter.reserve(rho); - outParameter.reserve(rho); + //cellActivationError = arma::zeros(outSize, 1); + + outParameter.push_back(arma::zeros(outSize, 1)); + cellParameter.push_back(arma::zeros(outSize, 1)); + + prevOutput = outParameter.begin(); + prevCell = cellParameter.begin(); + + backIterator = cellParameter.end(); + gradIterator = outParameter.end(); + + //cellParameter.reserve(rho); + //outParameter.reserve(rho); } template @@ -78,17 +87,11 @@ template void LSTM::Forward( arma::Mat&& input, arma::Mat&& output) { - if (!deterministic) - { - cellParameter.push_back(prevCell); - outParameter.push_back(prevOutput); - } - boost::apply_visitor(ForwardVisitor(std::move(input), std::move( boost::apply_visitor(outputParameterVisitor, input2GateModule))), input2GateModule); - boost::apply_visitor(ForwardVisitor(std::move(prevOutput), std::move( + boost::apply_visitor(ForwardVisitor(std::move(*prevOutput), std::move( boost::apply_visitor(outputParameterVisitor, output2GateModule))), output2GateModule); @@ -114,12 +117,12 @@ void LSTM::Forward( // Update the cell (nextCell): cmul1 + cmul2 // where cmul1 is input gate * hidden state and // cmul2 is forget gate * cell (prevCell). - prevCell = (boost::apply_visitor(outputParameterVisitor, + arma::mat tempPrevCell = (boost::apply_visitor(outputParameterVisitor, inputGateModule) % boost::apply_visitor(outputParameterVisitor, hiddenStateModule)) + (boost::apply_visitor(outputParameterVisitor, - forgetGateModule) % prevCell); + forgetGateModule) % *prevCell); - boost::apply_visitor(ForwardVisitor(std::move(prevCell), std::move( + boost::apply_visitor(ForwardVisitor(std::move(tempPrevCell), std::move( boost::apply_visitor(outputParameterVisitor, cellModule))), cellModule); boost::apply_visitor(ForwardVisitor(std::move(boost::apply_visitor( @@ -130,14 +133,34 @@ void LSTM::Forward( cellActivationModule) % boost::apply_visitor(outputParameterVisitor, outputGateModule); - prevOutput = output; - forwardStep++; if (forwardStep == rho) { forwardStep = 0; - prevOutput.zeros(); - prevCell.zeros(); + if (!deterministic) + { + outParameter.push_back(arma::zeros(outSize, 1)); + cellParameter.push_back(arma::zeros(outSize, 1)); + prevOutput = --outParameter.end(); + prevCell = --cellParameter.end(); + } + else + { + *prevOutput = arma::zeros(outSize, 1); + *prevCell = arma::zeros(outSize, 1); + } + } + else if (!deterministic) + { + outParameter.push_back(output); + cellParameter.push_back(std::move(tempPrevCell)); + prevOutput = --outParameter.end(); + prevCell = --cellParameter.end(); + } + else + { + *prevOutput = output; + *prevCell = std::move(tempPrevCell); } } @@ -146,10 +169,15 @@ template void LSTM::Backward( const arma::Mat&& /* input */, arma::Mat&& gy, arma::Mat&& g) { - if (backwardStep > 0) + if ((outParameter.size() - backwardStep - 1) % rho != 0 && backwardStep != 0) { gy += boost::apply_visitor(deltaVisitor, output2GateModule); } + + if (backIterator == cellParameter.end()) + { + backIterator = --(--cellParameter.end()); + } arma::mat g1 = boost::apply_visitor(outputParameterVisitor, cellActivationModule) % gy; @@ -162,7 +190,7 @@ void LSTM::Backward( std::move(boost::apply_visitor(deltaVisitor, cellActivationModule))), cellActivationModule); - cellActivationError = boost::apply_visitor(deltaVisitor, + arma::mat cellActivationError = boost::apply_visitor(deltaVisitor, cellActivationModule); if (backwardStep > 0) @@ -179,8 +207,7 @@ void LSTM::Backward( forgetGateError = boost::apply_visitor(outputParameterVisitor, forgetGateModule) % cellActivationError; - arma::mat g7 = cellParameter[cellParameter.size() - - backwardStep - 1] % cellActivationError; + arma::mat g7 = *backIterator % cellActivationError; boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor( outputParameterVisitor, inputGateModule)), std::move(g5), @@ -222,11 +249,7 @@ void LSTM::Backward( output2GateModule); backwardStep++; - if (backwardStep == rho) - { - backwardStep = 0; - cellParameter.clear(); - } + backIterator--; g = boost::apply_visitor(deltaVisitor, input2GateModule); } @@ -238,19 +261,38 @@ void LSTM::Gradient( arma::Mat&& /* error */, arma::Mat&& /* gradient */) { + if (gradIterator == outParameter.end()) + { + gradIterator = --(--outParameter.end()); + } + boost::apply_visitor(GradientVisitor(std::move(input), std::move(prevError)), input2GateModule); boost::apply_visitor(GradientVisitor( - std::move(outParameter[outParameter.size() - gradientStep - 1]), + std::move(*gradIterator), std::move(prevError)), output2GateModule); - gradientStep++; - if (gradientStep == rho) - { - gradientStep = 0; - outParameter.clear(); - } + gradIterator--; +} + +template +void LSTM::ResetCell() +{ + outParameter.clear(); + outParameter.push_back(arma::zeros(outSize, 1)); + + cellParameter.clear(); + cellParameter.push_back(arma::zeros(outSize, 1)); + + prevOutput = outParameter.begin(); + prevCell = cellParameter.begin(); + + backIterator = cellParameter.end(); + gradIterator = outParameter.end(); + + forwardStep = 0; + backwardStep = 0; } template