Skip to content

Commit

Permalink
Merge pull request #390 from stephentu/master
Browse files Browse the repository at this point in the history
Add covariance factorization caching to gaussian distribution
  • Loading branch information
rcurtin committed Jan 26, 2015
2 parents cfe8c10 + 046bdc4 commit 71a9f87
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 80 deletions.
66 changes: 59 additions & 7 deletions src/mlpack/core/dists/gaussian_distribution.cpp
Expand Up @@ -10,21 +10,65 @@
using namespace mlpack;
using namespace mlpack::distribution;


GaussianDistribution::GaussianDistribution(const arma::vec& mean,
const arma::mat& covariance)
: mean(mean)
{
Covariance(covariance);
}

void GaussianDistribution::Covariance(const arma::mat& covariance)
{
this->covariance = covariance;
FactorCovariance();
}

void GaussianDistribution::Covariance(arma::mat&& covariance)
{
this->covariance = std::move(covariance);
FactorCovariance();
}

void GaussianDistribution::FactorCovariance()
{
covLower = arma::chol(covariance, "lower");

// Comment from rcurtin:
//
// I think the use of the word "interpret" in the Armadillo documentation
// about trimatl and trimatu is somewhat misleading. What the function will
// actually do, when used in that context, is loop over the upper triangular
// part of the matrix and set it all to 0, so this ends up actually just
// burning cycles---also because the operator=() evaluates the expression and
// strips the knowledge that it's a lower triangular matrix. So then the call
// to .i() doesn't actually do anything smarter.
//
// But perusing fn_inv.hpp more closely, there is a specialization that will
// work when called like this: inv(trimatl(covLower)), and will use LAPACK's
// ?trtri functions. However, it will still set the upper triangular part to
// 0 after the method. That last part is unnecessary, but baked into
// Armadillo, so there's not really much that can be done about that without
// discussion with the Armadillo maintainer.
const arma::mat invCovLower = arma::inv(arma::trimatl(covLower));

invCov = invCovLower.t() * invCovLower;
double sign = 0.;
arma::log_det(logDetCov, sign, covLower);
logDetCov *= 2;
}

double GaussianDistribution::LogProbability(const arma::vec& observation) const
{
const size_t k = observation.n_elem;
double logdetsigma = 0;
double sign = 0.;
arma::log_det(logdetsigma, sign, covariance);
const arma::vec diff = mean - observation;
const arma::vec v = (diff.t() * arma::inv(covariance) * diff);
return -0.5 * k * log2pi - 0.5 * logdetsigma - 0.5 * v(0);
const arma::vec v = (diff.t() * invCov * diff);
return -0.5 * k * log2pi - 0.5 * logDetCov - 0.5 * v(0);
}

arma::vec GaussianDistribution::Random() const
{
// Should we store chol(covariance) for easier calculation later?
return trans(chol(covariance)) * arma::randn<arma::vec>(mean.n_elem) + mean;
return covLower * arma::randn<arma::vec>(mean.n_elem) + mean;
}

/**
Expand All @@ -41,6 +85,7 @@ void GaussianDistribution::Estimate(const arma::mat& observations)
}
else // This will end up just being empty.
{
// TODO(stephentu): why do we allow this case? why not throw an error?
mean.zeros(0);
covariance.zeros(0);
return;
Expand Down Expand Up @@ -77,6 +122,8 @@ void GaussianDistribution::Estimate(const arma::mat& observations)
perturbation *= 10; // Slow, but we don't want to add too much.
}
}

FactorCovariance();
}

/**
Expand All @@ -94,6 +141,7 @@ void GaussianDistribution::Estimate(const arma::mat& observations,
}
else // This will end up just being empty.
{
// TODO(stephentu): same as above
mean.zeros(0);
covariance.zeros(0);
return;
Expand All @@ -114,6 +162,7 @@ void GaussianDistribution::Estimate(const arma::mat& observations,
// Nothing in this Gaussian! At least set the covariance so that it's
// invertible.
covariance.diag() += 1e-50;
FactorCovariance();
return;
}

Expand Down Expand Up @@ -143,6 +192,8 @@ void GaussianDistribution::Estimate(const arma::mat& observations,
perturbation *= 10; // Slow, but we don't want to add too much.
}
}

FactorCovariance();
}

/**
Expand Down Expand Up @@ -180,4 +231,5 @@ void GaussianDistribution::Load(const util::SaveRestoreUtility& sr)
{
sr.LoadParameter(mean, "mean");
sr.LoadParameter(covariance, "covariance");
FactorCovariance();
}
36 changes: 24 additions & 12 deletions src/mlpack/core/dists/gaussian_distribution.hpp
Expand Up @@ -21,8 +21,14 @@ class GaussianDistribution
private:
//! Mean of the distribution.
arma::vec mean;
//! Covariance of the distribution.
//! Positive definite covariance of the distribution.
arma::mat covariance;
//! Lower triangular factor of cov (e.g. cov = LL^T).
arma::mat covLower;
//! Cached inverse of covariance.
arma::mat invCov;
//! Cached logdet(cov).
double logDetCov;

//! log(2pi)
static const constexpr double log2pi = 1.83787706640934533908193770912475883;
Expand All @@ -39,14 +45,20 @@ class GaussianDistribution
*/
GaussianDistribution(const size_t dimension) :
mean(arma::zeros<arma::vec>(dimension)),
covariance(arma::eye<arma::mat>(dimension, dimension))
covariance(arma::eye<arma::mat>(dimension, dimension)),
covLower(arma::eye<arma::mat>(dimension, dimension)),
invCov(arma::eye<arma::mat>(dimension, dimension)),
logDetCov(0)
{ /* Nothing to do. */ }

