Skip to content

Commit

Permalink
Fix read more (#1217)
Browse files Browse the repository at this point in the history
* Fix read more

* Fix test

* Rename files for the docs

* Fix lint

* Fix init

* Fix docs import
  • Loading branch information
matanper committed Apr 8, 2022
1 parent b6824ff commit 904e010
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 89 deletions.
4 changes: 2 additions & 2 deletions deepchecks/tabular/checks/integrity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from .string_mismatch_comparison import StringMismatchComparison
from .dominant_frequency_change import DominantFrequencyChange
from .data_duplicates import DataDuplicates
from .new_category import CategoryMismatchTrainTest
from .new_label import NewLabelTrainTest
from .category_mismatch_train_test import CategoryMismatchTrainTest
from .new_label_train_test import NewLabelTrainTest
from .label_ambiguity import LabelAmbiguity


Expand Down
3 changes: 2 additions & 1 deletion deepchecks/tabular/checks/performance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
# ----------------------------------------------------------------------------
#
"""Module contains checks of model performance metrics."""
from .performance_report import PerformanceReport, MultiModelPerformanceReport
from .performance_report import PerformanceReport
from .multi_model_performance_report import MultiModelPerformanceReport
from .confusion_matrix_report import ConfusionMatrixReport
from .roc_report import RocReport
from .simple_model_comparison import SimpleModelComparison
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# ----------------------------------------------------------------------------
# Copyright (C) 2021-2022 Deepchecks (https://www.deepchecks.com)
#
# This file is part of Deepchecks.
# Deepchecks is distributed under the terms of the GNU Affero General
# Public License (version 3 or later).
# You should have received a copy of the GNU Affero General Public License
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>.
# ----------------------------------------------------------------------------
#
"""Module containing multi model performance report check."""
from typing import Dict, Callable, cast

import pandas as pd
import plotly.express as px

from deepchecks.core import CheckResult
from deepchecks.tabular import ModelComparisonCheck, ModelComparisonContext
from deepchecks.utils.metrics import ModelType


__all__ = ['MultiModelPerformanceReport']


class MultiModelPerformanceReport(ModelComparisonCheck):
"""Summarize performance scores for multiple models on test datasets.
Parameters
----------
alternative_scorers : Dict[str, Callable] , default: None
An optional dictionary of scorer name to scorer functions.
If none given, using default scorers
"""

def __init__(self, alternative_scorers: Dict[str, Callable] = None, **kwargs):
super().__init__(**kwargs)
self.user_scorers = alternative_scorers

def run_logic(self, multi_context: ModelComparisonContext):
"""Run check logic."""
first_context = multi_context[0]
scorers = first_context.get_scorers(self.user_scorers, class_avg=False)

if multi_context.task_type in [ModelType.MULTICLASS, ModelType.BINARY]:
plot_x_axis = ['Class', 'Model']
results = []

for context in multi_context:
test = context.test
model = context.model
label = cast(pd.Series, test.label_col)
n_samples = label.groupby(label).count()
results.extend(
[context.model_name, class_score, scorer.name, class_name, n_samples[class_name]]
for scorer in scorers
# scorer returns numpy array of results with item per class
for class_score, class_name in zip(scorer(model, test), test.classes)
)

results_df = pd.DataFrame(results, columns=['Model', 'Value', 'Metric', 'Class', 'Number of samples'])

else:
plot_x_axis = 'Model'
results = [
[context.model_name, scorer(context.model, context.test), scorer.name,
cast(pd.Series, context.test.label_col).count()]
for context in multi_context
for scorer in scorers
]
results_df = pd.DataFrame(results, columns=['Model', 'Value', 'Metric', 'Number of samples'])

fig = px.histogram(
results_df,
x=plot_x_axis,
y='Value',
color='Model',
barmode='group',
facet_col='Metric',
facet_col_spacing=0.05,
hover_data=['Number of samples'],
)

if multi_context.task_type in [ModelType.MULTICLASS, ModelType.BINARY]:
fig.update_xaxes(title=None, tickprefix='Class ', tickangle=60)
else:
fig.update_xaxes(title=None)

fig = (
fig.update_yaxes(title=None, matches=None)
.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1]))
.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True))
)

return CheckResult(results_df, display=[fig])
77 changes: 3 additions & 74 deletions deepchecks/tabular/checks/performance/performance_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@
#
"""Module containing performance report check."""
from typing import Callable, TypeVar, Dict, cast

import pandas as pd
import plotly.express as px

from deepchecks.core import CheckResult, ConditionResult
from deepchecks.core.condition import ConditionCategory
from deepchecks.core.errors import DeepchecksValueError
from deepchecks.tabular import Context, ModelComparisonContext, TrainTestCheck, ModelComparisonCheck
from deepchecks.tabular import Context, TrainTestCheck
from deepchecks.utils.strings import format_percent, format_number
from deepchecks.utils.metrics import MULTICLASS_SCORERS_NON_AVERAGE, ModelType


__all__ = ['PerformanceReport', 'MultiModelPerformanceReport']
__all__ = ['PerformanceReport']


