From d7d34937b172b6b05ec16151a3d7ed653842d420 Mon Sep 17 00:00:00 2001 From: Pavel Date: Thu, 10 Dec 2015 18:52:47 +0500 Subject: [PATCH] Fix classification error Change "invVar" to "variance" matrix when calculating testProbs. By using "invVar" you have variances product in numenator, but it needs to be in denominator. In addition there was potential problem with accuracy when calculate exponents and then calculate logarithm. I fixed it too. --- .../methods/naive_bayes/naive_bayes_classifier_impl.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp index 2b0b3c76f9b..b40719c3685 100644 --- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp @@ -186,10 +186,10 @@ void NaiveBayesClassifier::Classify(const MatType& data, arma::mat rhs = -0.5 * arma::diagmat(invVar.col(i)) * diffs; arma::vec exponents(diffs.n_cols); for (size_t j = 0; j < diffs.n_cols; ++j) - exponents(j) = std::exp(arma::accu(diffs.col(j) % rhs.unsafe_col(j))); + exponents(j) = arma::accu(diffs.col(j) % rhs.unsafe_col(j)); //log( exp (value) ) == value - testProbs.col(i) += log(pow(2 * M_PI, (double) data.n_rows / -2.0) * - std::pow(arma::det(arma::diagmat(invVar.col(i))), -0.5) * exponents); + //calculate prob as sum of logarithm to decrease floating point errors + testProbs.col(i) += (data.n_rows / -2.0 * log(2 * M_PI) - 0.5 * log(arma::det(arma::diagmat(variances.col(i)))) + exponents); } // Now calculate the label.