/**
* Create a Gaussian distribution with the given mean and covariance.
*
* covariance is expected to be positive definite.
*/
GaussianDistribution(const arma::vec& mean, const arma::mat& covariance) :
mean(mean), covariance(covariance) { /* Nothing to do. */ }
GaussianDistribution(const arma::vec& mean, const arma::mat& covariance);

// TODO(stephentu): do we want a (arma::vec&&, arma::mat&&) ctor?

//! Return the dimensionality of this distribution.
size_t Dimensionality() const { return mean.n_elem; }
Expand Down Expand Up @@ -119,9 +131,11 @@ class GaussianDistribution
const arma::mat& Covariance() const { return covariance; }

/**
* Return a modifiable copy of the covariance.
* Set the covariance.
*/
arma::mat& Covariance() { return covariance; }
void Covariance(const arma::mat& covariance);

void Covariance(arma::mat&& covariance);

/**
* Returns a string representation of this object.
Expand All @@ -135,7 +149,8 @@ class GaussianDistribution
void Load(const util::SaveRestoreUtility& n);
static std::string const Type() { return "GaussianDistribution"; }


private:
void FactorCovariance();

};

Expand All @@ -156,17 +171,14 @@ inline void GaussianDistribution::LogProbability(const arma::mat& x,
// diffs). We just don't need any of the other elements. We can calculate
// the right hand part of the equation (instead of the left side) so that
// later we are referencing columns, not rows -- that is faster.
arma::mat rhs = -0.5 * inv(covariance) * diffs;
const arma::mat rhs = -0.5 * invCov * diffs;
arma::vec logExponents(diffs.n_cols); // We will now fill this.
for (size_t i = 0; i < diffs.n_cols; i++)
logExponents(i) = accu(diffs.unsafe_col(i) % rhs.unsafe_col(i));

double logdetsigma = 0;
double sign = 0.;
arma::log_det(logdetsigma, sign, covariance);
const size_t k = x.n_rows;

logProbabilities = -0.5 * k * log2pi - 0.5 * logdetsigma + logExponents;
logProbabilities = -0.5 * k * log2pi - 0.5 * logDetCov + logExponents;
}


Expand Down
4 changes: 3 additions & 1 deletion src/mlpack/core/dists/regression_distribution.hpp
Expand Up @@ -48,7 +48,9 @@ class RegressionDistribution
rf(regression::LinearRegression(predictors, responses))
{
err = GaussianDistribution(1);
err.Covariance() = rf.ComputeError(predictors, responses);
arma::mat cov(1, 1);
cov(0, 0) = rf.ComputeError(predictors, responses);
err.Covariance(std::move(cov));
}

/**
Expand Down
47 changes: 28 additions & 19 deletions src/mlpack/methods/gmm/em_fit_impl.hpp
Expand Up @@ -93,11 +93,12 @@ void EMFit<InitialClusteringType, CovarianceConstraintPolicy>::Estimate(
trans(condProb.col(i)));

// Don't update if there's no probability of the Gaussian having points.
if (probRowSums[i] != 0.0)
dists[i].Covariance() = (tmp * trans(tmpB)) / probRowSums[i];

// Apply covariance constraint.
constraint.ApplyConstraint(dists[i].Covariance());
if (probRowSums[i] != 0.0) {
arma::mat covariance = (tmp * trans(tmpB)) / probRowSums[i];
// Apply covariance constraint.
constraint.ApplyConstraint(covariance);
dists[i].Covariance(std::move(covariance));
}
}

// Calculate the new values for omega using the updated conditional
Expand Down Expand Up @@ -180,10 +181,12 @@ void EMFit<InitialClusteringType, CovarianceConstraintPolicy>::Estimate(
arma::mat tmpB = tmp % (arma::ones<arma::vec>(observations.n_rows) *
trans(condProb.col(i) % probabilities));

dists[i].Covariance() = (tmp * trans(tmpB)) / probRowSums[i];
arma::mat cov = (tmp * trans(tmpB)) / probRowSums[i];

// Apply covariance constraint.
constraint.ApplyConstraint(dists[i].Covariance());
constraint.ApplyConstraint(cov);

dists[i].Covariance(std::move(cov));
}

// Calculate the new values for omega using the updated conditional
Expand All @@ -210,12 +213,16 @@ InitialClustering(const arma::mat& observations,
// Run clustering algorithm.
clusterer.Cluster(observations, dists.size(), assignments);

std::vector<arma::vec> means(dists.size());
std::vector<arma::mat> covs(dists.size());

// Now calculate the means, covariances, and weights.
weights.zeros();
for (size_t i = 0; i < dists.size(); ++i)
{
dists[i].Mean().zeros();
dists[i].Covariance().zeros();
means[i].zeros(dists[i].Mean().n_elem);
covs[i].zeros(dists[i].Covariance().n_rows,
dists[i].Covariance().n_cols);
}

// From the assignments, generate our means, covariances, and weights.
Expand All @@ -224,11 +231,10 @@ InitialClustering(const arma::mat& observations,
const size_t cluster = assignments[i];

// Add this to the relevant mean.
dists[cluster].Mean() += observations.col(i);
means[cluster] += observations.col(i);

// Add this to the relevant covariance.
dists[cluster].Covariance() += observations.col(i) *
trans(observations.col(i));
covs[cluster] += observations.col(i) * trans(observations.col(i));

// Now add one to the weights (we will normalize).
weights[cluster]++;
Expand All @@ -237,22 +243,25 @@ InitialClustering(const arma::mat& observations,
// Now normalize the mean and covariance.
for (size_t i = 0; i < dists.size(); ++i)
{
dists[i].Mean() /= (weights[i] > 1) ? weights[i] : 1;
means[i] /= (weights[i] > 1) ? weights[i] : 1;
}

for (size_t i = 0; i < observations.n_cols; ++i)
{
const size_t cluster = assignments[i];
const arma::vec normObs = observations.col(i) - dists[cluster].Mean();
dists[cluster].Covariance() += normObs * normObs.t();
const arma::vec normObs = observations.col(i) - means[cluster];
covs[cluster] += normObs * normObs.t();
}

for (size_t i = 0; i < dists.size(); ++i)
{
dists[i].Covariance() /= (weights[i] > 1) ? weights[i] : 1;
covs[i] /= (weights[i] > 1) ? weights[i] : 1;

// Apply constraints to covariance matrix.
constraint.ApplyConstraint(dists[i].Covariance());
constraint.ApplyConstraint(covs[i]);

std::swap(dists[i].Mean(), means[i]);
dists[i].Covariance(std::move(covs[i]));
}

// Finally, normalize weights.
Expand All @@ -269,7 +278,7 @@ double EMFit<InitialClusteringType, CovarianceConstraintPolicy>::LogLikelihood(

arma::vec phis;
arma::mat likelihoods(dists.size(), observations.n_cols);

for (size_t i = 0; i < dists.size(); ++i)
{
dists[i].Probability(observations, phis);
Expand All @@ -283,7 +292,7 @@ double EMFit<InitialClusteringType, CovarianceConstraintPolicy>::LogLikelihood(
<< "outlier." << std::endl;
logLikelihood += log(accu(likelihoods.col(j)));
}

return logLikelihood;
}

Expand Down
4 changes: 3 additions & 1 deletion src/mlpack/methods/gmm/gmm_convert_main.cpp
Expand Up @@ -47,12 +47,14 @@ int main(int argc, char* argv[])
for (size_t i = 0; i < gaussians; ++i)
{
stringstream o;
arma::mat covariance;
o << i;
string meanName = "mean" + o.str();
string covName = "covariance" + o.str();

load.LoadParameter(gmm.Component(i).Mean(), meanName);
load.LoadParameter(gmm.Component(i).Covariance(), covName);
load.LoadParameter(covariance, covName);
gmm.Component(i).Covariance(std::move(covariance));
}

gmm.Save(CLI::GetParam<string>("output_file"));
Expand Down
8 changes: 6 additions & 2 deletions src/mlpack/methods/hmm/hmm_util_impl.hpp
Expand Up @@ -115,8 +115,10 @@ void ConvertHMM(HMM<distribution::GaussianDistribution>& hmm,
sr.LoadParameter(hmm.Emission()[i].Mean(), s.str());

s.str("");
arma::mat covariance;
s << "hmm_emission_covariance_" << i;
sr.LoadParameter(hmm.Emission()[i].Covariance(), s.str());
sr.LoadParameter(covariance, s.str());
hmm.Emission()[i].Covariance(std::move(covariance));
}

hmm.Dimensionality() = hmm.Emission()[0].Mean().n_elem;
Expand Down Expand Up @@ -168,7 +170,9 @@ void ConvertHMM(HMM<gmm::GMM<> >& hmm, const util::SaveRestoreUtility& sr)

s.str("");
s << "hmm_emission_" << i << "_gaussian_" << g << "_covariance";
sr.LoadParameter(hmm.Emission()[i].Component(g).Covariance(), s.str());
arma::mat covariance;
sr.LoadParameter(covariance, s.str());
hmm.Emission()[i].Component(g).Covariance(std::move(covariance));
}

s.str("");
Expand Down

0 comments on commit 71a9f87

Please sign in to comment.