Skip to content

Commit

Permalink
[MRG+1] Added support for multiclass Matthews correlation coefficient (
Browse files Browse the repository at this point in the history
…scikit-learn#8094)

Also ensure confusion matrix is accumulated with high precision.
  • Loading branch information
Erotemic authored and Jeremiah Johnson committed Dec 18, 2017
1 parent 164e4ba commit 44582f9
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 25 deletions.
34 changes: 30 additions & 4 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ Some of these are restricted to the binary classification case:
.. autosummary::
:template: function.rst

matthews_corrcoef
precision_recall_curve
roc_curve

Expand All @@ -236,6 +235,7 @@ Others also work in the multiclass case:
cohen_kappa_score
confusion_matrix
hinge_loss
matthews_corrcoef


Some also work in the multilabel case:
Expand Down Expand Up @@ -909,14 +909,40 @@ for binary classes. Quoting Wikipedia:
prediction, 0 an average random prediction and -1 an inverse prediction.
The statistic is also known as the phi coefficient."

If :math:`tp`, :math:`tn`, :math:`fp` and :math:`fn` are respectively the
number of true positives, true negatives, false positives and false negatives,
the MCC coefficient is defined as

In the binary (two-class) case, :math:`tp`, :math:`tn`, :math:`fp` and
:math:`fn` are respectively the number of true positives, true negatives, false
positives and false negatives, the MCC is defined as

.. math::
MCC = \frac{tp \times tn - fp \times fn}{\sqrt{(tp + fp)(tp + fn)(tn + fp)(tn + fn)}}.
In the multiclass case, the Matthews correlation coefficient can be `defined
<http://rk.kvl.dk/introduction/index.html>`_ in terms of a
:func:`confusion_matrix` :math:`C` for :math:`K` classes. To simplify the
definition consider the following intermediate variables:

* :math:`t_k=\sum_{i}^{K} C_{ik}` the number of times class :math:`k` truly occurred,
* :math:`p_k=\sum_{i}^{K} C_{ki}` the number of times class :math:`k` was predicted,
* :math:`c=\sum_{k}^{K} C_{kk}` the total number of samples correctly predicted,
* :math:`s=\sum_{i}^{K} \sum_{j}^{K} C_{ij}` the total number of samples.

Then the multiclass MCC is defined as:

.. math::
MCC = \frac{
c \times s - \sum_{k}^{K} p_k \times t_k
}{\sqrt{
(s^2 - \sum_{k}^{K} p_k^2) \times
(s^2 - \sum_{k}^{K} t_k^2)
}}
When there are more than two labels, the value of the MCC will no longer range
between -1 and +1. Instead the minimum value will be somewhere between -1 and 0
depending on the number and distribution of ground true labels. The maximum
value is always +1.

Here is a small example illustrating the usage of the :func:`matthews_corrcoef`
function:

Expand Down
3 changes: 2 additions & 1 deletion doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ New features
Enhancements
............

- :func:`metrics.matthews_corrcoef` now support multiclass classification.
:issue:`8094` by :user:`Jon Crall <Erotemic>`.
- Update Sphinx-Gallery from 0.1.4 to 0.1.7 for resolving links in
documentation build with Sphinx>1.5 :issue:`8010`, :issue:`7986` by
:user:`Oscar Najera <Titan-C>`

- :class:`multioutput.MultiOutputRegressor` and :class:`multioutput.MultiOutputClassifier`
now support online learning using `partial_fit`.
issue: `8053` by :user:`Peng Yu <yupbank>`.
Expand Down
48 changes: 30 additions & 18 deletions sklearn/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None):
raise ValueError("At least one label specified must be in y_true")

if sample_weight is None:
sample_weight = np.ones(y_true.shape[0], dtype=np.int)
sample_weight = np.ones(y_true.shape[0], dtype=np.int64)
else:
sample_weight = np.asarray(sample_weight)

Expand All @@ -351,8 +351,14 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None):
# also eliminate weights of eliminated items
sample_weight = sample_weight[ind]

# Choose the accumulator dtype to always have high precision
if sample_weight.dtype.kind in {'i', 'u', 'b'}:
dtype = np.int64
else:
dtype = np.float64

CM = coo_matrix((sample_weight, (y_true, y_pred)),
shape=(n_labels, n_labels)
shape=(n_labels, n_labels), dtype=dtype,
).toarray()

return CM
Expand Down Expand Up @@ -525,7 +531,7 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True,


