Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Confusionmatrix function #1798

Merged
merged 29 commits into from
May 22, 2019
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1338e89
confusion matrix along with test
jeffin143 Mar 20, 2019
7ca4546
syncing with the original one
jeffin143 Mar 20, 2019
7824713
original one
jeffin143 Mar 20, 2019
a1a0e00
styling issue
jeffin143 Mar 20, 2019
21965a3
update cv_test.cpp
jeffin143 Mar 20, 2019
7a5fbc2
Merge pull request #2 from mlpack/master
jeffin143 Mar 21, 2019
4134fdc
merge issue
jeffin143 Mar 22, 2019
61df5a4
meging
jeffin143 Mar 22, 2019
7cf6c27
changes
jeffin143 Mar 22, 2019
513bed7
Merge branch 'master' of https://github.com/jeffin143/mlpack into con…
jeffin143 Mar 22, 2019
f31017c
minor fixes
jeffin143 Mar 22, 2019
26e9cd2
handling some issues
jeffin143 Mar 22, 2019
8daaa14
Merge pull request #3 from mlpack/master
jeffin143 Mar 22, 2019
4a7c9e1
Merge branch 'confusionmatrix' of https://github.com/jeffin143/mlpack…
jeffin143 Mar 22, 2019
ac79b90
throw
jeffin143 Mar 22, 2019
1ea895e
big blunder -> int to double
jeffin143 Mar 22, 2019
b9dfe7a
DesignGuidelines
jeffin143 Mar 25, 2019
f1c5ec5
descriptive name
jeffin143 Mar 27, 2019
221c30a
typo
jeffin143 Mar 29, 2019
fcefbb7
changes
jeffin143 Apr 2, 2019
ddfe5df
documentation and parameter fixup
jeffin143 Apr 3, 2019
dd1e6d7
adding description about file
jeffin143 Apr 3, 2019
6f9cde0
resolving merge conflicts
jeffin143 Apr 26, 2019
7857e5a
Merge branch 'master' into confusionmatrix
jeffin143 Apr 26, 2019
438a996
small style changes
jeffin143 May 9, 2019
7e9dd20
Merge branch 'confusionmatrix' of https://github.com/jeffin143/mlpack…
jeffin143 May 9, 2019
794a0db
Merge branch 'master' into confusionmatrix
jeffin143 May 12, 2019
3228b7e
Changed documentation and add .
jeffin143 May 12, 2019
5500647
Merge branch 'confusionmatrix' of https://github.com/jeffin143/mlpack…
jeffin143 May 12, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/mlpack/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
#include <mlpack/core/dists/laplace_distribution.hpp>
#include <mlpack/core/dists/gamma_distribution.hpp>
#include <mlpack/core/dists/diagonal_gaussian_distribution.hpp>
#include <mlpack/core/data/confusion_matrix.hpp>

jeffin143 marked this conversation as resolved.
Show resolved Hide resolved
// mlpack::backtrace only for linux
#ifdef HAS_BFD_DL
Expand Down
1 change: 1 addition & 0 deletions src/mlpack/core/data/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ set(SOURCES
split_data.hpp
imputer.hpp
binarize.hpp
confusion_matrix.hpp
)

# add directory name to sources
Expand Down
53 changes: 53 additions & 0 deletions src/mlpack/core/data/confusion_matrix.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* @file confusion_matrix_impl.hpp
* @author Jeffin Sam
*
* Compute confusion matrix to evaluate the accuracy of a classification.
* The function works only for discrete data/categorical data.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_CORE_DATA_CONFUSION_MATRIX_HPP
#define MLPACK_CORE_DATA_CONFUSION_MATRIX_HPP

#include <mlpack/prereqs.hpp>

namespace mlpack {
namespace data {

/**
jeffin143 marked this conversation as resolved.
Show resolved Hide resolved
* A confusion matrix is a summary of prediction results on a classification
* problem.The number of correct and incorrect predictions are summarized
* with count values and broken down by each class.
* for example for 2 classes the function will be
* confusionmatrix(predictors, responses, output, 2)
* output matrix will be of size 2 * 2
jeffin143 marked this conversation as resolved.
Show resolved Hide resolved
*
* 0 1
* 0 TP FN
* 1 FP TN
*
* Confusion matrix for two labels will look like above.
* Row is the predicted values and column are actual values.
jeffin143 marked this conversation as resolved.
Show resolved Hide resolved
*
* @param predictors Vector of data points.
* @param responses The measured data for each point in X.
jeffin143 marked this conversation as resolved.
Show resolved Hide resolved
* @param output Matrix which is represented as confusion matrix.
* @param countlables No of classes
jeffin143 marked this conversation as resolved.
Show resolved Hide resolved
*
jeffin143 marked this conversation as resolved.
Show resolved Hide resolved
*/
template<typename eT>
void ConfusionMatrix(const arma::Row<size_t> predictors,
const arma::Row<size_t> responses,
arma::Mat<eT>& output,
const size_t countlabels);
} // namespace data
} // namespace mlpack

