Skip to content

Commit

Permalink
Fix confusion matrix heatmap (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
matanper committed Jan 2, 2022
1 parent 7fc29b5 commit eb1300d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
14 changes: 5 additions & 9 deletions deepchecks/checks/performance/confusion_matrix_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
# ----------------------------------------------------------------------------
#
"""The confusion_matrix_report check module."""
import numpy as np
import sklearn
from sklearn.base import BaseEstimator

import plotly.figure_factory as ff
import plotly.express as px
from deepchecks import CheckResult, Dataset
from deepchecks.base.check import SingleDatasetBaseCheck
from deepchecks.utils.metrics import ModelType, task_type_validation
Expand Down Expand Up @@ -49,15 +48,12 @@ def _confusion_matrix_report(self, dataset: Dataset, model):
ds_x = dataset.data[dataset.features]
ds_y = dataset.data[label]
y_pred = model.predict(ds_x)

confusion_matrix = sklearn.metrics.confusion_matrix(ds_y, y_pred)

labels = [str(val) for val in np.unique(ds_y)]
fig = ff.create_annotated_heatmap(confusion_matrix, x=labels, y=labels, colorscale='Viridis')
# Figure
fig = px.imshow(confusion_matrix, x=dataset.classes, y=dataset.classes, text_auto=True)
fig.update_layout(width=600, height=600)
fig.update_xaxes(title='Predicted Value')
fig.update_yaxes(title='True value', autorange='reversed')
fig['data'][0]['showscale'] = True
fig['layout']['xaxis']['side'] = 'bottom'
fig.update_xaxes(title='Predicted Value', type='category')
fig.update_yaxes(title='True value', type='category')

return CheckResult(confusion_matrix, display=fig)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ tqdm>=4.62.3
category-encoders>=2.3.0
scipy>=1.5.0
dataclasses>=0.6; python_version < '3.7'
plotly>=5.4.0
plotly>=5.5.0
matplotlib>=3.3.3

0 comments on commit eb1300d

Please sign in to comment.