Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove check from check result init #228

Merged
merged 8 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 24 additions & 7 deletions deepchecks/base/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import abc
import enum
import re
import typing
from collections import OrderedDict
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, List, Union, Dict, cast

__all__ = ['CheckResult', 'BaseCheck', 'SingleDatasetBaseCheck', 'CompareDatasetsBaseCheck', 'TrainTestBaseCheck',
Expand Down Expand Up @@ -128,8 +130,9 @@ class CheckResult:
header: str
display: List[Union[Callable, str, pd.DataFrame, Styler]]
condition_results: List[ConditionResult]
check: typing.ClassVar

def __init__(self, value, header: str = None, check=None, display: Any = None):
def __init__(self, value, header: str = None, display: Any = None):
"""Init check result.

Args:
Expand All @@ -140,8 +143,7 @@ def __init__(self, value, header: str = None, check=None, display: Any = None):
display (List): Objects to be displayed (dataframe or function or html)
"""
self.value = value
self.header = header or (check and check.name()) or None
self.check = check
self.header = header
self.condition_results = []

if display is not None and not isinstance(display, List):
Expand All @@ -154,10 +156,9 @@ def __init__(self, value, header: str = None, check=None, display: Any = None):
raise DeepchecksValueError(f'Can\'t display item of type: {type(item)}')

def _ipython_display_(self):
if self.header:
display_html(f'<h4>{self.header}</h4>', raw=True)
if self.check and '__doc__' in dir(self.check):
docs = self.check.__doc__
display_html(f'<h4>{self.get_header()}</h4>', raw=True)
if hasattr(self.check, '__doc__'):
docs = self.check.__doc__ or ''
# Take first non-whitespace line.
summary = next((s for s in docs.split('\n') if not re.match('^\\s*$', s)), '')
display_html(f'<p>{summary}</p>', raw=True)
Expand All @@ -179,6 +180,10 @@ def __repr__(self):
"""Return default __repr__ function uses value."""
return self.value.__repr__()

def get_header(self):
"""Return header for display. if header was defined return it, else extract name of check class."""
return self.header or self.check.name()

def set_condition_results(self, results: List[ConditionResult]):
"""Set the conditions results for current check result."""
self.conditions_results = results
Expand All @@ -200,6 +205,16 @@ def get_conditions_sort_value(self):
return max([r.get_sort_value() for r in self.conditions_results])


def wrap_run(func, class_instance):
"""Wrap the run function of checks, and sets the `check` property on the check result."""
@wraps(func)
def wrapped(*args, **kwargs):
result = func(*args, **kwargs)
result.check = class_instance.__class__
return result
return wrapped


class BaseCheck(metaclass=abc.ABCMeta):
"""Base class for check."""

Expand All @@ -209,6 +224,8 @@ class BaseCheck(metaclass=abc.ABCMeta):
def __init__(self):
self._conditions = OrderedDict()
self._conditions_index = 0
# Replace the run function with wrapped run function
setattr(self, 'run', wrap_run(getattr(self, 'run'), self))

def conditions_decision(self, result: CheckResult) -> List[ConditionResult]:
"""Run conditions on given result."""
Expand Down
4 changes: 2 additions & 2 deletions deepchecks/base/display_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ def display_suite_result(suite_name: str, results: List[Union[CheckResult, Check
for cond_result in result.conditions_results:
sort_value = cond_result.get_sort_value()
icon = cond_result.get_icon()
conditions_table.append([icon, result.header, cond_result.name,
conditions_table.append([icon, result.get_header(), cond_result.name,
cond_result.details, sort_value])
if result.have_display():
display_table.append(result)
else:
others_table.append([result.header, 'Nothing found', 2])
others_table.append([result.get_header(), 'Nothing found', 2])
elif isinstance(result, CheckFailure):
msg = result.exception.__class__.__name__ + ': ' + str(result.exception)
name = result.check.name()
Expand Down
4 changes: 2 additions & 2 deletions deepchecks/base/suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ def run(
elif isinstance(check, SingleDatasetBaseCheck):
if check_datasets_policy in ['both', 'train'] and train_dataset is not None:
check_result = check.run(dataset=train_dataset, model=model)
check_result.header = f'{check_result.header} - Train Dataset'
check_result.header = f'{check_result.get_header()} - Train Dataset'
check_result.set_condition_results(check.conditions_decision(check_result))
results.append(check_result)
if check_datasets_policy in ['both', 'test'] and test_dataset is not None:
check_result = check.run(dataset=test_dataset, model=model)
check_result.header = f'{check_result.header} - Test Dataset'
check_result.header = f'{check_result.get_header()} - Test Dataset'
check_result.set_condition_results(check.conditions_decision(check_result))
results.append(check_result)
elif isinstance(check, ModelOnlyBaseCheck):
Expand Down
7 changes: 1 addition & 6 deletions deepchecks/checks/distribution/train_test_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,7 @@ def _calc_drift(self, train_dataset: Dataset, test_dataset: Dataset, feature_imp

displays = [headnote] + [displays_dict[col] for col in columns_order]

return CheckResult(
value=values_dict,
display=displays,
header='Train Test Drift',
check=self.__class__
)
return CheckResult(value=values_dict, display=displays, header='Train Test Drift')

def _calc_drift_per_column(self, train_column: pd.Series, test_column: pd.Series, column_name: Hashable,
column_type: str, feature_importances: pd.Series = None
Expand Down
3 changes: 1 addition & 2 deletions deepchecks/checks/distribution/trust_score_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ def filter_quantile(data):
'<h5>Top Trust Score Samples</h5>', top_k]

result = {'test': np.mean(test_trust_scores), 'train': np.mean(train_trust_scores)}
return CheckResult(result, check=self.__class__, display=display,
header='Trust Score Comparison: Train vs. Test')
return CheckResult(result, display=display, header='Trust Score Comparison: Train vs. Test')

def add_condition_mean_score_percent_decline_not_greater_than(self, threshold: float = 0.2):
"""Add condition.
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/integrity/data_duplicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def run(self, dataset: Dataset, model=None) -> CheckResult:
else:
display = None

return CheckResult(value=percent_duplicate, check=self.__class__, display=display)
return CheckResult(value=percent_duplicate, display=display)

def add_condition_ratio_not_greater_than(self, max_ratio: float = 0):
"""Add condition - require duplicate ratio to not surpass max_ratio.
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/integrity/dominant_frequency_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _dominant_frequency_change(self, dataset: Dataset, baseline_dataset: Dataset
else:
sorted_p_df = None

return CheckResult(p_dict, check=self.__class__, display=sorted_p_df)
return CheckResult(p_dict, display=sorted_p_df)

def add_condition_p_value_not_less_than(self, p_value_threshold: float = 0.0001):
"""Add condition - require min p value allowed per column.
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/integrity/is_single_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _is_single_value(self, dataset: Union[pd.DataFrame, Dataset]) -> CheckResult
value = None
display = None

return CheckResult(value, header='Single Value in Column', check=self.__class__, display=display)
return CheckResult(value, header='Single Value in Column', display=display)

def add_condition_not_single_value(self):
"""Add condition - not single value."""
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/integrity/label_ambiguity.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def run(self, dataset: Dataset, model=None) -> CheckResult:

percent_ambiguous = num_ambiguous/dataset.n_samples

return CheckResult(value=percent_ambiguous, check=self.__class__, display=display)
return CheckResult(value=percent_ambiguous, display=display)

def add_condition_ambiguous_sample_ratio_not_greater_than(self, max_ratio=0):
"""Add condition - require samples with multiple labels to not be more than max_ratio.
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/integrity/mixed_nulls.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _mixed_nulls(self, dataset: Union[pd.DataFrame, Dataset], feature_importance
else:
display = None

return CheckResult(result_dict, check=self.__class__, display=display)
return CheckResult(result_dict, display=display)

def add_condition_different_nulls_not_more_than(self, max_allowed_null_types: int = 1):
"""Add condition - require column not to have more than given number of different null values.
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/integrity/mixed_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _mixed_types(self, dataset: Union[pd.DataFrame, Dataset], feature_importance
else:
display = None

return CheckResult(result_dict, check=self.__class__, display=display)
return CheckResult(result_dict, display=display)

def _get_data_mix(self, column_data: pd.Series) -> dict:
if is_string_column(column_data):
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/integrity/new_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _new_category_train_test(self, train_dataset: Dataset, test_dataset: Dataset
else:
display = None
new_categories = {}
return CheckResult(new_categories, check=self.__class__, display=display)
return CheckResult(new_categories, display=display)

def add_condition_new_categories_not_greater_than(self, max_new: int = 0):
"""Add condition - require column not to have greater than given number of different new categories.
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/integrity/new_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _new_label_train_test(self, train_dataset: Dataset, test_dataset: Dataset):
display = None
result = {}

return CheckResult(result, check=self.__class__, display=display)
return CheckResult(result, display=display)

def add_condition_new_labels_not_greater_than(self, max_new: int = 0):
"""Add condition - require label column not to have greater than given number of different new labels.
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/integrity/rare_format_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def _rare_format_detection(self, dataset: t.Union[Dataset, pd.DataFrame],
display.append(f'\n\nColumn {key}:')
display.append(value)

return CheckResult(value=filtered_res, header='Rare Format Detection', check=self.__class__, display=display)
return CheckResult(value=filtered_res, header='Rare Format Detection', display=display)

def add_condition_ratio_of_rare_formats_not_greater_than(self, var: float = 0):
"""
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/integrity/special_chars.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _special_characters(self, dataset: Union[pd.DataFrame, Dataset],
self.n_top_columns, col='Column Name')
display = df_graph if len(df_graph) > 0 else None

return CheckResult(result, check=self.__class__, display=display)
return CheckResult(result, display=display)

def add_condition_ratio_of_special_characters_not_grater_than(self, max_ratio: float = 0.001):
"""Add condition - ratio of entirely special character in column.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _string_length_out_of_bounds(self, dataset: Union[pd.DataFrame, Dataset],
self.n_top_columns, col='Column Name')
display = df_graph if len(df_graph) > 0 else None

return CheckResult(results, check=self.__class__, display=display)
return CheckResult(results, display=display)

def add_condition_number_of_outliers_not_greater_than(self, max_outliers: int = 0):
"""Add condition - require column not to have more than given number of string length outliers.
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/integrity/string_mismatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _string_mismatch(self, dataset: Union[pd.DataFrame, Dataset],
else:
display = None

return CheckResult(result_dict, check=self.__class__, display=display)
return CheckResult(result_dict, display=display)

def add_condition_not_more_variants_than(self, num_max_variants: int):
"""Add condition - no more than given number of variants are allowed (per string baseform).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _string_mismatch_comparison(self, dataset: Union[pd.DataFrame, Dataset],
else:
display = None

return CheckResult(result_dict, check=self.__class__, display=display)
return CheckResult(result_dict, display=display)

def add_condition_no_new_variants(self):
"""Add condition - no new variants allowed in test data."""
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/methodology/boosting_overfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def display_func():
axes.xaxis.set_major_locator(MaxNLocator(integer=True))

result = {'test': test_scores, 'train': train_scores}
return CheckResult(result, check=self.__class__, display=display_func, header='Boosting Overfit')
return CheckResult(result, display=display_func, header='Boosting Overfit')

def add_condition_test_score_percent_decline_not_greater_than(self, threshold: float = 0.05):
"""Add condition.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def run(self, train_dataset: Dataset, test_dataset: Dataset, model: object = Non
})
return CheckResult(
value=result,
header='Datasets size.',
display=result
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def _date_train_test_leakage_duplicates(self, train_dataset: Dataset, test_datas
display = None
return_value = 0

return CheckResult(value=return_value, header='Date Train-Test Leakage (duplicates)',
check=self.__class__, display=display)
return CheckResult(value=return_value, header='Date Train-Test Leakage (duplicates)', display=display)

def add_condition_leakage_ratio_not_greater_than(self, max_ratio: float = 0):
"""Add condition - require leakage ratio to not surpass max_ratio.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ def _date_train_test_leakage_overlap(self, train_dataset: Dataset, test_dataset:
display = None
return_value = 0

return CheckResult(value=return_value,
header='Date Train-Test Leakage (overlap)',
check=self.__class__,
display=display)
return CheckResult(value=return_value, header='Date Train-Test Leakage (overlap)', display=display)

def add_condition_leakage_ratio_not_greater_than(self, max_ratio: float = 0):
"""Add condition - require leakage ratio to not surpass max_ratio.
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/methodology/identifier_leakage.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def plot():
'For Identifier columns (Index/Date) PPS should be nearly 0, otherwise date and index have some '
'predictive effect on the label.']

return CheckResult(value=s_ppscore.to_dict(), display=[plot, *text], check=self.__class__)
return CheckResult(value=s_ppscore.to_dict(), display=[plot, *text])

def add_condition_pps_not_greater_than(self, max_pps: float = 0):
"""Add condition - require columns not to have a greater pps than given max.
Expand Down
3 changes: 1 addition & 2 deletions deepchecks/checks/methodology/index_leakage.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def _index_train_test_leakage(self, train_dataset: Dataset, test_dataset: Datase
size_in_test = 0
display = None

return CheckResult(value=size_in_test, header='Index Train-Test Leakage', check=self.__class__,
display=display)
return CheckResult(value=size_in_test, header='Index Train-Test Leakage', display=display)

def add_condition_ratio_not_greater_than(self, max_ratio: float = 0):
"""Add condition - require index leakage ratio to not surpass max_ratio.
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/methodology/model_inference_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _model_inference_time_check(

result = result / number_of_samples

return CheckResult(value=result, check=type(self), display=(
return CheckResult(value=result, display=(
'Average model inference time for one sample (in seconds): '
f'{format_number(result, floating_point=8)}'
))
Expand Down
3 changes: 1 addition & 2 deletions deepchecks/checks/methodology/performance_overfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def plot_overfit():
plt.xticks(rotation=30)
plt.legend(res_df.columns, loc='upper right', bbox_to_anchor=(1.45, 1.02))

return CheckResult(result, check=self.__class__, header='Train-Test Difference Overfit',
display=[plot_overfit])
return CheckResult(result, header='Train-Test Difference Overfit', display=[plot_overfit])

def add_condition_difference_not_greater_than(self: TD, threshold: float) -> TD:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def plot(n_show_top=self.n_show_top):
' actually due to data',
'leakage - meaning that the feature holds information that is based on the label to begin with.']

return CheckResult(value=s_ppscore.to_dict(), display=[plot, *text], check=self.__class__,
header='Single Feature Contribution')
return CheckResult(value=s_ppscore.to_dict(), display=[plot, *text], header='Single Feature Contribution')

def add_condition_feature_pps_not_greater_than(self: FC, threshold: float = 0.8) -> FC:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def plot():
'that was powerful in train but not in test can be explained by leakage in train that is not '
'relevant to a new dataset.']

return CheckResult(value=s_difference.to_dict(), display=[plot, *text], check=self.__class__,
return CheckResult(value=s_difference.to_dict(), display=[plot, *text],
header='Single Feature Contribution Train-Test')

def add_condition_feature_pps_difference_not_greater_than(self: FC, threshold: float = 0.2) -> FC:
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/methodology/train_test_samples_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _data_sample_leakage_report(self, test_dataset: Dataset, train_dataset: Data
of test data samples appear in train data'
display = [user_msg, duplicate_rows_df.head(10)] if dup_ratio else None

return CheckResult(dup_ratio, header='Train Test Samples Mix', check=self.__class__, display=display)
return CheckResult(dup_ratio, header='Train Test Samples Mix', display=display)

def add_condition_duplicates_ratio_not_greater_than(self, max_ratio: float = 0.1):
"""Add condition - require max allowed ratio of test data samples to appear in train data.
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/methodology/unused_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def plot_feature_importance():
last_variable_feature_index:].values.tolist()
}}

return CheckResult(return_value, check=self.__class__, header='Unused Features', display=display_list)
return CheckResult(return_value, header='Unused Features', display=display_list)

def add_condition_number_of_high_variance_unused_features_not_greater_than(
self, max_high_variance_unused_features: int = 5):
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/checks/overview/columns_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ def _columns_info(self, dataset: Dataset, feature_importances: pd.Series=None):
df = pd.DataFrame.from_dict(value, orient='index', columns=['role'])
df = df.transpose()

return CheckResult(value, check=self.__class__, header='Columns Info', display=df)
return CheckResult(value, header='Columns Info', display=df)

2 changes: 1 addition & 1 deletion deepchecks/checks/overview/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ def display():
profile = ProfileReport(dataset, title='Dataset Report', explorative=True, minimal=True)
profile.to_notebook_iframe()

return CheckResult(dataset.shape, check=self.__class__, display=display)
return CheckResult(dataset.shape, display=display)
2 changes: 1 addition & 1 deletion deepchecks/checks/overview/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ def highlight_not_default(data):
footnote = '<p style="font-size:0.7em"><i>Colored rows are parameters with non-default values</i></p>'
display = [f'Model Type: {model_type}', model_param_df, footnote]

return CheckResult(value, check=self.__class__, header='Model Info', display=display)
return CheckResult(value, header='Model Info', display=display)