Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add covariance factorization caching to gaussian distribution #390

Merged
merged 8 commits into from Jan 26, 2015
66 changes: 59 additions & 7 deletions src/mlpack/core/dists/gaussian_distribution.cpp
Expand Up @@ -10,21 +10,65 @@
using namespace mlpack;
using namespace mlpack::distribution;


GaussianDistribution::GaussianDistribution(const arma::vec& mean,
const arma::mat& covariance)
: mean(mean)
{
Covariance(covariance);
}

void GaussianDistribution::Covariance(const arma::mat& covariance)
{
this->covariance = covariance;
FactorCovariance();
}

void GaussianDistribution::Covariance(arma::mat&& covariance)
{
this->covariance = std::move(covariance);
FactorCovariance();
}

void GaussianDistribution::FactorCovariance()
{
covLower = arma::chol(covariance, "lower");

// Comment from rcurtin:
//
// I think the use of the word "interpret" in the Armadillo documentation
// about trimatl and trimatu is somewhat misleading. What the function will
// actually do, when used in that context, is loop over the upper triangular
// part of the matrix and set it all to 0, so this ends up actually just
// burning cycles---also because the operator=() evaluates the expression and
// strips the knowledge that it's a lower triangular matrix. So then the call
// to .i() doesn't actually do anything smarter.
//
// But perusing fn_inv.hpp more closely, there is a specialization that will
// work when called like this: inv(trimatl(covLower)), and will use LAPACK's
// ?trtri functions. However, it will still set the upper triangular part to
// 0 after the method. That last part is unnecessary, but baked into
// Armadillo, so there's not really much that can be done about that without
// discussion with the Armadillo maintainer.
const arma::mat invCovLower = arma::inv(arma::trimatl(covLower));

invCov = invCovLower.t() * invCovLower;
double sign = 0.;
arma::log_det(logDetCov, sign, covLower);
logDetCov *= 2;
}

double GaussianDistribution::LogProbability(const arma::vec& observation) const
{
const size_t k = observation.n_elem;
double logdetsigma = 0;
double sign = 0.;
arma::log_det(logdetsigma, sign, covariance);
const arma::vec diff = mean - observation;
const arma::vec v = (diff.t() * arma::inv(covariance) * diff);
return -0.5 * k * log2pi - 0.5 * logdetsigma - 0.5 * v(0);
const arma::vec v = (diff.t() * invCov * diff);
return -0.5 * k * log2pi - 0.5 * logDetCov - 0.5 * v(0);
}

arma::vec GaussianDistribution::Random() const
{
// Should we store chol(covariance) for easier calculation later?
return trans(chol(covariance)) * arma::randn<arma::vec>(mean.n_elem) + mean;
return covLower * arma::randn<arma::vec>(mean.n_elem) + mean;
}

/**
Expand All @@ -41,6 +85,7 @@ void GaussianDistribution::Estimate(const arma::mat& observations)
}
else // This will end up just being empty.
{
// TODO(stephentu): why do we allow this case? why not throw an error?
mean.zeros(0);
covariance.zeros(0);
return;
Expand Down Expand Up @@ -77,6 +122,8 @@ void GaussianDistribution::Estimate(const arma::mat& observations)
perturbation *= 10; // Slow, but we don't want to add too much.
}
}

FactorCovariance();
}

/**
Expand All @@ -94,6 +141,7 @@ void GaussianDistribution::Estimate(const arma::mat& observations,
}
else // This will end up just being empty.
{
// TODO(stephentu): same as above
mean.zeros(0);
covariance.zeros(0);
return;
Expand All @@ -114,6 +162,7 @@ void GaussianDistribution::Estimate(const arma::mat& observations,
// Nothing in this Gaussian! At least set the covariance so that it's
// invertible.
covariance.diag() += 1e-50;
FactorCovariance();
return;
}

Expand Down Expand Up @@ -143,6 +192,8 @@ void GaussianDistribution::Estimate(const arma::mat& observations,
perturbation *= 10; // Slow, but we don't want to add too much.
}
}

FactorCovariance();
}

/**
Expand Down Expand Up @@ -180,4 +231,5 @@ void GaussianDistribution::Load(const util::SaveRestoreUtility& sr)
{
sr.LoadParameter(mean, "mean");
sr.LoadParameter(covariance, "covariance");
FactorCovariance();
}
36 changes: 24 additions & 12 deletions src/mlpack/core/dists/gaussian_distribution.hpp
Expand Up @@ -21,8 +21,14 @@ class GaussianDistribution
private:
//! Mean of the distribution.
arma::vec mean;
//! Covariance of the distribution.
//! Positive definite covariance of the distribution.
arma::mat covariance;
//! Lower triangular factor of cov (e.g. cov = LL^T).
arma::mat covLower;
//! Cached inverse of covariance.
arma::mat invCov;
//! Cached logdet(cov).
double logDetCov;

//! log(2pi)
static const constexpr double log2pi = 1.83787706640934533908193770912475883;
Expand All @@ -39,14 +45,20 @@ class GaussianDistribution
*/
GaussianDistribution(const size_t dimension) :
mean(arma::zeros<arma::vec>(dimension)),
covariance(arma::eye<arma::mat>(dimension, dimension))
covariance(arma::eye<arma::mat>(dimension, dimension)),
covLower(arma::eye<arma::mat>(dimension, dimension)),
invCov(arma::eye<arma::mat>(dimension, dimension)),
logDetCov(0)
{ /* Nothing to do. */ }

