/
confusion_matrix_report.py
82 lines (69 loc) · 3.38 KB
/
confusion_matrix_report.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# ----------------------------------------------------------------------------
# Copyright (C) 2021-2023 Deepchecks (https://www.deepchecks.com)
#
# This file is part of Deepchecks.
# Deepchecks is distributed under the terms of the GNU Affero General
# Public License (version 3 or later).
# You should have received a copy of the GNU Affero General Public License
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>.
# ----------------------------------------------------------------------------
#
"""The confusion_matrix_report check module."""
import numpy as np
from deepchecks.core import CheckResult
from deepchecks.tabular import Context, SingleDatasetCheck
from deepchecks.utils.abstracts.confusion_matrix_abstract import (misclassified_samples_lower_than_condition,
run_confusion_matrix_check)
from deepchecks.utils.strings import format_percent
__all__ = ['ConfusionMatrixReport']
class ConfusionMatrixReport(SingleDatasetCheck):
"""Calculate the confusion matrix of the model on the given dataset.
Parameters
----------
normalize_display : bool , default: True:
boolean that determines whether to normalize the values of the matrix in the display.
n_samples : int , default: 1_000_000
number of samples to use for this check.
random_state : int, default: 42
random seed for all check internals.
"""
def __init__(self,
normalize_display: bool = True,
n_samples: int = 1_000_000,
random_state: int = 42,
**kwargs):
super().__init__(**kwargs)
self.normalize_display = normalize_display
self.n_samples = n_samples
self.random_state = random_state
def run_logic(self, context: Context, dataset_kind) -> CheckResult:
"""Run check.
Returns
-------
CheckResult
value is numpy array of the confusion matrix, displays the confusion matrix
Raises
------
DeepchecksValueError
If the data is not a Dataset instance with a label
"""
dataset = context.get_data_by_kind(dataset_kind).sample(self.n_samples, random_state=self.random_state)
context.assert_classification_task()
y_true = dataset.label_col
y_pred = np.array(context.model.predict(dataset.features_columns)).reshape(len(y_true), )
return run_confusion_matrix_check(y_pred, y_true, context.with_display, self.normalize_display)
def add_condition_misclassified_samples_lower_than_condition(self, misclassified_samples_threshold: float = 0.2):
"""Add condition - Misclassified samples lower than threshold.
Condition validates if the misclassified cell size/samples are lower than the threshold based on the
`misclassified_samples_threshold` parameter.
Parameters
----------
misclassified_samples_threshold: float, default: 0.20
Ratio of samples to be used for comparison in the condition (Value should be between 0 - 1 inclusive)
"""
return self.add_condition(
f'Misclassified cell size lower than {format_percent(misclassified_samples_threshold)} '
'of the total samples',
misclassified_samples_lower_than_condition,
misclassified_samples_threshold=misclassified_samples_threshold
)