def matthews_corrcoef(y_true, y_pred, sample_weight=None):
"""Compute the Matthews correlation coefficient (MCC) for binary classes
"""Compute the Matthews correlation coefficient (MCC)
The Matthews correlation coefficient is used in machine learning as a
measure of the quality of binary (two-class) classifications. It takes into
Expand All @@ -536,8 +542,9 @@ def matthews_corrcoef(y_true, y_pred, sample_weight=None):
an average random prediction and -1 an inverse prediction. The statistic
is also known as the phi coefficient. [source: Wikipedia]
Only in the binary case does this relate to information about true and
false positives and negatives. See references below.
Binary and multiclass labels are supported. Only in the binary case does
this relate to information about true and false positives and negatives.
See references below.
Read more in the :ref:`User Guide <matthews_corrcoef>`.
Expand Down Expand Up @@ -568,35 +575,40 @@ def matthews_corrcoef(y_true, y_pred, sample_weight=None):
.. [2] `Wikipedia entry for the Matthews Correlation Coefficient
<https://en.wikipedia.org/wiki/Matthews_correlation_coefficient>`_
.. [3] `Gorodkin, (2004). Comparing two K-category assignments by a
K-category correlation coefficient
<http://www.sciencedirect.com/science/article/pii/S1476927104000799>`_
.. [4] `Jurman, Riccadonna, Furlanello, (2012). A Comparison of MCC and CEN
Error Measures in MultiClass Prediction
<http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0041882>`_
Examples
--------
>>> from sklearn.metrics import matthews_corrcoef
>>> y_true = [+1, +1, +1, -1]
>>> y_pred = [+1, -1, +1, +1]
>>> matthews_corrcoef(y_true, y_pred) # doctest: +ELLIPSIS
-0.33...
"""
y_type, y_true, y_pred = _check_targets(y_true, y_pred)

if y_type != "binary":
if y_type not in {"binary", "multiclass"}:
raise ValueError("%s is not supported" % y_type)

lb = LabelEncoder()
lb.fit(np.hstack([y_true, y_pred]))
y_true = lb.transform(y_true)
y_pred = lb.transform(y_pred)
mean_yt = np.average(y_true, weights=sample_weight)
mean_yp = np.average(y_pred, weights=sample_weight)

y_true_u_cent = y_true - mean_yt
y_pred_u_cent = y_pred - mean_yp

cov_ytyp = np.average(y_true_u_cent * y_pred_u_cent, weights=sample_weight)
var_yt = np.average(y_true_u_cent ** 2, weights=sample_weight)
var_yp = np.average(y_pred_u_cent ** 2, weights=sample_weight)

mcc = cov_ytyp / np.sqrt(var_yt * var_yp)
C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
t_sum = C.sum(axis=1)
p_sum = C.sum(axis=0)
n_correct = np.trace(C)
n_samples = p_sum.sum()
cov_ytyp = n_correct * n_samples - np.dot(t_sum, p_sum)
cov_ypyp = n_samples ** 2 - np.dot(p_sum, p_sum)
cov_ytyt = n_samples ** 2 - np.dot(t_sum, t_sum)
mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)

if np.isnan(mcc):
return 0.
Expand Down
115 changes: 114 additions & 1 deletion sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,37 @@ def test_matthews_corrcoef_against_numpy_corrcoef():
np.corrcoef(y_true, y_pred)[0, 1], 10)


def test_matthews_corrcoef_against_jurman():
# Check that the multiclass matthews_corrcoef agrees with the definition
# presented in Jurman, Riccadonna, Furlanello, (2012). A Comparison of MCC
# and CEN Error Measures in MultiClass Prediction
rng = np.random.RandomState(0)
y_true = rng.randint(0, 2, size=20)
y_pred = rng.randint(0, 2, size=20)
sample_weight = rng.rand(20)

C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
N = len(C)
cov_ytyp = sum([
C[k, k] * C[m, l] - C[l, k] * C[k, m]
for k in range(N) for m in range(N) for l in range(N)
])
cov_ytyt = sum([
C[:, k].sum() *
np.sum([C[g, f] for f in range(N) for g in range(N) if f != k])
for k in range(N)
])
cov_ypyp = np.sum([
C[k, :].sum() *
np.sum([C[f, g] for f in range(N) for g in range(N) if f != k])
for k in range(N)
])
mcc_jurman = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)
mcc_ours = matthews_corrcoef(y_true, y_pred, sample_weight)

assert_almost_equal(mcc_ours, mcc_jurman, 10)


def test_matthews_corrcoef():
rng = np.random.RandomState(0)
y_true = ["a" if i == 0 else "b" for i in rng.randint(0, 2, size=20)]
Expand All @@ -380,8 +411,8 @@ def test_matthews_corrcoef():

# corrcoef, when the two vectors are opposites of each other, should be -1
y_true_inv = ["b" if i == "a" else "a" for i in y_true]

assert_almost_equal(matthews_corrcoef(y_true, y_true_inv), -1)

y_true_inv2 = label_binarize(y_true, ["a", "b"])
y_true_inv2 = np.where(y_true_inv2, 'a', 'b')
assert_almost_equal(matthews_corrcoef(y_true, y_true_inv2), -1)
Expand Down Expand Up @@ -414,6 +445,61 @@ def test_matthews_corrcoef():
matthews_corrcoef(y_1, y_2, sample_weight=mask), 0.)


def test_matthews_corrcoef_multiclass():
rng = np.random.RandomState(0)
ord_a = ord('a')
n_classes = 4
y_true = [chr(ord_a + i) for i in rng.randint(0, n_classes, size=20)]

# corrcoef of same vectors must be 1
assert_almost_equal(matthews_corrcoef(y_true, y_true), 1.0)

# with multiclass > 2 it is not possible to achieve -1
y_true = [0, 0, 1, 1, 2, 2]
y_pred_bad = [2, 2, 0, 0, 1, 1]
assert_almost_equal(matthews_corrcoef(y_true, y_pred_bad), -.5)

# Maximizing false positives and negatives minimizes the MCC
# The minimum will be different for depending on the input
y_true = [0, 0, 1, 1, 2, 2]
y_pred_min = [1, 1, 0, 0, 0, 0]
assert_almost_equal(matthews_corrcoef(y_true, y_pred_min),
-12 / np.sqrt(24 * 16))

# Zero variance will result in an mcc of zero and a Runtime Warning
y_true = [0, 1, 2]
y_pred = [3, 3, 3]
mcc = assert_warns_message(RuntimeWarning, 'invalid value encountered',
matthews_corrcoef, y_true, y_pred)
assert_almost_equal(mcc, 0.0)

# These two vectors have 0 correlation and hence mcc should be 0
y_1 = [0, 1, 2, 0, 1, 2, 0, 1, 2]
y_2 = [1, 1, 1, 2, 2, 2, 0, 0, 0]
assert_almost_equal(matthews_corrcoef(y_1, y_2), 0.)

# We can test that binary assumptions hold using the multiclass computation
# by masking the weight of samples not in the first two classes

# Masking the last label should let us get an MCC of -1
y_true = [0, 0, 1, 1, 2]
y_pred = [1, 1, 0, 0, 2]
sample_weight = [1, 1, 1, 1, 0]
assert_almost_equal(matthews_corrcoef(y_true, y_pred, sample_weight), -1)

# For the zero vector case, the corrcoef cannot be calculated and should
# result in a RuntimeWarning
y_true = [0, 0, 1, 2]
y_pred = [0, 0, 1, 2]
sample_weight = [1, 1, 0, 0]
mcc = assert_warns_message(RuntimeWarning, 'invalid value encountered',
matthews_corrcoef, y_true, y_pred,
sample_weight)

# But will output 0
assert_almost_equal(mcc, 0.)


def test_precision_recall_f1_score_multiclass():
# Test Precision Recall and F1 Score for multiclass classification task
y_true, y_pred, _ = make_prediction(binary=False)
Expand Down Expand Up @@ -593,6 +679,33 @@ def test_confusion_matrix_multiclass_subset_labels():
labels=[extra_label, extra_label + 1])


def test_confusion_matrix_dtype():
y = [0, 1, 1]
weight = np.ones(len(y))
# confusion_matrix returns int64 by default
cm = confusion_matrix(y, y)
assert_equal(cm.dtype, np.int64)
# The dtype of confusion_matrix is always 64 bit
for dtype in [np.bool_, np.int32, np.uint64]:
cm = confusion_matrix(y, y, sample_weight=weight.astype(dtype))
assert_equal(cm.dtype, np.int64)
for dtype in [np.float32, np.float64, None, object]:
cm = confusion_matrix(y, y, sample_weight=weight.astype(dtype))
assert_equal(cm.dtype, np.float64)

# np.iinfo(np.uint32).max should be accumulated correctly
weight = np.ones(len(y), dtype=np.uint32) * 4294967295
cm = confusion_matrix(y, y, sample_weight=weight)
assert_equal(cm[0, 0], 4294967295)
assert_equal(cm[1, 1], 8589934590)

# np.iinfo(np.int64).max should cause an overflow
weight = np.ones(len(y), dtype=np.int64) * 9223372036854775807
cm = confusion_matrix(y, y, sample_weight=weight)
assert_equal(cm[0, 0], 9223372036854775807)
assert_equal(cm[1, 1], -2)


def test_classification_report_multiclass():
# Test performance report
iris = datasets.load_iris()
Expand Down
1 change: 0 additions & 1 deletion sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@
# Those metrics don't support multiclass inputs
METRIC_UNDEFINED_MULTICLASS = [
"brier_score_loss",
"matthews_corrcoef_score",

# with default average='binary', multiclass is prohibited
"precision_score",
Expand Down

0 comments on commit 44582f9

Please sign in to comment.