PR = TypeVar('PR', bound='PerformanceReport')
Expand Down Expand Up @@ -301,75 +302,3 @@ def condition(check_result: pd.DataFrame) -> ConditionResult:
),
condition_func=condition
)


class MultiModelPerformanceReport(ModelComparisonCheck):
"""Summarize performance scores for multiple models on test datasets.
Parameters
----------
alternative_scorers : Dict[str, Callable] , default: None
An optional dictionary of scorer name to scorer functions.
If none given, using default scorers
"""

def __init__(self, alternative_scorers: Dict[str, Callable] = None, **kwargs):
super().__init__(**kwargs)
self.user_scorers = alternative_scorers

def run_logic(self, multi_context: ModelComparisonContext):
"""Run check logic."""
first_context = multi_context[0]
scorers = first_context.get_scorers(self.user_scorers, class_avg=False)

if multi_context.task_type in [ModelType.MULTICLASS, ModelType.BINARY]:
plot_x_axis = ['Class', 'Model']
results = []

for context in multi_context:
test = context.test
model = context.model
label = cast(pd.Series, test.label_col)
n_samples = label.groupby(label).count()
results.extend(
[context.model_name, class_score, scorer.name, class_name, n_samples[class_name]]
for scorer in scorers
# scorer returns numpy array of results with item per class
for class_score, class_name in zip(scorer(model, test), test.classes)
)

results_df = pd.DataFrame(results, columns=['Model', 'Value', 'Metric', 'Class', 'Number of samples'])

else:
plot_x_axis = 'Model'
results = [
[context.model_name, scorer(context.model, context.test), scorer.name,
cast(pd.Series, context.test.label_col).count()]
for context in multi_context
for scorer in scorers
]
results_df = pd.DataFrame(results, columns=['Model', 'Value', 'Metric', 'Number of samples'])

fig = px.histogram(
results_df,
x=plot_x_axis,
y='Value',
color='Model',
barmode='group',
facet_col='Metric',
facet_col_spacing=0.05,
hover_data=['Number of samples'],
)

if multi_context.task_type in [ModelType.MULTICLASS, ModelType.BINARY]:
fig.update_xaxes(title=None, tickprefix='Class ', tickangle=60)
else:
fig.update_xaxes(title=None)

fig = (
fig.update_yaxes(title=None, matches=None)
.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1]))
.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True))
)

return CheckResult(results_df, display=[fig])
13 changes: 6 additions & 7 deletions deepchecks/utils/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,15 @@ def _generate_check_docs_link_html(check):
# compare check full name and link to the notebook to
# understand how link is formatted:
#
# - deepchecks.tabular.checks.integrity.new_category.CategoryMismatchTrainTest
# - docs.deepchecks.com/{version}/examples/tabular/checks/integrity/category_mismatch_train_test.html # noqa: E501 # pylint: disable=line-too-long
# - deepchecks.tabular.checks.integrity.StringMismatchComparison
# - https://docs.deepchecks.com/{version}/examples/tabular/checks/integrity/examples/plot_string_mismatch_comparison.html # noqa: E501 # pylint: disable=line-too-long

# Remove deepchecks from the start
module_path = module_path[len('deepchecks.'):]
# There is a bug in doc rendering where the "tabular" is omitted, so do it for now
if module_path.startswith('tabular.'):
module_path = module_path[len('tabular.'):]

url = '/'.join([*module_path.split('.')])
module_parts = module_path.split('.')
module_parts[-1] = f'plot_{module_parts[-1]}'
module_parts.insert(len(module_parts) - 1, 'examples')
url = '/'.join([*module_parts])
version = deepchecks.__version__ or 'stable'
link = link_template.format(version=version, path=url)
return f' <a href="{link}" target="_blank">Read More...</a>'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#%%

from deepchecks.tabular.checks.integrity.new_category import CategoryMismatchTrainTest
from deepchecks.tabular.checks.integrity import CategoryMismatchTrainTest
from deepchecks.tabular import Dataset
import pandas as pd

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#%%

from deepchecks.tabular.checks.integrity.new_label import NewLabelTrainTest
from deepchecks.tabular.checks.integrity import NewLabelTrainTest
from deepchecks.tabular import Dataset
import pandas as pd

Expand Down
6 changes: 3 additions & 3 deletions tests/utils/test_string_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def test_generate_check_docs_link_html():
# Act
html = _generate_check_docs_link_html(check)
# Assert
assert_that(html, equal_to(f' <a href="https://docs.deepchecks.com/{version}/examples/checks/overview/model_info'
f'.html?utm_source=display_output&utm_medium=referral&utm_campaign=check_link" '
f'target="_blank">Read More...</a>'))
assert_that(html, equal_to(f' <a href="https://docs.deepchecks.com/{version}/examples/tabular/checks/overview/'
f'examples/plot_model_info.html?utm_source=display_output&utm_medium=referral'
f'&utm_campaign=check_link" target="_blank">Read More...</a>'))


def test_generate_check_docs_link_html_not_a_check():
Expand Down

0 comments on commit 904e010

Please sign in to comment.