From 914567679e12b3b9179909d711ad59c7ba4b7abe Mon Sep 17 00:00:00 2001 From: Arash Abghari Date: Wed, 6 Nov 2019 10:03:12 -0500 Subject: [PATCH 1/8] convert Transition and Initial matrix to log space --- src/mlpack/methods/hmm/hmm.hpp | 27 +++-- src/mlpack/methods/hmm/hmm_impl.hpp | 154 +++++++++++++++++----------- 2 files changed, 114 insertions(+), 67 deletions(-) diff --git a/src/mlpack/methods/hmm/hmm.hpp b/src/mlpack/methods/hmm/hmm.hpp index 1c113d36c18..021ff74dadd 100644 --- a/src/mlpack/methods/hmm/hmm.hpp +++ b/src/mlpack/methods/hmm/hmm.hpp @@ -323,14 +323,14 @@ class HMM arma::mat& smoothSeq) const; //! Return the vector of initial state probabilities. - const arma::vec& Initial() const { return initial; } + const arma::vec Initial() const { return initialProxy; } //! Modify the vector of initial state probabilities. - arma::vec& Initial() { return initial; } + arma::vec& Initial() { return initialProxy; } //! Return the transition matrix. - const arma::mat& Transition() const { return transition; } + const arma::mat Transition() const { return transitionProxy; } //! Return a modifiable transition matrix reference. - arma::mat& Transition() { return transition; } + arma::mat& Transition() { return transitionProxy; } //! Return the emission distributions. const std::vector& Emission() const { return emission; } @@ -351,7 +351,13 @@ class HMM * Serialize the object. */ template - void serialize(Archive& ar, const unsigned int version); + void load(Archive& ar, const unsigned int version); + + template + void save(Archive& ar, const unsigned int version) const; + + BOOST_SERIALIZATION_SPLIT_MEMBER() + protected: // Helper functions. @@ -387,12 +393,19 @@ class HMM //! Set of emission probability distributions; one for each state. std::vector emission; + arma::mat transitionProxy; + //! Transition probability matrix. - arma::mat transition; + mutable arma::mat logTransition; private: + + void ConvertToLogSpace() const; + + arma::vec initialProxy; + //! Initial state probability vector. - arma::vec initial; + mutable arma::vec logInitial; //! Dimensionality of observations. size_t dimensionality; diff --git a/src/mlpack/methods/hmm/hmm_impl.hpp b/src/mlpack/methods/hmm/hmm_impl.hpp index ae9521df9f0..ef857f4ee56 100644 --- a/src/mlpack/methods/hmm/hmm_impl.hpp +++ b/src/mlpack/methods/hmm/hmm_impl.hpp @@ -30,15 +30,18 @@ HMM::HMM(const size_t states, const Distribution emissions, const double tolerance) : emission(states, /* default distribution */ emissions), - transition(arma::randu(states, states)), - initial(arma::randu(states) / (double) states), + transitionProxy(arma::randu(states, states)), + initialProxy(arma::randu(states) / (double) states), dimensionality(emissions.Dimensionality()), tolerance(tolerance) { // Normalize the transition probabilities and initial state probabilities. - initial /= arma::accu(initial); - for (size_t i = 0; i < transition.n_cols; ++i) - transition.col(i) /= arma::accu(transition.col(i)); + initialProxy /= arma::accu(initialProxy); + for (size_t i = 0; i < transitionProxy.n_cols; ++i) + transitionProxy.col(i) /= arma::accu(transitionProxy.col(i)); + + logTransition = log(transitionProxy); + logInitial = log(initialProxy); } /** @@ -51,8 +54,10 @@ HMM::HMM(const arma::vec& initial, const std::vector& emission, const double tolerance) : emission(emission), - transition(transition), - initial(initial), + transitionProxy(transition), + logTransition(log(transition)), + initialProxy(initial), + logInitial(log(initial)), tolerance(tolerance) { // Set the dimensionality, if we can. @@ -91,7 +96,7 @@ double HMM::Train(const std::vector& dataSeq) // Maximum iterations? size_t iterations = 1000; - + // Find length of all sequences and ensure they are the correct size. size_t totalLength = 0; for (size_t seq = 0; seq < dataSeq.size(); seq++) @@ -106,7 +111,7 @@ double HMM::Train(const std::vector& dataSeq) // These are used later for training of each distribution. We initialize it // all now so we don't have to do any allocation later on. - std::vector emissionProb(transition.n_cols, + std::vector emissionProb(logTransition.n_cols, arma::vec(totalLength)); arma::mat emissionList(dimensionality, totalLength); @@ -116,9 +121,9 @@ double HMM::Train(const std::vector& dataSeq) for (size_t iter = 0; iter < iterations; iter++) { // Clear new transition matrix and emission probabilities. - arma::vec newLogInitial(transition.n_rows); + arma::vec newLogInitial(logTransition.n_rows); newLogInitial.fill(-std::numeric_limits::infinity()); - arma::mat newLogTransition(transition.n_rows, transition.n_cols); + arma::mat newLogTransition(logTransition.n_rows, logTransition.n_cols); newLogTransition.fill(-std::numeric_limits::infinity()); // Reset log likelihood. @@ -140,7 +145,7 @@ double HMM::Train(const std::vector& dataSeq) backwardLog, logScales); // Add to estimate of initial probability for state j. - for (size_t j = 0; j < transition.n_cols; ++j) + for (size_t j = 0; j < logTransition.n_cols; ++j) newLogInitial[j] = math::LogAdd(newLogInitial[j], stateLogProb(j, 0)); // Now re-estimate the parameters. This is the M-step. @@ -151,13 +156,13 @@ double HMM::Train(const std::vector& dataSeq) // We store the new estimates in a different matrix. for (size_t t = 0; t < dataSeq[seq].n_cols; ++t) { - for (size_t j = 0; j < transition.n_cols; ++j) + for (size_t j = 0; j < logTransition.n_cols; ++j) { if (t < dataSeq[seq].n_cols - 1) { // Estimate of T_ij (probability of transition from state j to state // i). We postpone multiplication of the old T_ij until later. - for (size_t i = 0; i < transition.n_rows; i++) + for (size_t i = 0; i < logTransition.n_rows; i++) { newLogTransition(i, j) = math::LogAdd(newLogTransition(i, j), forwardLog(j, t) + backwardLog(i, t + 1) + @@ -184,28 +189,30 @@ double HMM::Train(const std::vector& dataSeq) // Normalize the new initial probabilities. if (dataSeq.size() > 1) - initial = exp(newLogInitial) / dataSeq.size(); + logInitial = newLogInitial - log(dataSeq.size()); else - initial = exp(newLogInitial); + logInitial = newLogInitial; // Assign the new transition matrix. We use %= (element-wise // multiplication) because every element of the new transition matrix must // still be multiplied by the old elements (this is the multiplication we // earlier postponed). - transition %= exp(newLogTransition); - + logTransition += newLogTransition; + // Now we normalize the transition matrix. - for (size_t i = 0; i < transition.n_cols; i++) + for (size_t i = 0; i < logTransition.n_cols; i++) { - const double sum = accu(transition.col(i)); - if (sum > 0.0) - transition.col(i) /= sum; + const double sum = math::AccuLog(logTransition.col(i)); + if (std::isfinite(sum)) + logTransition.col(i) -= sum; else - transition.col(i).fill(1.0 / (double) transition.n_rows); + logTransition.col(i).fill(-log((double) logTransition.n_rows)); } + initialProxy = exp(logInitial); + transitionProxy = exp(logTransition); // Now estimate emission probabilities. - for (size_t state = 0; state < transition.n_cols; state++) + for (size_t state = 0; state < logTransition.n_cols; state++) emission[state].Train(emissionList, emissionProb[state]); Log::Debug << "Iteration " << iter << ": log-likelihood " << loglik @@ -230,8 +237,8 @@ void HMM::Train(const std::vector& dataSeq, << ")." << std::endl; } - initial.zeros(); - transition.zeros(); + arma::mat initial = arma::zeros(logInitial.n_elem); + arma::mat transition = arma::zeros(logTransition.n_rows, logTransition.n_cols); // Estimate the transition and emission matrices directly from the // observations. The emission list holds the time indices for observations @@ -283,6 +290,11 @@ void HMM::Train(const std::vector& dataSeq, if (sum > 0) transition.col(col) /= sum; } + + initialProxy = initial; + transitionProxy = transition; + logTransition = log(transition); + logInitial = log(initial); // Estimate emission matrix. for (size_t state = 0; state < transition.n_cols; state++) @@ -406,6 +418,8 @@ void HMM::Generate(const size_t length, // distribution of emissions for our starting state. dataSequence.col(0) = emission[startState].Random(); + ConvertToLogSpace(); + // Now choose the states and emissions for the rest of the sequence. for (size_t t = 1; t < length; t++) { @@ -415,9 +429,9 @@ void HMM::Generate(const size_t length, // Now find where our random value sits in the probability distribution of // state changes. double probSum = 0; - for (size_t st = 0; st < transition.n_rows; st++) + for (size_t st = 0; st < logTransition.n_rows; st++) { - probSum += transition(st, stateSequence[t - 1]); + probSum += exp(logTransition(st, stateSequence[t - 1])); if (randValue <= probSum) { stateSequence[t] = st; @@ -444,20 +458,18 @@ double HMM::Predict(const arma::mat& dataSeq, // don't use log-likelihoods to save that little bit of time, but we'll // calculate the log-likelihood at the end of it all. stateSeq.set_size(dataSeq.n_cols); - arma::mat logStateProb(transition.n_rows, dataSeq.n_cols); - arma::mat stateSeqBack(transition.n_rows, dataSeq.n_cols); - - // Store the logs of the transposed transition matrix. This is because we - // will be using the rows of the transition matrix. - arma::mat logTrans(log(trans(transition))); + arma::mat logStateProb(logTransition.n_rows, dataSeq.n_cols); + arma::mat stateSeqBack(logTransition.n_rows, dataSeq.n_cols); + ConvertToLogSpace(); + // The calculation of the first state is slightly different; the probability // of the first state being state j is the maximum probability that the state // came to be j from another state. logStateProb.col(0).zeros(); - for (size_t state = 0; state < transition.n_rows; state++) + for (size_t state = 0; state < logTransition.n_rows; state++) { - logStateProb(state, 0) = log(initial[state]) + + logStateProb(state, 0) = logInitial[state] + emission[state].LogProbability(dataSeq.unsafe_col(0)); stateSeqBack(state, 0) = state; } @@ -469,9 +481,9 @@ double HMM::Predict(const arma::mat& dataSeq, // Assemble the state probability for this element. // Given that we are in state j, we use state with the highest probability // of being the previous state. - for (size_t j = 0; j < transition.n_rows; j++) + for (size_t j = 0; j < logTransition.n_rows; j++) { - arma::vec prob = logStateProb.col(t - 1) + logTrans.col(j); + arma::vec prob = logStateProb.col(t - 1) + logTransition.row(j).t(); logStateProb(j, t) = prob.max(index) + emission[j].LogProbability(dataSeq.unsafe_col(t)); stateSeqBack(j, t) = index; @@ -518,11 +530,11 @@ void HMM::Filter(const arma::mat& dataSeq, arma::vec logScales; Forward(dataSeq, logScales, forwardLogProb); - arma::mat forwardProb = exp(forwardLogProb); - // Propagate state ahead. if (ahead != 0) - forwardProb = pow(transition, ahead) * forwardProb; + forwardLogProb += ahead * logTransition; + + arma::mat forwardProb = exp(forwardLogProb); // Compute expected emissions. // Will not work for distributions without a Mean() function. @@ -563,21 +575,21 @@ void HMM::Forward(const arma::mat& dataSeq, { // Our goal is to calculate the forward probabilities: // P(X_k | o_{1:k}) for all possible states X_k, for each time point k. - forwardLogProb.resize(transition.n_rows, dataSeq.n_cols); + forwardLogProb.resize(logTransition.n_rows, dataSeq.n_cols); forwardLogProb.fill(-std::numeric_limits::infinity()); logScales.resize(dataSeq.n_cols); logScales.fill(-std::numeric_limits::infinity()); - - arma::mat logTrans = trans(log(transition)); - + + ConvertToLogSpace(); + // The first entry in the forward algorithm uses the initial state // probabilities. Note that MATLAB assumes that the starting state (at // t = -1) is state 0; this is not our assumption here. To force that // behavior, you could append a single starting state to every single data // sequence and that should produce results in line with MATLAB. - for (size_t state = 0; state < transition.n_rows; state++) + for (size_t state = 0; state < logTransition.n_rows; state++) { - forwardLogProb(state, 0) = log(initial(state)) + + forwardLogProb(state, 0) = logInitial(state) + emission[state].LogProbability(dataSeq.unsafe_col(0)); } @@ -589,12 +601,12 @@ void HMM::Forward(const arma::mat& dataSeq, // Now compute the probabilities for each successive observation. for (size_t t = 1; t < dataSeq.n_cols; t++) { - for (size_t j = 0; j < transition.n_rows; j++) + for (size_t j = 0; j < logTransition.n_rows; j++) { // The forward probability of state j at time t is the sum over all states // of the probability of the previous state transitioning to the current // state and emitting the given observation. - arma::vec tmp = forwardLogProb.col(t - 1) + logTrans.col(j); + arma::vec tmp = forwardLogProb.col(t - 1) + logTransition.row(j).t(); forwardLogProb(j, t) = math::AccuLog(tmp) + emission[j].LogProbability(dataSeq.unsafe_col(t)); } @@ -613,9 +625,8 @@ void HMM::Backward(const arma::mat& dataSeq, { // Our goal is to calculate the backward probabilities: // P(X_k | o_{k + 1:T}) for all possible states X_k, for each time point k. - backwardLogProb.resize(transition.n_rows, dataSeq.n_cols); + backwardLogProb.resize(logTransition.n_rows, dataSeq.n_cols); backwardLogProb.fill(-std::numeric_limits::infinity()); - arma::mat logTrans = log(transition); // The last element probability is 1. backwardLogProb.col(dataSeq.n_cols - 1).fill(0); @@ -623,16 +634,16 @@ void HMM::Backward(const arma::mat& dataSeq, // Now step backwards through all other observations. for (size_t t = dataSeq.n_cols - 2; t + 1 > 0; t--) { - for (size_t j = 0; j < transition.n_rows; j++) + for (size_t j = 0; j < logTransition.n_rows; j++) { // The backward probability of state j at time t is the sum over all state // of the probability of the next state having been a transition from the // current state multiplied by the probability of each of those states // emitting the given observation. - for (size_t state = 0; state < transition.n_rows; state++) + for (size_t state = 0; state < logTransition.n_rows; state++) { backwardLogProb(j, t) = math::LogAdd(backwardLogProb(j, t), - logTrans(state, j) + backwardLogProb(state, t + 1) + logTransition(state, j) + backwardLogProb(state, t + 1) + emission[state].LogProbability(dataSeq.unsafe_col(t + 1))); } @@ -643,23 +654,46 @@ void HMM::Backward(const arma::mat& dataSeq, } } +template +void HMM::ConvertToLogSpace() const { + logInitial = log(initialProxy); + logTransition = log(transitionProxy); +} + //! Serialize the HMM. template template -void HMM::serialize(Archive& ar, const unsigned int /* version */) +void HMM::load(Archive& ar, const unsigned int /* version */) { + arma::mat transition; + arma::vec initial; ar & BOOST_SERIALIZATION_NVP(dimensionality); ar & BOOST_SERIALIZATION_NVP(tolerance); ar & BOOST_SERIALIZATION_NVP(transition); ar & BOOST_SERIALIZATION_NVP(initial); - + // Now serialize each emission. If we are loading, we must resize the vector // of emissions correctly. - if (Archive::is_loading::value) - emission.resize(transition.n_rows); - + emission.resize(transition.n_rows); // Load the emissions; generate the correct name for each one. - ar & BOOST_SERIALIZATION_NVP(emission); + ar & BOOST_SERIALIZATION_NVP(emission); + + logTransition = log(transition); + logInitial = log(initial); + } + +//! Serialize the HMM. +template +template +void HMM::save(Archive& ar, const unsigned int /* version */) const +{ + arma::mat transition = exp(logTransition); + arma::vec initial = exp(logInitial); + ar & BOOST_SERIALIZATION_NVP(dimensionality); + ar & BOOST_SERIALIZATION_NVP(tolerance); + ar & BOOST_SERIALIZATION_NVP(transition); + ar & BOOST_SERIALIZATION_NVP(initial); + ar & BOOST_SERIALIZATION_NVP(emission); } } // namespace hmm From efb1eff25ce904ed19cad9a6881dc13b359eea93 Mon Sep 17 00:00:00 2001 From: Arash Abghari Date: Fri, 8 Nov 2019 05:50:39 -0500 Subject: [PATCH 2/8] init proxy matrices at load --- src/mlpack/methods/hmm/hmm_impl.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mlpack/methods/hmm/hmm_impl.hpp b/src/mlpack/methods/hmm/hmm_impl.hpp index ef857f4ee56..c65fcc4b2da 100644 --- a/src/mlpack/methods/hmm/hmm_impl.hpp +++ b/src/mlpack/methods/hmm/hmm_impl.hpp @@ -678,6 +678,8 @@ void HMM::load(Archive& ar, const unsigned int /* version */) // Load the emissions; generate the correct name for each one. ar & BOOST_SERIALIZATION_NVP(emission); + initialProxy = initial; + transitionProxy = transition; logTransition = log(transition); logInitial = log(initial); } From 98e2e5631e5bc418b43d64cee95afab2e3fe3b54 Mon Sep 17 00:00:00 2001 From: Arash Abghari Date: Wed, 20 Nov 2019 10:57:08 -0500 Subject: [PATCH 3/8] * Introduced boolean variables to reduce the overhead of keeping linear and log space in sync for initial and transition matrices * use std::move() to avoid extra copy in load() --- src/mlpack/methods/hmm/hmm.hpp | 18 ++++++++++--- src/mlpack/methods/hmm/hmm_impl.hpp | 39 ++++++++++++++++++++--------- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/src/mlpack/methods/hmm/hmm.hpp b/src/mlpack/methods/hmm/hmm.hpp index 021ff74dadd..dccacfab988 100644 --- a/src/mlpack/methods/hmm/hmm.hpp +++ b/src/mlpack/methods/hmm/hmm.hpp @@ -325,12 +325,12 @@ class HMM //! Return the vector of initial state probabilities. const arma::vec Initial() const { return initialProxy; } //! Modify the vector of initial state probabilities. - arma::vec& Initial() { return initialProxy; } + arma::vec& Initial() { recalculateInitial = true; return initialProxy; } //! Return the transition matrix. const arma::mat Transition() const { return transitionProxy; } //! Return a modifiable transition matrix reference. - arma::mat& Transition() { return transitionProxy; } + arma::mat& Transition() { recalculateTransition = true; return transitionProxy; } //! Return the emission distributions. const std::vector& Emission() const { return emission; } @@ -393,15 +393,19 @@ class HMM //! Set of emission probability distributions; one for each state. std::vector emission; + //! a proxy vriable in linear space for logTransition arma::mat transitionProxy; //! Transition probability matrix. mutable arma::mat logTransition; private: - - void ConvertToLogSpace() const; + /** + * Make sure the variables in log space are in sync with the linear counter parts + */ + void ConvertToLogSpace() const; + //! a proxy vriable in linear space for logInitial arma::vec initialProxy; //! Initial state probability vector. @@ -412,6 +416,12 @@ class HMM //! Tolerance of Baum-Welch algorithm. double tolerance; + + //! Whether or not we need to update the logInitial from initialProxy + mutable bool recalculateInitial; + + //! Whether or not we need to update the logTransition from transitionProxy + mutable bool recalculateTransition; }; } // namespace hmm diff --git a/src/mlpack/methods/hmm/hmm_impl.hpp b/src/mlpack/methods/hmm/hmm_impl.hpp index c65fcc4b2da..abc136b0a66 100644 --- a/src/mlpack/methods/hmm/hmm_impl.hpp +++ b/src/mlpack/methods/hmm/hmm_impl.hpp @@ -33,13 +33,15 @@ HMM::HMM(const size_t states, transitionProxy(arma::randu(states, states)), initialProxy(arma::randu(states) / (double) states), dimensionality(emissions.Dimensionality()), - tolerance(tolerance) + tolerance(tolerance), + recalculateInitial(false), + recalculateTransition(false) { // Normalize the transition probabilities and initial state probabilities. initialProxy /= arma::accu(initialProxy); for (size_t i = 0; i < transitionProxy.n_cols; ++i) - transitionProxy.col(i) /= arma::accu(transitionProxy.col(i)); - + transitionProxy.col(i) /= arma::accu(transitionProxy.col(i)); + logTransition = log(transitionProxy); logInitial = log(initialProxy); } @@ -58,7 +60,9 @@ HMM::HMM(const arma::vec& initial, logTransition(log(transition)), initialProxy(initial), logInitial(log(initial)), - tolerance(tolerance) + tolerance(tolerance), + recalculateInitial(false), + recalculateTransition(false) { // Set the dimensionality, if we can. if (emission.size() > 0) @@ -198,7 +202,7 @@ double HMM::Train(const std::vector& dataSeq) // still be multiplied by the old elements (this is the multiplication we // earlier postponed). logTransition += newLogTransition; - + // Now we normalize the transition matrix. for (size_t i = 0; i < logTransition.n_cols; i++) { @@ -238,7 +242,8 @@ void HMM::Train(const std::vector& dataSeq, } arma::mat initial = arma::zeros(logInitial.n_elem); - arma::mat transition = arma::zeros(logTransition.n_rows, logTransition.n_cols); + arma::mat transition = arma::zeros(logTransition.n_rows, + logTransition.n_cols); // Estimate the transition and emission matrices directly from the // observations. The emission list holds the time indices for observations @@ -290,7 +295,7 @@ void HMM::Train(const std::vector& dataSeq, if (sum > 0) transition.col(col) /= sum; } - + initialProxy = initial; transitionProxy = transition; logTransition = log(transition); @@ -419,7 +424,7 @@ void HMM::Generate(const size_t length, dataSequence.col(0) = emission[startState].Random(); ConvertToLogSpace(); - + // Now choose the states and emissions for the rest of the sequence. for (size_t t = 1; t < length; t++) { @@ -654,10 +659,20 @@ void HMM::Backward(const arma::mat& dataSeq, } } +/** + * Make sure the variables in log space are in sync with the linear counter parts + */ template void HMM::ConvertToLogSpace() const { - logInitial = log(initialProxy); - logTransition = log(transitionProxy); + if(recalculateInitial){ + logInitial = log(initialProxy); + recalculateInitial = false; + } + + if(recalculateTransition){ + logTransition = log(transitionProxy); + recalculateTransition = false; + } } //! Serialize the HMM. @@ -678,10 +693,10 @@ void HMM::load(Archive& ar, const unsigned int /* version */) // Load the emissions; generate the correct name for each one. ar & BOOST_SERIALIZATION_NVP(emission); - initialProxy = initial; - transitionProxy = transition; logTransition = log(transition); logInitial = log(initial); + initialProxy = std::move(initial); + transitionProxy = std::move(transition); } //! Serialize the HMM. From 28474ef69f8de4eb47f03e4f1195dc1f04fe8265 Mon Sep 17 00:00:00 2001 From: Arash Abghari Date: Wed, 20 Nov 2019 11:58:53 -0500 Subject: [PATCH 4/8] fixed white spaces and other style issues --- src/mlpack/methods/hmm/hmm.hpp | 5 +++-- src/mlpack/methods/hmm/hmm_impl.hpp | 22 +++++++++++----------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/mlpack/methods/hmm/hmm.hpp b/src/mlpack/methods/hmm/hmm.hpp index dccacfab988..31e1a4a7f52 100644 --- a/src/mlpack/methods/hmm/hmm.hpp +++ b/src/mlpack/methods/hmm/hmm.hpp @@ -330,7 +330,8 @@ class HMM //! Return the transition matrix. const arma::mat Transition() const { return transitionProxy; } //! Return a modifiable transition matrix reference. - arma::mat& Transition() { recalculateTransition = true; return transitionProxy; } + arma::mat& Transition() { recalculateTransition = true; + return transitionProxy; } //! Return the emission distributions. const std::vector& Emission() const { return emission; } @@ -404,7 +405,7 @@ class HMM * Make sure the variables in log space are in sync with the linear counter parts */ void ConvertToLogSpace() const; - + //! a proxy vriable in linear space for logInitial arma::vec initialProxy; diff --git a/src/mlpack/methods/hmm/hmm_impl.hpp b/src/mlpack/methods/hmm/hmm_impl.hpp index abc136b0a66..591a4ad1994 100644 --- a/src/mlpack/methods/hmm/hmm_impl.hpp +++ b/src/mlpack/methods/hmm/hmm_impl.hpp @@ -100,7 +100,7 @@ double HMM::Train(const std::vector& dataSeq) // Maximum iterations? size_t iterations = 1000; - + // Find length of all sequences and ensure they are the correct size. size_t totalLength = 0; for (size_t seq = 0; seq < dataSeq.size(); seq++) @@ -242,8 +242,8 @@ void HMM::Train(const std::vector& dataSeq, } arma::mat initial = arma::zeros(logInitial.n_elem); - arma::mat transition = arma::zeros(logTransition.n_rows, - logTransition.n_cols); + arma::mat transition = arma::zeros(logTransition.n_rows, + logTransition.n_cols); // Estimate the transition and emission matrices directly from the // observations. The emission list holds the time indices for observations @@ -467,7 +467,7 @@ double HMM::Predict(const arma::mat& dataSeq, arma::mat stateSeqBack(logTransition.n_rows, dataSeq.n_cols); ConvertToLogSpace(); - + // The calculation of the first state is slightly different; the probability // of the first state being state j is the maximum probability that the state // came to be j from another state. @@ -584,9 +584,9 @@ void HMM::Forward(const arma::mat& dataSeq, forwardLogProb.fill(-std::numeric_limits::infinity()); logScales.resize(dataSeq.n_cols); logScales.fill(-std::numeric_limits::infinity()); - + ConvertToLogSpace(); - + // The first entry in the forward algorithm uses the initial state // probabilities. Note that MATLAB assumes that the starting state (at // t = -1) is state 0; this is not our assumption here. To force that @@ -664,12 +664,12 @@ void HMM::Backward(const arma::mat& dataSeq, */ template void HMM::ConvertToLogSpace() const { - if(recalculateInitial){ + if( recalculateInitial ){ logInitial = log(initialProxy); recalculateInitial = false; } - - if(recalculateTransition){ + + if( recalculateTransition ){ logTransition = log(transitionProxy); recalculateTransition = false; } @@ -686,13 +686,13 @@ void HMM::load(Archive& ar, const unsigned int /* version */) ar & BOOST_SERIALIZATION_NVP(tolerance); ar & BOOST_SERIALIZATION_NVP(transition); ar & BOOST_SERIALIZATION_NVP(initial); - + // Now serialize each emission. If we are loading, we must resize the vector // of emissions correctly. emission.resize(transition.n_rows); // Load the emissions; generate the correct name for each one. ar & BOOST_SERIALIZATION_NVP(emission); - + logTransition = log(transition); logInitial = log(initial); initialProxy = std::move(initial); From 6bb592cdd95f01f8528b997fb53a410acf67d897 Mon Sep 17 00:00:00 2001 From: Arash Abghari Date: Wed, 20 Nov 2019 12:06:56 -0500 Subject: [PATCH 5/8] fixed style issues --- src/mlpack/methods/hmm/hmm_impl.hpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/mlpack/methods/hmm/hmm_impl.hpp b/src/mlpack/methods/hmm/hmm_impl.hpp index 591a4ad1994..5f880613f0b 100644 --- a/src/mlpack/methods/hmm/hmm_impl.hpp +++ b/src/mlpack/methods/hmm/hmm_impl.hpp @@ -664,12 +664,14 @@ void HMM::Backward(const arma::mat& dataSeq, */ template void HMM::ConvertToLogSpace() const { - if( recalculateInitial ){ + if (recalculateInitial) + { logInitial = log(initialProxy); recalculateInitial = false; } - if( recalculateTransition ){ + if (recalculateTransition) + { logTransition = log(transitionProxy); recalculateTransition = false; } @@ -697,12 +699,13 @@ void HMM::load(Archive& ar, const unsigned int /* version */) logInitial = log(initial); initialProxy = std::move(initial); transitionProxy = std::move(transition); - } +} //! Serialize the HMM. template template -void HMM::save(Archive& ar, const unsigned int /* version */) const +void HMM::save(Archive& ar, + const unsigned int /* version */) const { arma::mat transition = exp(logTransition); arma::vec initial = exp(logInitial); From ec4d1bd176e55598cb46d4665129bcf3512b0b6b Mon Sep 17 00:00:00 2001 From: Arash Abghari Date: Wed, 20 Nov 2019 12:14:34 -0500 Subject: [PATCH 6/8] fixed style issues --- src/mlpack/methods/hmm/hmm_impl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mlpack/methods/hmm/hmm_impl.hpp b/src/mlpack/methods/hmm/hmm_impl.hpp index 5f880613f0b..9820b455ccb 100644 --- a/src/mlpack/methods/hmm/hmm_impl.hpp +++ b/src/mlpack/methods/hmm/hmm_impl.hpp @@ -704,7 +704,7 @@ void HMM::load(Archive& ar, const unsigned int /* version */) //! Serialize the HMM. template template -void HMM::save(Archive& ar, +void HMM::save(Archive& ar, const unsigned int /* version */) const { arma::mat transition = exp(logTransition); From 760c9e4680851668528999bf43145e58bdc08d8c Mon Sep 17 00:00:00 2001 From: Arash Abghari Date: Wed, 20 Nov 2019 14:06:32 -0500 Subject: [PATCH 7/8] added comment regarding mlpack 4.0 --- src/mlpack/methods/hmm/hmm.hpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/mlpack/methods/hmm/hmm.hpp b/src/mlpack/methods/hmm/hmm.hpp index 31e1a4a7f52..4d851fa1dff 100644 --- a/src/mlpack/methods/hmm/hmm.hpp +++ b/src/mlpack/methods/hmm/hmm.hpp @@ -394,22 +394,23 @@ class HMM //! Set of emission probability distributions; one for each state. std::vector emission; - //! a proxy vriable in linear space for logTransition + //! a proxy vriable in linear space for logTransition. Should be removed in mlpack 4.0. arma::mat transitionProxy; - //! Transition probability matrix. + //! Transition probability matrix. No need to be mutable in mlpack 4.0. mutable arma::mat logTransition; private: /** - * Make sure the variables in log space are in sync with the linear counter parts + * Make sure the variables in log space are in sync with the linear counter parts. + * Should be removed in mlpack 4.0. */ void ConvertToLogSpace() const; - //! a proxy vriable in linear space for logInitial + //! a proxy vriable in linear space for logInitial. Should be removed in mlpack 4.0. arma::vec initialProxy; - //! Initial state probability vector. + //! Initial state probability vector. No need to be mutable in mlpack 4.0. mutable arma::vec logInitial; //! Dimensionality of observations. @@ -418,10 +419,10 @@ class HMM //! Tolerance of Baum-Welch algorithm. double tolerance; - //! Whether or not we need to update the logInitial from initialProxy + //! Whether or not we need to update the logInitial from initialProxy. Should be removed in mlpack 4.0. mutable bool recalculateInitial; - //! Whether or not we need to update the logTransition from transitionProxy + //! Whether or not we need to update the logTransition from transitionProxy. Should be removed in mlpack 4.0. mutable bool recalculateTransition; }; From 086de1ee1e5a7d26cdb6fcfa31bc0e8ff901cd4c Mon Sep 17 00:00:00 2001 From: Arash Abghari Date: Wed, 20 Nov 2019 14:27:21 -0500 Subject: [PATCH 8/8] make sure lines are <= 80 characters long --- src/mlpack/methods/hmm/hmm.hpp | 35 +++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/src/mlpack/methods/hmm/hmm.hpp b/src/mlpack/methods/hmm/hmm.hpp index 4d851fa1dff..ff782668da9 100644 --- a/src/mlpack/methods/hmm/hmm.hpp +++ b/src/mlpack/methods/hmm/hmm.hpp @@ -201,14 +201,14 @@ class HMM * of the most probable sequence is returned. * * @param dataSeq Sequence of observations. - * @param stateProb Matrix in which the log probabilities of each state at each - * time interval will be stored. - * @param forwardProb Matrix in which the forward log probabilities of each state - * at each time interval will be stored. + * @param stateProb Matrix in which the log probabilities of each state at + * each time interval will be stored. + * @param forwardProb Matrix in which the forward log probabilities of each + * state at each time interval will be stored. * @param backwardProb Matrix in which the backward log probabilities of each * state at each time interval will be stored. - * @param scales Vector in which the log of scaling factors at each time interval - * will be stored. + * @param scales Vector in which the log of scaling factors at each time + * interval will be stored. * @return Log-likelihood of most likely state sequence. */ double LogEstimate(const arma::mat& dataSeq, @@ -394,7 +394,10 @@ class HMM //! Set of emission probability distributions; one for each state. std::vector emission; - //! a proxy vriable in linear space for logTransition. Should be removed in mlpack 4.0. + /** + *a proxy vriable in linear space for logTransition. + *Should be removed in mlpack 4.0. + */ arma::mat transitionProxy; //! Transition probability matrix. No need to be mutable in mlpack 4.0. @@ -402,12 +405,16 @@ class HMM private: /** - * Make sure the variables in log space are in sync with the linear counter parts. + * Make sure the variables in log space are in sync + * with the linear counter parts. * Should be removed in mlpack 4.0. */ void ConvertToLogSpace() const; - //! a proxy vriable in linear space for logInitial. Should be removed in mlpack 4.0. + /** + * a proxy vriable in linear space for logInitial. + * Should be removed in mlpack 4.0. + */ arma::vec initialProxy; //! Initial state probability vector. No need to be mutable in mlpack 4.0. @@ -419,10 +426,16 @@ class HMM //! Tolerance of Baum-Welch algorithm. double tolerance; - //! Whether or not we need to update the logInitial from initialProxy. Should be removed in mlpack 4.0. + /** + * Whether or not we need to update the logInitial from initialProxy. + * Should be removed in mlpack 4.0. + */ mutable bool recalculateInitial; - //! Whether or not we need to update the logTransition from transitionProxy. Should be removed in mlpack 4.0. + /** + * Whether or not we need to update the logTransition from transitionProxy. + * Should be removed in mlpack 4.0. + */ mutable bool recalculateTransition; };