Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Nb/feat/confusion matrix nlp (#2426)
- Loading branch information
1 parent
180b455
commit 73cb2d9
Showing
17 changed files
with
260 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
deepchecks/nlp/checks/model_evaluation/confusion_matrix_report.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# ---------------------------------------------------------------------------- | ||
# Copyright (C) 2021-2023 Deepchecks (https://www.deepchecks.com) | ||
# | ||
# This file is part of Deepchecks. | ||
# Deepchecks is distributed under the terms of the GNU Affero General | ||
# Public License (version 3 or later). | ||
# You should have received a copy of the GNU Affero General Public License | ||
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>. | ||
# ---------------------------------------------------------------------------- | ||
# | ||
"""The confusion_matrix_report check module.""" | ||
import numpy as np | ||
|
||
from deepchecks.core import CheckResult | ||
from deepchecks.nlp import Context, SingleDatasetCheck | ||
from deepchecks.utils.abstracts.confusion_matrix_abstract import run_confusion_matrix_check | ||
|
||
__all__ = ['ConfusionMatrixReport'] | ||
|
||
|
||
class ConfusionMatrixReport(SingleDatasetCheck): | ||
"""Calculate the confusion matrix of the model on the given dataset. | ||
Parameters | ||
---------- | ||
normalize_display : bool , default: True: | ||
boolean that determines whether to normalize the values of the matrix in the display. | ||
n_samples : int , default: 10_000 | ||
number of samples to use for this check. | ||
random_state : int, default: 42 | ||
random seed for all check internals. | ||
""" | ||
|
||
def __init__(self, | ||
normalize_display: bool = True, | ||
n_samples: int = 1_000_000, | ||
random_state: int = 42, | ||
**kwargs): | ||
super().__init__(**kwargs) | ||
self.normalize_display = normalize_display | ||
self.n_samples = n_samples | ||
self.random_state = random_state | ||
|
||
def run_logic(self, context: Context, dataset_kind) -> CheckResult: | ||
"""Run check. | ||
Returns | ||
------- | ||
CheckResult | ||
value is numpy array of the confusion matrix, displays the confusion matrix | ||
Raises | ||
------ | ||
DeepchecksValueError | ||
If the data is not a Dataset instance with a label | ||
""" | ||
dataset = context.get_data_by_kind(dataset_kind).sample(self.n_samples, random_state=self.random_state) | ||
y_true = np.asarray(dataset.label) | ||
y_pred = np.array(context.model.predict(dataset)).reshape(len(y_true), ) | ||
|
||
return run_confusion_matrix_check(y_pred, y_true, context.with_display, self.normalize_display) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# ---------------------------------------------------------------------------- | ||
# Copyright (C) 2021-2023 Deepchecks (https://www.deepchecks.com) | ||
# | ||
# This file is part of Deepchecks. | ||
# Deepchecks is distributed under the terms of the GNU Affero General | ||
# Public License (version 3 or later). | ||
# You should have received a copy of the GNU Affero General Public License | ||
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>. | ||
# ---------------------------------------------------------------------------- | ||
# | ||
"""The confusion_matrix_report check module.""" | ||
from typing import List | ||
|
||
import numpy as np | ||
import plotly.graph_objects as go | ||
from sklearn.metrics import confusion_matrix | ||
|
||
from deepchecks.core import CheckResult | ||
from deepchecks.utils.strings import format_number_if_not_nan | ||
|
||
__all__ = ['create_confusion_matrix_figure', 'run_confusion_matrix_check'] | ||
|
||
|
||
def run_confusion_matrix_check(y_pred: np.ndarray, y_true: np.ndarray, with_display=True, | ||
normalize_display=True) -> CheckResult: | ||
"""Calculate confusion matrix based on predictions and true label values.""" | ||
total_classes = sorted([str(x) for x in set(y_pred).union(set(y_true))]) | ||
result = confusion_matrix(y_true, y_pred) | ||
|
||
if with_display: | ||
fig = create_confusion_matrix_figure(result, total_classes, normalize_display) | ||
else: | ||
fig = None | ||
|
||
return CheckResult(result, display=fig) | ||
|
||
|
||
def create_confusion_matrix_figure(confusion_matrix_data: np.ndarray, classes_names: List[str], | ||
normalize_display: bool): | ||
"""Create a confusion matrix figure. | ||
Parameters | ||
---------- | ||
confusion_matrix_data: np.ndarray | ||
2D array containing the confusion matrix. | ||
classes_names: List[str] | ||
the names of the classes to display as the axis. | ||
normalize_display: bool | ||
if True will also show normalized values by the true values. | ||
Returns | ||
------- | ||
plotly Figure object | ||
confusion matrix figure | ||
""" | ||
if normalize_display: | ||
confusion_matrix_norm = confusion_matrix_data.astype('float') / \ | ||
(confusion_matrix_data.sum(axis=1)[:, np.newaxis] + np.finfo(float).eps) * 100 | ||
z = np.vectorize(format_number_if_not_nan)(confusion_matrix_norm) | ||
text_template = '%{z}%<br>(%{text})' | ||
color_bar_title = '% out of<br>True Values' | ||
plot_title = 'Percent Out of True Values (Count)' | ||
else: | ||
z = confusion_matrix_data | ||
color_bar_title = None | ||
text_template = '%{text}' | ||
plot_title = 'Value Count' | ||
|
||
fig = go.Figure(data=go.Heatmap( | ||
x=classes_names, y=classes_names, z=z, | ||
text=confusion_matrix_data, texttemplate=text_template)) | ||
fig.data[0].colorbar.title = color_bar_title | ||
fig.update_layout(title=plot_title) | ||
fig.update_layout(height=600) | ||
fig.update_xaxes(title='Predicted Value', type='category', scaleanchor='y', constrain='domain') | ||
fig.update_yaxes(title='True value', type='category', constrain='domain', autorange='reversed') | ||
|
||
return fig |
File renamed without changes.
Oops, something went wrong.