-
Notifications
You must be signed in to change notification settings - Fork 247
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Fix read more * Fix test * Rename files for the docs * Fix lint * Fix init * Fix docs import
- Loading branch information
matanper
committed
Apr 8, 2022
1 parent
b6824ff
commit 904e010
Showing
12 changed files
with
112 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
94 changes: 94 additions & 0 deletions
94
deepchecks/tabular/checks/performance/multi_model_performance_report.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# ---------------------------------------------------------------------------- | ||
# Copyright (C) 2021-2022 Deepchecks (https://www.deepchecks.com) | ||
# | ||
# This file is part of Deepchecks. | ||
# Deepchecks is distributed under the terms of the GNU Affero General | ||
# Public License (version 3 or later). | ||
# You should have received a copy of the GNU Affero General Public License | ||
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>. | ||
# ---------------------------------------------------------------------------- | ||
# | ||
"""Module containing multi model performance report check.""" | ||
from typing import Dict, Callable, cast | ||
|
||
import pandas as pd | ||
import plotly.express as px | ||
|
||
from deepchecks.core import CheckResult | ||
from deepchecks.tabular import ModelComparisonCheck, ModelComparisonContext | ||
from deepchecks.utils.metrics import ModelType | ||
|
||
|
||
__all__ = ['MultiModelPerformanceReport'] | ||
|
||
|
||
class MultiModelPerformanceReport(ModelComparisonCheck): | ||
"""Summarize performance scores for multiple models on test datasets. | ||
Parameters | ||
---------- | ||
alternative_scorers : Dict[str, Callable] , default: None | ||
An optional dictionary of scorer name to scorer functions. | ||
If none given, using default scorers | ||
""" | ||
|
||
def __init__(self, alternative_scorers: Dict[str, Callable] = None, **kwargs): | ||
super().__init__(**kwargs) | ||
self.user_scorers = alternative_scorers | ||
|
||
def run_logic(self, multi_context: ModelComparisonContext): | ||
"""Run check logic.""" | ||
first_context = multi_context[0] | ||
scorers = first_context.get_scorers(self.user_scorers, class_avg=False) | ||
|
||
if multi_context.task_type in [ModelType.MULTICLASS, ModelType.BINARY]: | ||
plot_x_axis = ['Class', 'Model'] | ||
results = [] | ||
|
||
for context in multi_context: | ||
test = context.test | ||
model = context.model | ||
label = cast(pd.Series, test.label_col) | ||
n_samples = label.groupby(label).count() | ||
results.extend( | ||
[context.model_name, class_score, scorer.name, class_name, n_samples[class_name]] | ||
for scorer in scorers | ||
# scorer returns numpy array of results with item per class | ||
for class_score, class_name in zip(scorer(model, test), test.classes) | ||
) | ||
|
||
results_df = pd.DataFrame(results, columns=['Model', 'Value', 'Metric', 'Class', 'Number of samples']) | ||
|
||
else: | ||
plot_x_axis = 'Model' | ||
results = [ | ||
[context.model_name, scorer(context.model, context.test), scorer.name, | ||
cast(pd.Series, context.test.label_col).count()] | ||
for context in multi_context | ||
for scorer in scorers | ||
] | ||
results_df = pd.DataFrame(results, columns=['Model', 'Value', 'Metric', 'Number of samples']) | ||
|
||
fig = px.histogram( | ||
results_df, | ||
x=plot_x_axis, | ||
y='Value', | ||
color='Model', | ||
barmode='group', | ||
facet_col='Metric', | ||
facet_col_spacing=0.05, | ||
hover_data=['Number of samples'], | ||
) | ||
|
||
if multi_context.task_type in [ModelType.MULTICLASS, ModelType.BINARY]: | ||
fig.update_xaxes(title=None, tickprefix='Class ', tickangle=60) | ||
else: | ||
fig.update_xaxes(title=None) | ||
|
||
fig = ( | ||
fig.update_yaxes(title=None, matches=None) | ||
.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) | ||
.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True)) | ||
) | ||
|
||
return CheckResult(results_df, display=[fig]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters