From 4538f9f836751f753ddd3d371b8e2674c54c4057 Mon Sep 17 00:00:00 2001 From: kuza55 Date: Tue, 20 Sep 2016 11:46:54 -0400 Subject: [PATCH] Matthews Correlation fix and test --- keras/metrics.py | 6 +++--- tests/keras/test_metrics.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/keras/metrics.py b/keras/metrics.py index c479e8f6b8b5..51afe719299e 100644 --- a/keras/metrics.py +++ b/keras/metrics.py @@ -83,9 +83,9 @@ def matthews_correlation(y_true, y_pred): tp = K.sum(y_pos * y_pred_pos) tn = K.sum(y_neg * y_pred_neg) - fp = K.sum(1 - y_neg * y_pred_pos) - fn = K.sum(1 - y_pos * y_pred_neg) - + fp = K.sum(y_neg * y_pred_pos) + fn = K.sum(y_pos * y_pred_neg) + numerator = (tp * tn - fp * fn) denominator = K.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) diff --git a/tests/keras/test_metrics.py b/tests/keras/test_metrics.py index f3bae663da06..ce19b8f463a6 100644 --- a/tests/keras/test_metrics.py +++ b/tests/keras/test_metrics.py @@ -34,6 +34,18 @@ def test_metrics(): assert K.eval(output).shape == () +def test_matthews_correlation(): + y_true = K.variable(np.array([0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0])) + y_pred = K.variable(np.array([1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0])) + + # Calculated using sklearn.metrics.matthews_corrcoef + actual = -0.14907119849998601 + + calc = K.eval(metrics.matthews_correlation(y_true, y_pred)) + epsilon = 1e-05 + assert actual - epsilon <= calc <= actual + epsilon + + def test_sparse_metrics(): for metric in all_sparse_metrics: y_a = K.variable(np.random.randint(0, 7, (6,)), dtype=K.floatx())