Skip to content

Commit

Permalink
Improved heatmap display for confusion matrix report (#2617)
Browse files Browse the repository at this point in the history
* Improved heatmap display for confusion matrix report

* Resolved Comments

* Minor fix
  • Loading branch information
harsh-deepchecks committed Jul 3, 2023
1 parent 3ec7ac8 commit e481a87
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 25 deletions.
55 changes: 36 additions & 19 deletions deepchecks/utils/abstracts/confusion_matrix_abstract.py
Expand Up @@ -31,14 +31,14 @@ def run_confusion_matrix_check(y_pred: np.ndarray, y_true: np.ndarray, with_disp
result = confusion_matrix(y_true, y_pred)

if with_display:
fig = create_confusion_matrix_figure(result, total_classes, normalize_display)
displays = create_confusion_matrix_figure(result, total_classes, normalize_display)
else:
fig = None
displays = None

# For accessing the class names from the condition
result = pd.DataFrame(result, index=total_classes, columns=total_classes)

return CheckResult(result, display=fig)
return CheckResult(result, display=displays)


def create_confusion_matrix_figure(confusion_matrix_data: np.ndarray, classes_names: List[str],
Expand All @@ -60,29 +60,46 @@ def create_confusion_matrix_figure(confusion_matrix_data: np.ndarray, classes_na
confusion matrix figure
"""
confusion_matrix_norm = confusion_matrix_data.astype('float') / \
(confusion_matrix_data.sum(axis=1)[:, np.newaxis] + np.finfo(float).eps) * 100
if normalize_display:
confusion_matrix_norm = confusion_matrix_data.astype('float') / \
(confusion_matrix_data.sum(axis=1)[:, np.newaxis] + np.finfo(float).eps) * 100
z = np.vectorize(format_number_if_not_nan)(confusion_matrix_norm)
text_template = '%{z}%<br>(%{text})'
color_bar_title = '% out of<br>True Values'
plot_title = 'Percent Out of True Values (Count)'
else:
z = confusion_matrix_data
color_bar_title = None
text_template = '%{text}'
plot_title = 'Value Count'

fig = go.Figure(data=go.Heatmap(
x=classes_names, y=classes_names, z=z,
text=confusion_matrix_data, texttemplate=text_template))
fig.data[0].colorbar.title = color_bar_title
fig.update_layout(title=plot_title)

accuracy_array = np.diag(confusion_matrix_norm).round(decimals=2)

display = []
display_msg = f'The overall accuracy of your model is: {round(np.sum(accuracy_array)/len(accuracy_array), 2)}%.'

if min(accuracy_array) < 100:
display_msg += f'<br>Best accuracy achieved on samples with <b>{classes_names[np.argmax(accuracy_array)]}' \
f'</b> label ({np.max(accuracy_array)}%).'
display_msg += f'<br>Worst accuracy achieved on samples with <b>{classes_names[np.argmin(accuracy_array)]}' \
f'</b> label ({np.min(accuracy_array)}%).'
display.append(display_msg)

total_samples = np.nansum(confusion_matrix_data)
percent_data_each_row = np.round(confusion_matrix_norm, decimals=2)
percent_data_each_cell = np.round(np.divide(np.nan_to_num(confusion_matrix_data, nan=0.0), total_samples) * 100,
decimals=2)
percent_data_each_col = (confusion_matrix_data.astype('float') /
(confusion_matrix_data.sum(axis=0)[:, np.newaxis] +
np.finfo(float).eps) * 100).round(decimals=2)
custom_hoverdata = np.dstack((percent_data_each_cell, percent_data_each_row, percent_data_each_col))

fig = go.Figure(data=go.Heatmap(x=classes_names, y=classes_names, z=z, customdata=custom_hoverdata,
xgap=1, ygap=1, text=confusion_matrix_data, texttemplate='%{text}',
hovertemplate='% out of all data: <b>%{customdata[0]}%</b><br>% out '
'of row: <b>%{customdata[1]}%</b><br>% out of column: '
'<b>%{customdata[2]}%</b><extra></extra>',
showscale=False))
fig.update_layout(title='Confusion Matrix (# Samples)', title_x=0.5)
fig.update_layout(height=600)
fig.update_xaxes(title='Predicted Value', type='category', scaleanchor='y', constrain='domain')
fig.update_yaxes(title='True Value', type='category', constrain='domain', autorange='reversed')

return fig
display.append(fig)
return display


def misclassified_samples_lower_than_condition(value: pd.DataFrame,
Expand Down
Expand Up @@ -145,7 +145,7 @@ def compute(self, context: Context, dataset_kind: DatasetKind = None) -> CheckRe
x.append('No overlapping')
y.append('No overlapping')

description.append(
description.extend(
create_confusion_matrix_figure(confusion_matrix, x, self.normalized)
)
else:
Expand Down
24 changes: 20 additions & 4 deletions tests/nlp/checks/model_evaluation/confusion_matrix_test.py
Expand Up @@ -100,16 +100,16 @@ def test_condition_misclassified_samples_lower_than_raises_error(tweet_emotion_t
assert_that(result.conditions_results[0], equal_condition_result(
is_pass=False,
name=f'Misclassified cell size lower than {format_number(-0.1 * 100)}% of the total samples',
details='Exception in condition: DeepchecksValueError: Condition requires the parameter "misclassified_samples_threshold" '
'to be between 0 and 1 inclusive but got -0.1',
details='Exception in condition: DeepchecksValueError: Condition requires the parameter '
'"misclassified_samples_threshold" to be between 0 and 1 inclusive but got -0.1',
category=ConditionCategory.ERROR
))

assert_that(result.conditions_results[1], equal_condition_result(
is_pass=False,
name=f'Misclassified cell size lower than {format_number(1.1 * 100)}% of the total samples',
details='Exception in condition: DeepchecksValueError: Condition requires the parameter "misclassified_samples_threshold" '
'to be between 0 and 1 inclusive but got 1.1',
details='Exception in condition: DeepchecksValueError: Condition requires the parameter '
'"misclassified_samples_threshold" to be between 0 and 1 inclusive but got 1.1',
category=ConditionCategory.ERROR
))

Expand Down Expand Up @@ -189,3 +189,19 @@ def test_condition_misclassified_samples_lower_than_fails(tweet_emotion_train_te
f'Largest misclassified cell ({format_percent(max_misclassified_samples_ratio)} of the data) ' \
f'is samples with a true value of "{class_names[x]}" and a predicted value of "{class_names[y]}".'
))


def test_confusion_matrix_report_display(tweet_emotion_train_test_textdata, tweet_emotion_train_test_predictions):
# Arrange and Act
check = ConfusionMatrixReport()
result = check.run(tweet_emotion_train_test_textdata[0], predictions=tweet_emotion_train_test_predictions[0])

# Assert
assert_that(result.display[0],
equal_to('The overall accuracy of your model is: 92.04%.<br>Best accuracy achieved on samples with '
'<b>anger</b> label (96.59%).<br>Worst accuracy achieved on samples with <b>sadness</b> '
'label (88.86%).'))
# First is the text description and second is the heatmap
assert_that(len(result.display), equal_to(2))
assert_that(len(result.display[1].data), equal_to(1))
assert_that(result.display[1].data[0].type, equal_to('heatmap'))
Expand Up @@ -10,7 +10,7 @@
#
"""Contains unit tests for the confusion_matrix_report check."""
import numpy as np
from hamcrest import assert_that, calling, greater_than, has_length, raises
from hamcrest import assert_that, calling, greater_than, has_length, raises, equal_to

from deepchecks.core.condition import ConditionCategory
from deepchecks.core.errors import DeepchecksNotSupportedError, DeepchecksValueError, ModelValidationError
Expand Down Expand Up @@ -183,3 +183,22 @@ def test_condition_misclassified_samples_lower_than_fails(iris_split_dataset_and
f'Largest misclassified cell ({format_percent(max_misclassified_samples_ratio)} of the data) ' \
f'is samples with a true value of "{class_names[x]}" and a predicted value of "{class_names[y]}".'
))


def test_confusion_matrix_report_display(iris_split_dataset_and_model):
# Arrange
_, test, clf = iris_split_dataset_and_model

# Act
check = ConfusionMatrixReport()

result = check.run(test, clf)

# Assert
assert_that(result.display[0],
equal_to('The overall accuracy of your model is: 91.67%.<br>Best accuracy achieved on samples with <b>'
'0</b> label (100.0%).<br>Worst accuracy achieved on samples with <b>2</b> label (75.0%).'))
# # First is the text description and second is the heatmap
assert_that(len(result.display), equal_to(2))
assert_that(len(result.display[1].data), equal_to(1))
assert_that(result.display[1].data[0].type, equal_to('heatmap'))
18 changes: 18 additions & 0 deletions tests/vision/checks/model_evaluation/confusion_matrix_test.py
Expand Up @@ -56,3 +56,21 @@ def test_detection(coco_visiondata_train):
# Assert
num_of_classes = len(coco_visiondata_train.get_observed_classes()) + 1 # plus no-overlapping
assert_that(result.value.shape, le((num_of_classes, num_of_classes)))


def test_confusion_matrix_report_display(mnist_visiondata_train):
# Arrange
check = ConfusionMatrixReport()

# Act
result = check.run(mnist_visiondata_train)

# Assert
assert_that(result.display[0], equal_to('Showing 10 of 10 classes:'))
assert_that(result.display[1],
equal_to('The overall accuracy of your model is: 97.45%.<br>Best accuracy achieved on samples with <b>'
'0</b> label (100.0%).<br>Worst accuracy achieved on samples with <b>9</b> label (86.96%).'))
# First and second are the text descriptions and third is a heatmap
assert_that(len(result.display), equal_to(3))
assert_that(len(result.display[2].data), equal_to(1))
assert_that(result.display[2].data[0].type, equal_to('heatmap'))

0 comments on commit e481a87

Please sign in to comment.