Skip to content

Commit

Permalink
Adding variable length support for LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
sumedhghaisas committed Jun 23, 2017
1 parent e2f4bb2 commit e241f5e
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/mlpack/methods/ann/layer/gru_impl.hpp
Expand Up @@ -131,7 +131,7 @@ void GRU<InputDataType, OutputDataType>::Forward(
hiddenStateModule));

forwardStep++;
if (forwardStep == rho && !deterministic)
if (forwardStep == rho)
{
forwardStep = 0;
if (!deterministic)
Expand Down
16 changes: 9 additions & 7 deletions src/mlpack/methods/ann/layer/lstm.hpp
Expand Up @@ -93,6 +93,8 @@ class LSTM
void Gradient(arma::Mat<eT>&& input,
arma::Mat<eT>&& /* error */,
arma::Mat<eT>&& /* gradient */);

void ResetCell();

//! The value of the deterministic parameter.
bool Deterministic() const { return deterministic; }
Expand Down Expand Up @@ -152,10 +154,10 @@ class LSTM
OutputDataType weights;

//! Locally-stored previous output.
arma::mat prevOutput;
std::list<arma::mat>::iterator prevOutput;

//! Locally-stored previous cell state.
arma::mat prevCell;
std::list<arma::mat>::iterator prevCell;

//! Locally-stored input 2 gate module.
LayerTypes input2GateModule;
Expand Down Expand Up @@ -200,17 +202,17 @@ class LSTM
size_t gradientStep;

//! Locally-stored cell parameters.
std::vector<arma::mat> cellParameter;
std::list<arma::mat> cellParameter;

//! Locally-stored output parameters.
std::vector<arma::mat> outParameter;
std::list<arma::mat> outParameter;

std::list<arma::mat>::iterator backIterator;
std::list<arma::mat>::iterator gradIterator;

//! Locally-stored previous error.
arma::mat prevError;

//! Locally-stored cell activation error.
arma::mat cellActivationError;

//! Locally-stored foget gate error.
arma::mat forgetGateError;

Expand Down
114 changes: 78 additions & 36 deletions src/mlpack/methods/ann/layer/lstm_impl.hpp
Expand Up @@ -64,31 +64,34 @@ LSTM<InputDataType, OutputDataType>::LSTM(
network.push_back(cellModule);
network.push_back(cellActivationModule);

prevOutput = arma::zeros<arma::mat>(outSize, 1);
prevCell = arma::zeros<arma::mat>(outSize, 1);
//prevOutput = arma::zeros<arma::mat>(outSize, 1);
//prevCell = arma::zeros<arma::mat>(outSize, 1);
prevError = arma::zeros<arma::mat>(4 * outSize, 1);
cellActivationError = arma::zeros<arma::mat>(outSize, 1);

cellParameter.reserve(rho);
outParameter.reserve(rho);
//cellActivationError = arma::zeros<arma::mat>(outSize, 1);

outParameter.push_back(arma::zeros<arma::mat>(outSize, 1));
cellParameter.push_back(arma::zeros<arma::mat>(outSize, 1));

prevOutput = outParameter.begin();
prevCell = cellParameter.begin();

backIterator = cellParameter.end();
gradIterator = outParameter.end();

//cellParameter.reserve(rho);
//outParameter.reserve(rho);
}

template<typename InputDataType, typename OutputDataType>
template<typename eT>
void LSTM<InputDataType, OutputDataType>::Forward(
arma::Mat<eT>&& input, arma::Mat<eT>&& output)
{
if (!deterministic)
{
cellParameter.push_back(prevCell);
outParameter.push_back(prevOutput);
}

boost::apply_visitor(ForwardVisitor(std::move(input), std::move(
boost::apply_visitor(outputParameterVisitor, input2GateModule))),
input2GateModule);

boost::apply_visitor(ForwardVisitor(std::move(prevOutput), std::move(
boost::apply_visitor(ForwardVisitor(std::move(*prevOutput), std::move(
boost::apply_visitor(outputParameterVisitor, output2GateModule))),
output2GateModule);

Expand All @@ -114,12 +117,12 @@ void LSTM<InputDataType, OutputDataType>::Forward(
// Update the cell (nextCell): cmul1 + cmul2
// where cmul1 is input gate * hidden state and
// cmul2 is forget gate * cell (prevCell).
prevCell = (boost::apply_visitor(outputParameterVisitor,
arma::mat tempPrevCell = (boost::apply_visitor(outputParameterVisitor,
inputGateModule) % boost::apply_visitor(outputParameterVisitor,
hiddenStateModule)) + (boost::apply_visitor(outputParameterVisitor,
forgetGateModule) % prevCell);
forgetGateModule) % *prevCell);

boost::apply_visitor(ForwardVisitor(std::move(prevCell), std::move(
boost::apply_visitor(ForwardVisitor(std::move(tempPrevCell), std::move(
boost::apply_visitor(outputParameterVisitor, cellModule))), cellModule);

boost::apply_visitor(ForwardVisitor(std::move(boost::apply_visitor(
Expand All @@ -130,14 +133,34 @@ void LSTM<InputDataType, OutputDataType>::Forward(
cellActivationModule) % boost::apply_visitor(outputParameterVisitor,
outputGateModule);

prevOutput = output;

forwardStep++;
if (forwardStep == rho)
{
forwardStep = 0;
prevOutput.zeros();
prevCell.zeros();
if (!deterministic)
{
outParameter.push_back(arma::zeros<arma::mat>(outSize, 1));
cellParameter.push_back(arma::zeros<arma::mat>(outSize, 1));
prevOutput = --outParameter.end();
prevCell = --cellParameter.end();
}
else
{
*prevOutput = arma::zeros<arma::mat>(outSize, 1);
*prevCell = arma::zeros<arma::mat>(outSize, 1);
}
}
else if (!deterministic)
{
outParameter.push_back(output);
cellParameter.push_back(std::move(tempPrevCell));
prevOutput = --outParameter.end();
prevCell = --cellParameter.end();
}
else
{
*prevOutput = output;
*prevCell = std::move(tempPrevCell);
}
}

Expand All @@ -146,10 +169,15 @@ template<typename eT>
void LSTM<InputDataType, OutputDataType>::Backward(
const arma::Mat<eT>&& /* input */, arma::Mat<eT>&& gy, arma::Mat<eT>&& g)
{
if (backwardStep > 0)
if ((outParameter.size() - backwardStep - 1) % rho != 0 && backwardStep != 0)
{
gy += boost::apply_visitor(deltaVisitor, output2GateModule);
}

if (backIterator == cellParameter.end())
{
backIterator = --(--cellParameter.end());
}

arma::mat g1 = boost::apply_visitor(outputParameterVisitor,
cellActivationModule) % gy;
Expand All @@ -162,7 +190,7 @@ void LSTM<InputDataType, OutputDataType>::Backward(
std::move(boost::apply_visitor(deltaVisitor, cellActivationModule))),
cellActivationModule);

cellActivationError = boost::apply_visitor(deltaVisitor,
arma::mat cellActivationError = boost::apply_visitor(deltaVisitor,
cellActivationModule);

if (backwardStep > 0)
Expand All @@ -179,8 +207,7 @@ void LSTM<InputDataType, OutputDataType>::Backward(
forgetGateError = boost::apply_visitor(outputParameterVisitor,
forgetGateModule) % cellActivationError;

arma::mat g7 = cellParameter[cellParameter.size() -
backwardStep - 1] % cellActivationError;
arma::mat g7 = *backIterator % cellActivationError;

boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor(
outputParameterVisitor, inputGateModule)), std::move(g5),
Expand Down Expand Up @@ -222,11 +249,7 @@ void LSTM<InputDataType, OutputDataType>::Backward(
output2GateModule);

backwardStep++;
if (backwardStep == rho)
{
backwardStep = 0;
cellParameter.clear();
}
backIterator--;

g = boost::apply_visitor(deltaVisitor, input2GateModule);
}
Expand All @@ -238,19 +261,38 @@ void LSTM<InputDataType, OutputDataType>::Gradient(
arma::Mat<eT>&& /* error */,
arma::Mat<eT>&& /* gradient */)
{
if (gradIterator == outParameter.end())
{
gradIterator = --(--outParameter.end());
}

boost::apply_visitor(GradientVisitor(std::move(input), std::move(prevError)),
input2GateModule);

boost::apply_visitor(GradientVisitor(
std::move(outParameter[outParameter.size() - gradientStep - 1]),
std::move(*gradIterator),
std::move(prevError)), output2GateModule);

gradientStep++;
if (gradientStep == rho)
{
gradientStep = 0;
outParameter.clear();
}
gradIterator--;
}

template<typename InputDataType, typename OutputDataType>
void LSTM<InputDataType, OutputDataType>::ResetCell()
{
outParameter.clear();
outParameter.push_back(arma::zeros<arma::mat>(outSize, 1));

cellParameter.clear();
cellParameter.push_back(arma::zeros<arma::mat>(outSize, 1));

prevOutput = outParameter.begin();
prevCell = cellParameter.begin();

backIterator = cellParameter.end();
gradIterator = outParameter.end();

forwardStep = 0;
backwardStep = 0;
}

template<typename InputDataType, typename OutputDataType>
Expand Down

0 comments on commit e241f5e

Please sign in to comment.