// Include implementation.
#include "confusion_matrix_impl.hpp"

#endif
59 changes: 59 additions & 0 deletions src/mlpack/core/data/confusion_matrix_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/**
* @file confusion_matrix_impl.hpp
* @author Jeffin Sam
*
* Compute confusion matrix to evaluate the accuracy of a classification.
* The function works only for discrete data/categorical data.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_CORE_DATA_CONFUSION_MATRIX_IMPL_HPP
#define MLPACK_CORE_DATA_CONFUSION_MATRIX_IMPL_HPP

// In case it hasn't been included yet.
#include "confusion_matrix.hpp"

namespace mlpack {
namespace data {

/**
* A confusion matrix is a summary of prediction results on a classification
* problem.The number of correct and incorrect predictions are summarized
* with count values and broken down by each class.
* for example for 2 classes the function will be
jeffin143 marked this conversation as resolved.
Show resolved Hide resolved
* confusionmatrix(predictors, responses, output, 2)
* output matrix will be of size 2 * 2
*
* 0 1
* 0 TP FN
* 1 FP TN
*
* Confusion matrix for two labels will look like above.
* Row is the predicted values and column are actual values.
*
* @param predictors Vector of data points.
* @param responses The measured data for each point in X.
* @param output Matrix which is represented as confusion matrix.
* @param countlables No of classes
*
*/
template<typename eT>
void ConfusionMatrix(const arma::Row<size_t> predictors,
const arma::Row<size_t> responses,
arma::Mat<eT>& output,
const size_t countlabels)
{
// Loop over the actual labels and predicted labels and add the count
jeffin143 marked this conversation as resolved.
Show resolved Hide resolved
output = arma::zeros<arma::Mat<eT> >(countlabels, countlabels);
for (size_t i = 0; i < predictors.n_elem; ++i)
{
output.at(predictors[i], responses[i])++;
}
}
} // namespace data
} // namespace mlpack

#endif
22 changes: 21 additions & 1 deletion src/mlpack/tests/cv_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <mlpack/methods/logistic_regression/logistic_regression.hpp>
#include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
#include <mlpack/methods/softmax_regression/softmax_regression.hpp>

#include <mlpack/core/data/confusion_matrix.hpp>
#include <ensmallen.hpp>

#include <boost/test/unit_test.hpp>
Expand All @@ -43,6 +43,7 @@ using namespace mlpack::cv;
using namespace mlpack::naive_bayes;
using namespace mlpack::regression;
using namespace mlpack::tree;
using namespace mlpack::data;

BOOST_AUTO_TEST_SUITE(CVTest);

Expand Down Expand Up @@ -73,6 +74,25 @@ BOOST_AUTO_TEST_CASE(BinaryClassificationMetricsTest)
BOOST_REQUIRE_CLOSE(F1<Binary>::Evaluate(lr, data, labels), f1, 1e-5);
}

/**
* Test for confusion matrix.
*/
BOOST_AUTO_TEST_CASE(ConfusionMatrixTest)
{
// Labels that will be considered as "ground truth".
arma::Row<size_t> labels("0 0 1 0 0 1 0 1 0 1");

// Predicted labels.
arma::Row<size_t> predictedLabels("0 0 0 0 0 1 1 1 1 1");
// Confusion matrix.
arma::Mat<int> output;
data::ConfusionMatrix(predictedLabels, labels, output, 2);
BOOST_REQUIRE_EQUAL(output(0, 0), 4);
BOOST_REQUIRE_EQUAL(output(0, 1), 1);
BOOST_REQUIRE_EQUAL(output(1, 0), 2);
BOOST_REQUIRE_EQUAL(output(1, 1), 3);
}

/**
* Test metrics for multiclass classification.
*/
Expand Down