Skip to content

Commit

Permalink
allow checks to return CheckFailure (#1252)
Browse files Browse the repository at this point in the history
* allow checks to return CheckFailure

* add _ipython_display_ to CheckFailure
  • Loading branch information
benisraeldan committed Apr 12, 2022
1 parent b09450b commit d6ce78f
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 37 deletions.
9 changes: 9 additions & 0 deletions deepchecks/core/check_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,12 @@ def __repr__(self):
tb_str = traceback.format_exception(etype=type(self.exception), value=self.exception,
tb=self.exception.__traceback__)
return ''.join(tb_str)

def _ipython_display_(self):
"""Display the check failure."""
check_html = f'<h4>{self.header}</h4>'
if hasattr(self.check.__class__, '__doc__'):
summary = get_docs_summary(self.check)
check_html += f'<p>{summary}</p>'
check_html += f'<p style="color:red"> {self.exception}</p>'
display_html(check_html, raw=True)
5 changes: 4 additions & 1 deletion deepchecks/core/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections import OrderedDict
from typing import Any, Callable, List, Union, Dict, Type, ClassVar, Optional

from deepchecks.core.check_result import CheckResult
from deepchecks.core.check_result import CheckResult, CheckFailure
from deepchecks.core.condition import Condition, ConditionCategory, ConditionResult
from deepchecks.core.errors import DeepchecksValueError
from deepchecks.utils.strings import split_camel_case
Expand Down Expand Up @@ -124,6 +124,9 @@ def params(self, show_defaults: bool = False) -> Dict:

def finalize_check_result(self, check_result: CheckResult) -> CheckResult:
"""Finalize the check result by adding the check instance and processing the conditions."""
if isinstance(check_result, CheckFailure):
return check_result

if not isinstance(check_result, CheckResult):
raise DeepchecksValueError(f'Check {self.name()} expected to return CheckResult but got: '
+ type(check_result).__name__)
Expand Down
22 changes: 13 additions & 9 deletions deepchecks/tabular/checks/performance/model_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from typing import Callable, Dict, Tuple, Union
from sklearn import preprocessing

from deepchecks import CheckFailure
from deepchecks.core import CheckResult, ConditionResult, ConditionCategory
from deepchecks.core.errors import DeepchecksProcessError
from deepchecks.tabular import Context, TrainTestCheck, Dataset
from deepchecks.utils.metrics import ModelType
from deepchecks.utils.performance.error_model import model_error_contribution, error_model_display
Expand Down Expand Up @@ -138,15 +140,17 @@ def scoring_func(dataset: Dataset):

cat_features = train_dataset.cat_features
numeric_features = train_dataset.numerical_features

error_fi, error_model_predicted = model_error_contribution(train_dataset.features_columns,
train_scores,
test_dataset.features_columns,
test_scores,
numeric_features,
cat_features,
min_error_model_score=self.min_error_model_score,
random_state=self.random_state)
try:
error_fi, error_model_predicted = model_error_contribution(train_dataset.features_columns,
train_scores,
test_dataset.features_columns,
test_scores,
numeric_features,
cat_features,
min_error_model_score=self.min_error_model_score,
random_state=self.random_state)
except DeepchecksProcessError as e:
return CheckFailure(self, e)

display, value = error_model_display(error_fi,
error_model_predicted,
Expand Down
14 changes: 8 additions & 6 deletions deepchecks/vision/base_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ def run(

p_bar = ProgressBar('Computing Check', 1, unit='Check')
result = self.compute(context, DatasetKind.TRAIN)
footnote = context.get_is_sampled_footnote(DatasetKind.TRAIN)
if footnote:
result.display.append(footnote)
if isinstance(result, CheckResult):
footnote = context.get_is_sampled_footnote(DatasetKind.TRAIN)
if footnote:
result.display.append(footnote)
result = self.finalize_check_result(result)
p_bar.inc_progress()
p_bar.close()
Expand Down Expand Up @@ -156,9 +157,10 @@ def run(

p_bar = ProgressBar('Computing Check', 1, unit='Check')
result = self.compute(context)
footnote = context.get_is_sampled_footnote()
if footnote:
result.display.append(footnote)
if isinstance(result, CheckResult):
footnote = context.get_is_sampled_footnote()
if footnote:
result.display.append(footnote)
result = self.finalize_check_result(result)
p_bar.inc_progress()
p_bar.close()
Expand Down
24 changes: 14 additions & 10 deletions deepchecks/vision/checks/performance/model_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import pandas as pd
import torch

from deepchecks import CheckFailure
from deepchecks.core import CheckResult, DatasetKind
from deepchecks.core.errors import DeepchecksValueError
from deepchecks.core.errors import DeepchecksValueError, DeepchecksProcessError
from deepchecks.utils.performance.error_model import error_model_display_dataframe, model_error_contribution
from deepchecks.utils.single_sample_metrics import per_sample_cross_entropy
from deepchecks.vision.utils.image_properties import default_image_properties, validate_properties
Expand Down Expand Up @@ -144,15 +145,18 @@ def compute(self, context: Context) -> CheckResult:
train_property_df = pd.DataFrame(self._train_properties).dropna(axis=1, how='all')
test_property_df = pd.DataFrame(self._test_properties)[train_property_df.columns]

error_fi, error_model_predicted = \
model_error_contribution(train_property_df,
self._train_scores,
test_property_df,
self._test_scores,
train_property_df.columns.to_list(),
[],
min_error_model_score=self.min_error_model_score,
random_state=self.random_state)
try:
error_fi, error_model_predicted = \
model_error_contribution(train_property_df,
self._train_scores,
test_property_df,
self._test_scores,
train_property_df.columns.to_list(),
[],
min_error_model_score=self.min_error_model_score,
random_state=self.random_state)
except DeepchecksProcessError as e:
return CheckFailure(self, e)

display, value = error_model_display_dataframe(error_fi,
error_model_predicted,
Expand Down
11 changes: 4 additions & 7 deletions tests/checks/performance/model_error_analysis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
#
"""Tests for segment performance check."""
import numpy as np
from hamcrest import assert_that, calling, raises, has_length, has_items, close_to
from hamcrest import assert_that, calling, raises, has_length, has_items, close_to, instance_of
from scipy.special import softmax
from sklearn.metrics import log_loss

from deepchecks import CheckFailure
from deepchecks.core import ConditionCategory
from deepchecks.core.errors import DeepchecksValueError, DeepchecksNotSupportedError, DeepchecksProcessError
from deepchecks.tabular.checks.performance.model_error_analysis import ModelErrorAnalysis
Expand Down Expand Up @@ -46,9 +47,7 @@ def test_model_error_analysis_regression_not_meaningful(diabetes_split_dataset_a
train, val, model = diabetes_split_dataset_and_model

# Assert
assert_that(calling(ModelErrorAnalysis().run).with_args(train, val, model),
raises(DeepchecksProcessError,
'Unable to train meaningful error model'))
assert_that(ModelErrorAnalysis().run(train, val, model), instance_of(CheckFailure))


def test_model_error_analysis_classification(iris_labeled_dataset, iris_adaboost):
Expand All @@ -64,9 +63,7 @@ def test_binary_string_model_info_object(iris_binary_string_split_dataset_and_mo
train_ds, test_ds, clf = iris_binary_string_split_dataset_and_model

# Assert
assert_that(calling(ModelErrorAnalysis().run).with_args(train_ds, test_ds, clf),
raises(DeepchecksProcessError,
'Unable to train meaningful error model'))
assert_that(ModelErrorAnalysis().run(train_ds, test_ds, clf), instance_of(CheckFailure))


def test_condition_fail(iris_labeled_dataset, iris_adaboost):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
#
"""Test functions of the VISION model error analysis."""

from hamcrest import assert_that, equal_to, calling, raises, close_to
from hamcrest import assert_that, equal_to, instance_of

from deepchecks.core.errors import DeepchecksProcessError
from deepchecks import CheckFailure
from deepchecks.vision.checks import ModelErrorAnalysis


Expand Down Expand Up @@ -50,6 +50,6 @@ def test_classification_not_interesting(mnist_dataset_train, mock_trained_mnist,
train, test = mnist_dataset_train, mnist_dataset_train

# Assert
assert_that(calling(check.run).with_args(
assert_that(check.run(
train, test, mock_trained_mnist,
device=device), raises(DeepchecksProcessError))
device=device), instance_of(CheckFailure))

0 comments on commit d6ce78f

Please sign in to comment.