Skip to content

Commit

Permalink
Add passed function to suite result (#1594)
Browse files Browse the repository at this point in the history
* Add passed function to suite result

* Fix tests

* Fix few tests and conditions

* Fix imports
  • Loading branch information
matanper committed Jun 12, 2022
1 parent 2e40cca commit 8b33000
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 96 deletions.
6 changes: 3 additions & 3 deletions deepchecks/core/check_result.py
Expand Up @@ -49,7 +49,7 @@
from deepchecks.core.checks import BaseCheck


__all__ = ['CheckResult', 'CheckFailure']
__all__ = ['CheckResult', 'CheckFailure', 'BaseCheckResult']


TDisplayCallable = Callable[[], None]
Expand Down Expand Up @@ -161,9 +161,9 @@ def have_display(self) -> bool:
"""Return if this check has display."""
return bool(self.display)

def passed_conditions(self) -> bool:
def passed_conditions(self, fail_if_warning=True) -> bool:
"""Return if this check has no passing condition results."""
return all((r.is_pass for r in self.conditions_results))
return all((r.is_pass(fail_if_warning) for r in self.conditions_results))

@property
def priority(self) -> int:
Expand Down
10 changes: 5 additions & 5 deletions deepchecks/core/condition.py
Expand Up @@ -75,7 +75,6 @@ class ConditionResult:
"""

is_pass: bool
category: ConditionCategory
details: str
name: str
Expand Down Expand Up @@ -114,10 +113,11 @@ def priority(self) -> int:
return 2
return 3 # if error

@property
def is_pass(self) -> bool:
"""Return true if the category is PASS."""
return self.category == ConditionCategory.PASS
def is_pass(self, fail_if_warning=True) -> bool:
"""Return true if the condition has passed."""
passed_categories = [ConditionCategory.PASS] if fail_if_warning else \
[ConditionCategory.PASS, ConditionCategory.WARN]
return self.category in passed_categories

def get_icon(self):
"""Return icon of the result to display."""
Expand Down
49 changes: 40 additions & 9 deletions deepchecks/core/suite.py
Expand Up @@ -14,7 +14,7 @@
import io
import warnings
from collections import OrderedDict
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import List, Optional, Sequence, Set, Tuple, Union

import jsonpickle
from IPython.core.display import display, display_html
Expand Down Expand Up @@ -290,19 +290,50 @@ def to_wandb(
with wandb_run(**wandb_kwargs) as run:
run.log(WandbSerializer(self).serialize())

def get_failures(self) -> Dict[str, CheckFailure]:
"""Get all the failed checks.
def get_checks_not_ran(self) -> List[CheckFailure]:
"""Get all the check results which did not run (unable to run due to missing parameters, exception, etc).
Returns
-------
Dict[str, CheckFailure]
List[CheckFailure]
All the check failures in the suite.
"""
failures = {}
for res in self.results:
if isinstance(res, CheckFailure):
failures[res.header] = res
return failures
return self.select_results(self.failures)

def get_checks_not_passed(self, fail_if_warning=True) -> List[CheckResult]:
"""Get all the check results that have not passing condition.
Parameters
----------
fail_if_warning: bool, Default: True
Whether conditions should fail on status of warning
Returns
-------
List[CheckResult]
All the check results in the suite that have failing conditions.
"""
return [r for r in self.select_results(self.results_with_conditions)
if not r.passed_conditions(fail_if_warning)]

def passed(self, fail_if_warning: bool = True, fail_if_check_not_run: bool = False) -> bool:
"""Return whether this suite result has passed. Pass value is derived from condition results of all individual\
checks, and may consider checks that didn't run.
Parameters
----------
fail_if_warning: bool, Default: True
Whether conditions should fail on status of warning
fail_if_check_not_run: bool, Default: False
Whether checks that didn't run (missing parameters, exception, etc) should fail the suite result.
Returns
-------
bool
"""
not_run_pass = len(self.get_checks_not_ran()) == 0 if fail_if_check_not_run else True
conditions_pass = len(self.get_checks_not_passed(fail_if_warning)) == 0
return conditions_pass and not_run_pass

@classmethod
def from_json(cls, json_res: str):
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/tabular/dataset.py
Expand Up @@ -744,7 +744,7 @@ def numerical_features(self) -> t.List[Hashable]:
Returns
-------
t.List[Hashable]
List of categorical feature names.
List of numerical feature names.
"""
return list(self._numerical_features)

Expand Down
22 changes: 10 additions & 12 deletions deepchecks/vision/checks/model_evaluation/class_performance.py
Expand Up @@ -159,11 +159,13 @@ def condition(check_result: pd.DataFrame):
test_scores = check_result.loc[check_result['Dataset'] == 'Test']
not_passed_test = test_scores.loc[test_scores['Value'] <= min_score]
if len(not_passed_test):
details = f'Found metrics with scores below threshold:\n' \
f'{not_passed_test[["Class Name", "Metric", "Value"]].to_dict("records")}'
not_passed_list = [{'Class': row['Class Name'], 'Metric': row['Metric'],
'Score': format_number(row['Value'])}
for _, row in not_passed_test.iterrows()]
details = f'Found metrics with scores below threshold:\n{not_passed_list}'
return ConditionResult(ConditionCategory.FAIL, details)
else:
min_metric = test_scores.iloc[test_scores['Value'].idmin()]
min_metric = test_scores.iloc[test_scores['Value'].idxmin()]
details = f'Found minimum score for {min_metric["Metric"]} metric of value ' \
f'{format_number(min_metric["Value"])} for class {min_metric["Class Name"]}'
return ConditionResult(ConditionCategory.PASS, details)
Expand All @@ -189,6 +191,7 @@ def condition(check_result: pd.DataFrame) -> ConditionResult:
test_scores = check_result.loc[check_result['Dataset'] == 'Test']
train_scores = check_result.loc[check_result['Dataset'] == 'Train']
max_degradation = ('', -np.inf)
num_failures = 0

def update_max_degradation(diffs, class_name):
nonlocal max_degradation
Expand All @@ -198,7 +201,6 @@ def update_max_degradation(diffs, class_name):
f'{max_scorer} and class {class_name}', max_diff

classes = check_result['Class Name'].unique()
explained_failures = []

for class_name in classes:
test_scores_class = test_scores.loc[test_scores['Class Name'] == class_name]
Expand All @@ -211,14 +213,10 @@ def update_max_degradation(diffs, class_name):
diff = {score_name: _ratio_of_change_calc(score, test_scores_dict[score_name])
for score_name, score in train_scores_dict.items()}
update_max_degradation(diff, class_name)
failed_scores = [k for k, v in diff.items() if v >= threshold]
for score_name in failed_scores:
explained_failures.append(f'{score_name} for class {class_name} '
f'(train={format_number(train_scores_dict[score_name])} '
f'test={format_number(test_scores_dict[score_name])})')

if explained_failures:
message = '\n'.join(explained_failures)
num_failures += len([v for v in diff.values() if v >= threshold])

if num_failures > 0:
message = f'{num_failures} classes scores failed. ' + max_degradation[0]
return ConditionResult(ConditionCategory.FAIL, message)
else:
message = max_degradation[0]
Expand Down
Expand Up @@ -19,6 +19,7 @@
from deepchecks.core import CheckResult, ConditionResult, DatasetKind
from deepchecks.core.condition import ConditionCategory
from deepchecks.core.errors import DeepchecksValueError
from deepchecks.utils.strings import format_number
from deepchecks.vision import Batch, Context, SingleDatasetCheck
from deepchecks.vision.metrics_utils.object_detection_precision_recall import ObjectDetectionAveragePrecision
from deepchecks.vision.vision_data import TaskType
Expand Down Expand Up @@ -105,7 +106,7 @@ def compute(self, context: Context, dataset_kind: DatasetKind.TRAIN) -> CheckRes
def add_condition_greater_than(self, threshold: float) -> ConditionResult:
"""Add condition - the result is greater than the threshold."""
def condition(check_result):
details = f'The score {self.metric_name} is {check_result["score"]}'
details = f'The score {self.metric_name} is {format_number(check_result["score"])}'
if check_result['score'] > threshold:
return ConditionResult(ConditionCategory.PASS, details)
else:
Expand All @@ -115,7 +116,7 @@ def condition(check_result):
def add_condition_greater_or_equal(self, threshold: float) -> ConditionResult:
"""Add condition - the result is greater or equal to the threshold."""
def condition(check_result):
details = f'The score {self.metric_name} is {check_result["score"]}'
details = f'The score {self.metric_name} is {format_number(check_result["score"])}'
if check_result['score'] >= threshold:
return ConditionResult(ConditionCategory.PASS, details)
else:
Expand All @@ -125,7 +126,7 @@ def condition(check_result):
def add_condition_less_than(self, threshold: float) -> ConditionResult:
"""Add condition - the result is less than the threshold."""
def condition(check_result):
details = f'The score {self.metric_name} is {check_result["score"]}'
details = f'The score {self.metric_name} is {format_number(check_result["score"])}'
if check_result['score'] < threshold:
return ConditionResult(ConditionCategory.PASS, details)
else:
Expand All @@ -136,7 +137,7 @@ def add_condition_less_or_equal(self, threshold: float) -> ConditionResult:
"""Add condition - the result is less or equal to the threshold."""

def condition(check_result):
details = f'The score {self.metric_name} is {check_result["score"]}'
details = f'The score {self.metric_name} is {format_number(check_result["score"])}'
if check_result['score'] <= threshold:
return ConditionResult(ConditionCategory.PASS, details)
else:
Expand Down
Expand Up @@ -34,16 +34,19 @@

import pandas as pd

from deepchecks.tabular import Dataset
from deepchecks.tabular.checks import StringMismatch

data = {'col1': ['Deep', 'deep', 'deep!!!', '$deeP$', 'earth', 'foo', 'bar', 'foo?']}
df = pd.DataFrame(data=data)
result = StringMismatch().run(df)
dataset = Dataset(df, cat_features=['col1'])
result = StringMismatch().run(dataset)
result.show()

#%%
# Define a Condition
# ==================

check = StringMismatch().add_condition_no_variants()
result = check.run(df)
result = check.run(dataset)
result.show(show_additional_outputs=False)
46 changes: 44 additions & 2 deletions tests/base/check_suite_test.py
Expand Up @@ -13,6 +13,7 @@

from hamcrest import assert_that, calling, equal_to, has_length, instance_of, is_, raises

from deepchecks import SuiteResult, ConditionResult, ConditionCategory
from deepchecks.core import CheckFailure, CheckResult
from deepchecks.core.errors import DeepchecksValueError
from deepchecks.tabular import SingleDatasetCheck, Suite, TrainTestCheck
Expand Down Expand Up @@ -158,5 +159,46 @@ def test_get_error(iris_split_dataset_and_model_custom):
tabular_checks.ModelErrorAnalysis())

result = suite.run(train_dataset=iris_train, test_dataset=iris_test, model=iris_model)
assert_that(result.get_failures(), has_length(1))
assert_that(result.get_failures()["Model Error Analysis"], instance_of(CheckFailure))
assert_that(result.get_checks_not_ran(), has_length(1))
assert_that(result.get_checks_not_ran()[0], instance_of(CheckFailure))


def test_suite_result_checks_not_passed():
# Arrange
result1 = CheckResult(0, 'check1')
result1.conditions_results = [ConditionResult(ConditionCategory.PASS)]
result2 = CheckResult(0, 'check2')
result2.conditions_results = [ConditionResult(ConditionCategory.WARN)]
result3 = CheckResult(0, 'check3')
result3.conditions_results = [ConditionResult(ConditionCategory.FAIL)]

# Act & Assert
not_passed_checks = SuiteResult('test', [result1, result2]).get_checks_not_passed()
assert_that(not_passed_checks, has_length(1))
not_passed_checks = SuiteResult('test', [result1, result2]).get_checks_not_passed(fail_if_warning=False)
assert_that(not_passed_checks, has_length(0))
not_passed_checks = SuiteResult('test', [result1, result2, result3]).get_checks_not_passed()
assert_that(not_passed_checks, has_length(2))


def test_suite_result_passed_fn():
# Arrange
result1 = CheckResult(0, 'check1')
result1.conditions_results = [ConditionResult(ConditionCategory.PASS)]
result2 = CheckResult(0, 'check2')
result2.conditions_results = [ConditionResult(ConditionCategory.WARN)]
result3 = CheckResult(0, 'check3')
result3.conditions_results = [ConditionResult(ConditionCategory.FAIL)]
result4 = CheckFailure(tabular_checks.IsSingleValue(), DeepchecksValueError(''))

# Act & Assert
passed = SuiteResult('test', [result1, result2]).passed()
assert_that(passed, equal_to(False))
passed = SuiteResult('test', [result1, result2]).passed(fail_if_warning=False)
assert_that(passed, equal_to(True))
passed = SuiteResult('test', [result1, result2, result3]).passed(fail_if_warning=False)
assert_that(passed, equal_to(False))
passed = SuiteResult('test', [result1, result4]).passed()
assert_that(passed, equal_to(True))
passed = SuiteResult('test', [result1, result4]).passed(fail_if_check_not_run=True)
assert_that(passed, equal_to(False))
4 changes: 0 additions & 4 deletions tests/base/check_test.py
Expand Up @@ -121,25 +121,21 @@ def raise_(ex): # just to test error in condition
assert_that(decisions, has_items(
all_of(
has_property('name', 'condition A'),
has_property('is_pass', equal_to(True)),
has_property('category', ConditionCategory.PASS),
has_property('details', '')
),
all_of(
has_property('name', 'condition B'),
has_property('is_pass', equal_to(False)),
has_property('category', ConditionCategory.FAIL),
has_property('details', 'some result')
),
all_of(
has_property('name', 'condition C'),
has_property('is_pass', equal_to(False)),
has_property('category', ConditionCategory.WARN),
has_property('details', 'my actual')
),
all_of(
has_property('name', 'condition F'),
has_property('is_pass', equal_to(False)),
has_property('category', ConditionCategory.ERROR),
has_property('details', 'Exception in condition: Exception: fail')
)
Expand Down
9 changes: 5 additions & 4 deletions tests/base/utils.py
Expand Up @@ -12,7 +12,7 @@
import re
from typing import Pattern, Union

from hamcrest import all_of, has_property, matches_regexp
from hamcrest import all_of, has_property, matches_regexp, is_in
from hamcrest.core.matcher import Matcher

from deepchecks.core import ConditionCategory
Expand All @@ -31,7 +31,9 @@ def equal_condition_result(
category: ConditionCategory = None
) -> Matcher[Matcher[object]]:
if category is None:
category = ConditionCategory.PASS if is_pass else ConditionCategory.FAIL
possible_categories = [ConditionCategory.PASS] if is_pass else [ConditionCategory.FAIL, ConditionCategory.WARN]
else:
possible_categories = [category]

# Check if details is a regex class
if hasattr(details, 'match'):
Expand All @@ -40,8 +42,7 @@ def equal_condition_result(
details_matcher = details

return all_of(
has_property('is_pass', is_pass),
has_property('category', category),
has_property('category', is_in(possible_categories)),
has_property('details', details_matcher),
has_property('name', name)
)

0 comments on commit 8b33000

Please sign in to comment.