Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HMM initialization: don't use equal initial probabilities. #828

Merged
merged 2 commits into from
Dec 22, 2016
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
66 changes: 3 additions & 63 deletions src/mlpack/methods/hmm/hmm_train_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ PROGRAM_INFO("Hidden Markov Model (HMM) Training", "This program allows a "
"\n\n"
"The HMM is trained with the Baum-Welch algorithm if no labels are "
"provided. The tolerance of the Baum-Welch algorithm can be set with the "
"--tolerance option. In general it is a good idea to use random "
"initialization in this case, which can be specified with the "
"--random_initialization (-r) option."
"--tolerance option. By default, the transition matrix is randomly "
"initialized and the emission distributions are initialized to fit the "
"extent of the data."
"\n\n"
"Optionally, a pre-created HMM model can be used as a guess for the "
"transition matrix and emission probabilities; this is specifiable with "
Expand All @@ -54,8 +54,6 @@ PARAM_STRING_OUT("output_model_file", "File to save trained HMM to.", "M");
PARAM_INT_IN("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
PARAM_DOUBLE_IN("tolerance", "Tolerance of the Baum-Welch algorithm.", "T",
1e-5);
PARAM_FLAG("random_initialization", "Initialize emissions and transition "
"matrices with a uniform random distribution.", "r");

using namespace mlpack;
using namespace mlpack::hmm;
Expand Down Expand Up @@ -305,21 +303,6 @@ int main(int argc, char** argv)
HMM<DiscreteDistribution> hmm(size_t(states),
DiscreteDistribution(maxEmission), tolerance);

// Initialize with random starting point.
if (CLI::HasParam("random_initialization"))
{
hmm.Transition().randu();
for (size_t c = 0; c < hmm.Transition().n_cols; ++c)
hmm.Transition().col(c) /= arma::accu(hmm.Transition().col(c));

for (size_t e = 0; e < hmm.Emission().size(); ++e)
{
hmm.Emission()[e].Probabilities().randu();
hmm.Emission()[e].Probabilities() /=
arma::accu(hmm.Emission()[e].Probabilities());
}
}

// Now train it. Pass the already-loaded training data.
Train::Apply(hmm, &trainSeq);
}
Expand All @@ -338,22 +321,6 @@ int main(int argc, char** argv)
HMM<GaussianDistribution> hmm(size_t(states),
GaussianDistribution(dimensionality), tolerance);

// Initialize with random starting point.
if (CLI::HasParam("random_initialization"))
{
hmm.Transition().randu();
for (size_t c = 0; c < hmm.Transition().n_cols; ++c)
hmm.Transition().col(c) /= arma::accu(hmm.Transition().col(c));

for (size_t e = 0; e < hmm.Emission().size(); ++e)
{
hmm.Emission()[e].Mean().randu();
// Generate random covariance.
arma::mat r = arma::randu<arma::mat>(dimensionality, dimensionality);
hmm.Emission()[e].Covariance(r * r.t());
}
}

// Now train it.
Train::Apply(hmm, &trainSeq);
}
Expand All @@ -376,33 +343,6 @@ int main(int argc, char** argv)
HMM<GMM> hmm(size_t(states), GMM(size_t(gaussians), dimensionality),
tolerance);

// Initialize with random starting point.
if (CLI::HasParam("random_initialization"))
{
hmm.Transition().randu();
for (size_t c = 0; c < hmm.Transition().n_cols; ++c)
hmm.Transition().col(c) /= arma::accu(hmm.Transition().col(c));

for (size_t e = 0; e < hmm.Emission().size(); ++e)
{
// Random weights.
hmm.Emission()[e].Weights().randu();
hmm.Emission()[e].Weights() /=
arma::accu(hmm.Emission()[e].Weights());

// Random means and covariances.
for (int g = 0; g < gaussians; ++g)
{
hmm.Emission()[e].Component(g).Mean().randu();

// Generate random covariance.
arma::mat r = arma::randu<arma::mat>(dimensionality,
dimensionality);
hmm.Emission()[e].Component(g).Covariance(r * r.t());
}
}
}

// Issue a warning if the user didn't give labels.
if (!CLI::HasParam("labels_file"))
Log::Warn << "Unlabeled training of GMM HMMs is almost certainly not "
Expand Down