Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/aabghari/mlpack into aabg…
Browse files Browse the repository at this point in the history
…hari-master
  • Loading branch information
rcurtin committed Dec 5, 2019
2 parents e35ef2e + f65e5e0 commit 631ccf5
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 71 deletions.
68 changes: 53 additions & 15 deletions src/mlpack/methods/hmm/hmm.hpp
Expand Up @@ -201,14 +201,14 @@ class HMM
* of the most probable sequence is returned.
*
* @param dataSeq Sequence of observations.
* @param stateProb Matrix in which the log probabilities of each state at each
* time interval will be stored.
* @param forwardProb Matrix in which the forward log probabilities of each state
* at each time interval will be stored.
* @param stateProb Matrix in which the log probabilities of each state at
* each time interval will be stored.
* @param forwardProb Matrix in which the forward log probabilities of each
* state at each time interval will be stored.
* @param backwardProb Matrix in which the backward log probabilities of each
* state at each time interval will be stored.
* @param scales Vector in which the log of scaling factors at each time interval
* will be stored.
* @param scales Vector in which the log of scaling factors at each time
* interval will be stored.
* @return Log-likelihood of most likely state sequence.
*/
double LogEstimate(const arma::mat& dataSeq,
Expand Down Expand Up @@ -323,14 +323,15 @@ class HMM
arma::mat& smoothSeq) const;

//! Return the vector of initial state probabilities.
const arma::vec& Initial() const { return initial; }
const arma::vec Initial() const { return initialProxy; }
//! Modify the vector of initial state probabilities.
arma::vec& Initial() { return initial; }
arma::vec& Initial() { recalculateInitial = true; return initialProxy; }

//! Return the transition matrix.
const arma::mat& Transition() const { return transition; }
const arma::mat Transition() const { return transitionProxy; }
//! Return a modifiable transition matrix reference.
arma::mat& Transition() { return transition; }
arma::mat& Transition() { recalculateTransition = true;
return transitionProxy; }

//! Return the emission distributions.
const std::vector<Distribution>& Emission() const { return emission; }
Expand All @@ -351,7 +352,13 @@ class HMM
* Serialize the object.
*/
template<typename Archive>
void serialize(Archive& ar, const unsigned int version);
void load(Archive& ar, const unsigned int version);

template<typename Archive>
void save(Archive& ar, const unsigned int version) const;

BOOST_SERIALIZATION_SPLIT_MEMBER()


protected:
// Helper functions.
Expand Down Expand Up @@ -387,18 +394,49 @@ class HMM
//! Set of emission probability distributions; one for each state.
std::vector<Distribution> emission;

//! Transition probability matrix.
arma::mat transition;
/**
*a proxy vriable in linear space for logTransition.
*Should be removed in mlpack 4.0.
*/
arma::mat transitionProxy;

//! Transition probability matrix. No need to be mutable in mlpack 4.0.
mutable arma::mat logTransition;

private:
//! Initial state probability vector.
arma::vec initial;
/**
* Make sure the variables in log space are in sync
* with the linear counter parts.
* Should be removed in mlpack 4.0.
*/
void ConvertToLogSpace() const;

/**
* a proxy vriable in linear space for logInitial.
* Should be removed in mlpack 4.0.
*/
arma::vec initialProxy;

//! Initial state probability vector. No need to be mutable in mlpack 4.0.
mutable arma::vec logInitial;

//! Dimensionality of observations.
size_t dimensionality;

//! Tolerance of Baum-Welch algorithm.
double tolerance;

/**
* Whether or not we need to update the logInitial from initialProxy.
* Should be removed in mlpack 4.0.
*/
mutable bool recalculateInitial;

/**
* Whether or not we need to update the logTransition from transitionProxy.
* Should be removed in mlpack 4.0.
*/
mutable bool recalculateTransition;
};

} // namespace hmm
Expand Down

0 comments on commit 631ccf5

Please sign in to comment.