From d000a5a5230b7ea18866e3e94c5edb09e02d29fa Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Mon, 12 Dec 2016 14:44:20 -0500 Subject: [PATCH 1/2] Don't use equal initial probabilities. That can cause training to fail sometimes. Instead, optimization seems to perform better when using random intiialization. --- src/mlpack/methods/hmm/hmm_impl.hpp | 11 ++++++++--- src/mlpack/tests/hmm_test.cpp | 5 +++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/mlpack/methods/hmm/hmm_impl.hpp b/src/mlpack/methods/hmm/hmm_impl.hpp index 8e854d2b427..8b98a0c8755 100644 --- a/src/mlpack/methods/hmm/hmm_impl.hpp +++ b/src/mlpack/methods/hmm/hmm_impl.hpp @@ -29,11 +29,16 @@ HMM::HMM(const size_t states, const Distribution emissions, const double tolerance) : emission(states, /* default distribution */ emissions), - transition(arma::ones(states, states) / (double) states), - initial(arma::ones(states) / (double) states), + transition(arma::randu(states, states)), + initial(arma::randu(states) / (double) states), dimensionality(emissions.Dimensionality()), tolerance(tolerance) -{ /* nothing to do */ } +{ + // Normalize the transition probabilities and initial state probabilities. + initial /= arma::accu(initial); + for (size_t i = 0; i < transition.n_cols; ++i) + transition.col(i) /= arma::accu(transition.col(i)); +} /** * Create the Hidden Markov Model with the given transition matrix and the given diff --git a/src/mlpack/tests/hmm_test.cpp b/src/mlpack/tests/hmm_test.cpp index ab41e904d32..c9016c51bff 100644 --- a/src/mlpack/tests/hmm_test.cpp +++ b/src/mlpack/tests/hmm_test.cpp @@ -412,9 +412,10 @@ BOOST_AUTO_TEST_CASE(DiscreteHMMLabeledTrainTest) BOOST_AUTO_TEST_CASE(DiscreteHMMSimpleGenerateTest) { // Very simple HMM. 4 emissions with equal probability and 2 states with - // equal probability. The default transition and emission matrices satisfy - // this property. + // equal probability. HMM hmm(2, DiscreteDistribution(4)); + hmm.Initial() = arma::ones(2) / 2.0; + hmm.Transition() = arma::ones(2, 2) / 2.0; // Now generate a really, really long sequence. arma::mat dataSeq; From c76f5ac9cd71aad46f7a9a395bb7ff95b3c29582 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Thu, 22 Dec 2016 10:14:12 -0500 Subject: [PATCH 2/2] Remove random initialization since it is done by default now. --- src/mlpack/methods/hmm/hmm_train_main.cpp | 66 ++--------------------- 1 file changed, 3 insertions(+), 63 deletions(-) diff --git a/src/mlpack/methods/hmm/hmm_train_main.cpp b/src/mlpack/methods/hmm/hmm_train_main.cpp index 599bc15cfa7..1bf4b10ddd1 100644 --- a/src/mlpack/methods/hmm/hmm_train_main.cpp +++ b/src/mlpack/methods/hmm/hmm_train_main.cpp @@ -29,9 +29,9 @@ PROGRAM_INFO("Hidden Markov Model (HMM) Training", "This program allows a " "\n\n" "The HMM is trained with the Baum-Welch algorithm if no labels are " "provided. The tolerance of the Baum-Welch algorithm can be set with the " - "--tolerance option. In general it is a good idea to use random " - "initialization in this case, which can be specified with the " - "--random_initialization (-r) option." + "--tolerance option. By default, the transition matrix is randomly " + "initialized and the emission distributions are initialized to fit the " + "extent of the data." "\n\n" "Optionally, a pre-created HMM model can be used as a guess for the " "transition matrix and emission probabilities; this is specifiable with " @@ -54,8 +54,6 @@ PARAM_STRING_OUT("output_model_file", "File to save trained HMM to.", "M"); PARAM_INT_IN("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0); PARAM_DOUBLE_IN("tolerance", "Tolerance of the Baum-Welch algorithm.", "T", 1e-5); -PARAM_FLAG("random_initialization", "Initialize emissions and transition " - "matrices with a uniform random distribution.", "r"); using namespace mlpack; using namespace mlpack::hmm; @@ -305,21 +303,6 @@ int main(int argc, char** argv) HMM hmm(size_t(states), DiscreteDistribution(maxEmission), tolerance); - // Initialize with random starting point. - if (CLI::HasParam("random_initialization")) - { - hmm.Transition().randu(); - for (size_t c = 0; c < hmm.Transition().n_cols; ++c) - hmm.Transition().col(c) /= arma::accu(hmm.Transition().col(c)); - - for (size_t e = 0; e < hmm.Emission().size(); ++e) - { - hmm.Emission()[e].Probabilities().randu(); - hmm.Emission()[e].Probabilities() /= - arma::accu(hmm.Emission()[e].Probabilities()); - } - } - // Now train it. Pass the already-loaded training data. Train::Apply(hmm, &trainSeq); } @@ -338,22 +321,6 @@ int main(int argc, char** argv) HMM hmm(size_t(states), GaussianDistribution(dimensionality), tolerance); - // Initialize with random starting point. - if (CLI::HasParam("random_initialization")) - { - hmm.Transition().randu(); - for (size_t c = 0; c < hmm.Transition().n_cols; ++c) - hmm.Transition().col(c) /= arma::accu(hmm.Transition().col(c)); - - for (size_t e = 0; e < hmm.Emission().size(); ++e) - { - hmm.Emission()[e].Mean().randu(); - // Generate random covariance. - arma::mat r = arma::randu(dimensionality, dimensionality); - hmm.Emission()[e].Covariance(r * r.t()); - } - } - // Now train it. Train::Apply(hmm, &trainSeq); } @@ -376,33 +343,6 @@ int main(int argc, char** argv) HMM hmm(size_t(states), GMM(size_t(gaussians), dimensionality), tolerance); - // Initialize with random starting point. - if (CLI::HasParam("random_initialization")) - { - hmm.Transition().randu(); - for (size_t c = 0; c < hmm.Transition().n_cols; ++c) - hmm.Transition().col(c) /= arma::accu(hmm.Transition().col(c)); - - for (size_t e = 0; e < hmm.Emission().size(); ++e) - { - // Random weights. - hmm.Emission()[e].Weights().randu(); - hmm.Emission()[e].Weights() /= - arma::accu(hmm.Emission()[e].Weights()); - - // Random means and covariances. - for (int g = 0; g < gaussians; ++g) - { - hmm.Emission()[e].Component(g).Mean().randu(); - - // Generate random covariance. - arma::mat r = arma::randu(dimensionality, - dimensionality); - hmm.Emission()[e].Component(g).Covariance(r * r.t()); - } - } - } - // Issue a warning if the user didn't give labels. if (!CLI::HasParam("labels_file")) Log::Warn << "Unlabeled training of GMM HMMs is almost certainly not "