/**
* Create a Gaussian distribution with the given mean and covariance.
*
* covariance is expected to be positive definite.
*/
GaussianDistribution(const arma::vec& mean, const arma::mat& covariance) :
mean(mean), covariance(covariance) { /* Nothing to do. */ }
GaussianDistribution(const arma::vec& mean, const arma::mat& covariance);

// TODO(stephentu): do we want a (arma::vec&&, arma::mat&&) ctor?

//! Return the dimensionality of this distribution.
size_t Dimensionality() const { return mean.n_elem; }
Expand Down Expand Up @@ -119,9 +131,11 @@ class GaussianDistribution
const arma::mat& Covariance() const { return covariance; }

/**
* Return a modifiable copy of the covariance.
* Set the covariance.
*/
arma::mat& Covariance() { return covariance; }
void Covariance(const arma::mat& covariance);

void Covariance(arma::mat&& covariance);

/**
* Returns a string representation of this object.
Expand All @@ -135,7 +149,8 @@ class GaussianDistribution
void Load(const util::SaveRestoreUtility& n);
static std::string const Type() { return "GaussianDistribution"; }


private:
void FactorCovariance();

};

Expand All @@ -156,17 +171,14 @@ inline void GaussianDistribution::LogProbability(const arma::mat& x,
// diffs). We just don't need any of the other elements. We can calculate
// the right hand part of the equation (instead of the left side) so that
// later we are referencing columns, not rows -- that is faster.
arma::mat rhs = -0.5 * inv(covariance) * diffs;
const arma::mat rhs = -0.5 * invCov * diffs;
arma::vec logExponents(diffs.n_cols); // We will now fill this.
for (size_t i = 0; i < diffs.n_cols; i++)
logExponents(i) = accu(diffs.unsafe_col(i) % rhs.unsafe_col(i));

double logdetsigma = 0;
double sign = 0.;
arma::log_det(logdetsigma, sign, covariance);
const size_t k = x.n_rows;

logProbabilities = -0.5 * k * log2pi - 0.5 * logdetsigma + logExponents;
logProbabilities = -0.5 * k * log2pi - 0.5 * logDetCov + logExponents;
}


Expand Down
4 changes: 3 additions & 1 deletion src/mlpack/core/dists/regression_distribution.hpp
Expand Up @@ -48,7 +48,9 @@ class RegressionDistribution
rf(regression::LinearRegression(predictors, responses))
{
err = GaussianDistribution(1);
err.Covariance() = rf.ComputeError(predictors, responses);
arma::mat cov(1, 1);
cov(0, 0) = rf.ComputeError(predictors, responses);
err.Covariance(std::move(cov));
}

/**
Expand Down
47 changes: 28 additions & 19 deletions src/mlpack/methods/gmm/em_fit_impl.hpp
Expand Up @@ -93,11 +93,12 @@ void EMFit<InitialClusteringType, CovarianceConstraintPolicy>::Estimate(
trans(condProb.col(i)));

// Don't update if there's no probability of the Gaussian having points.
if (probRowSums[i] != 0.0)
dists[i].Covariance() = (tmp * trans(tmpB)) / probRowSums[i];

// Apply covariance constraint.
constraint.ApplyConstraint(dists[i].Covariance());
if (probRowSums[i] != 0.0) {
arma::mat covariance = (tmp * trans(tmpB)) / probRowSums[i];
// Apply covariance constraint.
constraint.ApplyConstraint(covariance);
dists[i].Covariance(std::move(covariance));
}
}

// Calculate the new values for omega using the updated conditional
Expand Down Expand Up @@ -180,10 +181,12 @@ void EMFit<InitialClusteringType, CovarianceConstraintPolicy>::Estimate(
arma::mat tmpB = tmp % (arma::ones<arma::vec>(observations.n_rows) *
trans(condProb.col(i) % probabilities));

dists[i].Covariance() = (tmp * trans(tmpB)) / probRowSums[i];
arma::mat cov = (tmp * trans(tmpB)) / probRowSums[i];

// Apply covariance constraint.
constraint.ApplyConstraint(dists[i].Covariance());
constraint.ApplyConstraint(cov);

dists[i].Covariance(std::move(cov));
}

// Calculate the new values for omega using the updated conditional
Expand All @@ -210,12 +213,16 @@ InitialClustering(const arma::mat& observations,
// Run clustering algorithm.
clusterer.Cluster(observations, dists.size(), assignments);

std::vector<arma::vec> means(dists.size());
std::vector<arma::mat> covs(dists.size());

// Now calculate the means, covariances, and weights.
weights.zeros();
for (size_t i = 0; i < dists.size(); ++i)
{
dists[i].Mean().zeros();
dists[i].Covariance().zeros();
means[i].zeros(dists[i].Mean().n_elem);
covs[i].zeros(dists[i].Covariance().n_rows,
dists[i].Covariance().n_cols);
}

// From the assignments, generate our means, covariances, and weights.
Expand All @@ -224,11 +231,10 @@ InitialClustering(const arma::mat& observations,
const size_t cluster = assignments[i];

// Add this to the relevant mean.
dists[cluster].Mean() += observations.col(i);
means[cluster] += observations.col(i);

// Add this to the relevant covariance.
dists[cluster].Covariance() += observations.col(i) *
trans(observations.col(i));
covs[cluster] += observations.col(i) * trans(observations.col(i));

// Now add one to the weights (we will normalize).
weights[cluster]++;
Expand All @@ -237,22 +243,25 @@ InitialClustering(const arma::mat& observations,
// Now normalize the mean and covariance.
for (size_t i = 0; i < dists.size(); ++i)
{
dists[i].Mean() /= (weights[i] > 1) ? weights[i] : 1;
means[i] /= (weights[i] > 1) ? weights[i] : 1;
}

for (size_t i = 0; i < observations.n_cols; ++i)
{
const size_t cluster = assignments[i];
const arma::vec normObs = observations.col(i) - dists[cluster].Mean();
dists[cluster].Covariance() += normObs * normObs.t();
const arma::vec normObs = observations.col(i) - means[cluster];
covs[cluster] += normObs * normObs.t();
}

for (size_t i = 0; i < dists.size(); ++i)
{
dists[i].Covariance() /= (weights[i] > 1) ? weights[i] : 1;
covs[i] /= (weights[i] > 1) ? weights[i] : 1;

// Apply constraints to covariance matrix.
constraint.ApplyConstraint(dists[i].Covariance());
constraint.ApplyConstraint(covs[i]);

std::swap(dists[i].Mean(), means[i]);
dists[i].Covariance(std::move(covs[i]));
}

// Finally, normalize weights.
Expand All @@ -269,7 +278,7 @@ double EMFit<InitialClusteringType, CovarianceConstraintPolicy>::LogLikelihood(

arma::vec phis;
arma::mat likelihoods(dists.size(), observations.n_cols);

for (size_t i = 0; i < dists.size(); ++i)
{
dists[i].Probability(observations, phis);
Expand All @@ -283,7 +292,7 @@ double EMFit<InitialClusteringType, CovarianceConstraintPolicy>::LogLikelihood(
<< "outlier." << std::endl;
logLikelihood += log(accu(likelihoods.col(j)));
}

return logLikelihood;
}

Expand Down
4 changes: 3 additions & 1 deletion src/mlpack/methods/gmm/gmm_convert_main.cpp
Expand Up @@ -47,12 +47,14 @@ int main(int argc, char* argv[])
for (size_t i = 0; i < gaussians; ++i)
{
stringstream o;
arma::mat covariance;
o << i;
string meanName = "mean" + o.str();
string covName = "covariance" + o.str();

load.LoadParameter(gmm.Component(i).Mean(), meanName);
load.LoadParameter(gmm.Component(i).Covariance(), covName);
load.LoadParameter(covariance, covName);
gmm.Component(i).Covariance(std::move(covariance));
}

gmm.Save(CLI::GetParam<string>("output_file"));
Expand Down
8 changes: 6 additions & 2 deletions src/mlpack/methods/hmm/hmm_util_impl.hpp
Expand Up @@ -115,8 +115,10 @@ void ConvertHMM(HMM<distribution::GaussianDistribution>& hmm,
sr.LoadParameter(hmm.Emission()[i].Mean(), s.str());

s.str("");
arma::mat covariance;
s << "hmm_emission_covariance_" << i;
sr.LoadParameter(hmm.Emission()[i].Covariance(), s.str());
sr.LoadParameter(covariance, s.str());
hmm.Emission()[i].Covariance(std::move(covariance));
}

hmm.Dimensionality() = hmm.Emission()[0].Mean().n_elem;
Expand Down Expand Up @@ -168,7 +170,9 @@ void ConvertHMM(HMM<gmm::GMM<> >& hmm, const util::SaveRestoreUtility& sr)

s.str("");
s << "hmm_emission_" << i << "_gaussian_" << g << "_covariance";
sr.LoadParameter(hmm.Emission()[i].Component(g).Covariance(), s.str());
arma::mat covariance;
sr.LoadParameter(covariance, s.str());
hmm.Emission()[i].Component(g).Covariance(std::move(covariance));
}

s.str("");
Expand Down