Skip to content

Commit

Permalink
invert the lower triangular matrix the right way
Browse files Browse the repository at this point in the history
thanks to rcurtin for the note
  • Loading branch information
stephentu committed Jan 22, 2015
1 parent 10ad21b commit 046bdc4
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions src/mlpack/core/dists/gaussian_distribution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,25 @@ void GaussianDistribution::Covariance(arma::mat&& covariance)
void GaussianDistribution::FactorCovariance()
{
covLower = arma::chol(covariance, "lower");
// tell arma that this is lower triangular matrix (for faster inversion)
covLower = arma::trimatl(covLower);
const arma::mat invCovLower = covLower.i();

// Comment from rcurtin:
//
// I think the use of the word "interpret" in the Armadillo documentation
// about trimatl and trimatu is somewhat misleading. What the function will
// actually do, when used in that context, is loop over the upper triangular
// part of the matrix and set it all to 0, so this ends up actually just
// burning cycles---also because the operator=() evaluates the expression and
// strips the knowledge that it's a lower triangular matrix. So then the call
// to .i() doesn't actually do anything smarter.
//
// But perusing fn_inv.hpp more closely, there is a specialization that will
// work when called like this: inv(trimatl(covLower)), and will use LAPACK's
// ?trtri functions. However, it will still set the upper triangular part to
// 0 after the method. That last part is unnecessary, but baked into
// Armadillo, so there's not really much that can be done about that without
// discussion with the Armadillo maintainer.
const arma::mat invCovLower = arma::inv(arma::trimatl(covLower));

invCov = invCovLower.t() * invCovLower;
double sign = 0.;
arma::log_det(logDetCov, sign, covLower);
Expand Down Expand Up @@ -146,7 +162,7 @@ void GaussianDistribution::Estimate(const arma::mat& observations,
// Nothing in this Gaussian! At least set the covariance so that it's
// invertible.
covariance.diag() += 1e-50;
// TODO(stephentu): why do we allow this case?
FactorCovariance();
return;
}

Expand Down

0 comments on commit 046bdc4

Please sign in to comment.