diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp index 2b0b3c76f9b..b40719c3685 100644 --- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp @@ -186,10 +186,10 @@ void NaiveBayesClassifier::Classify(const MatType& data, arma::mat rhs = -0.5 * arma::diagmat(invVar.col(i)) * diffs; arma::vec exponents(diffs.n_cols); for (size_t j = 0; j < diffs.n_cols; ++j) - exponents(j) = std::exp(arma::accu(diffs.col(j) % rhs.unsafe_col(j))); + exponents(j) = arma::accu(diffs.col(j) % rhs.unsafe_col(j)); //log( exp (value) ) == value - testProbs.col(i) += log(pow(2 * M_PI, (double) data.n_rows / -2.0) * - std::pow(arma::det(arma::diagmat(invVar.col(i))), -0.5) * exponents); + //calculate prob as sum of logarithm to decrease floating point errors + testProbs.col(i) += (data.n_rows / -2.0 * log(2 * M_PI) - 0.5 * log(arma::det(arma::diagmat(variances.col(i)))) + exponents); } // Now calculate the label.