Skip to content

Commit

Permalink
same colors for train test (#260)
Browse files Browse the repository at this point in the history
* same colors for train test

* fix lint

* change from global dict to func

* fix lint

* fix docstring

* revert to colors

Co-authored-by: Noam Bressler <noamzbr@gmail.com>
  • Loading branch information
benisraeldan and noamzbr committed Dec 15, 2021
1 parent 68fbb2f commit 2d310ea
Show file tree
Hide file tree
Showing 11 changed files with 397 additions and 387 deletions.
9 changes: 4 additions & 5 deletions deepchecks/checks/distribution/train_test_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from deepchecks import Dataset, CheckResult, TrainTestBaseCheck, ConditionResult
from deepchecks.checks.distribution.plot import plot_density
from deepchecks.utils.features import calculate_feature_importance_or_null
from deepchecks.utils.plot import colors
from deepchecks.utils.typing import Hashable
from deepchecks.errors import DeepchecksValueError
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -300,8 +301,6 @@ def drift_score_bar(axes, drift_score: float, drift_type: str):
axes.set_xlim([0, stop])
axes.set_yticklabels([])

colors = ['darkblue', '#69b3a2']

if feature_importances is not None:
fi_rank_series = feature_importances.rank(method='first', ascending=False)
fi_rank = fi_rank_series[column_name]
Expand All @@ -320,8 +319,8 @@ def plot_numerical():
fig.suptitle(plot_title, horizontalalignment='left', fontweight='bold', x=0.05)
drift_score_bar(axs[0], score, 'Earth Movers Distance')
plt.sca(axs[1])
pdf1 = plot_density(train_column, xs, colors[0])
pdf2 = plot_density(test_column, xs, colors[1])
pdf1 = plot_density(train_column, xs, colors['Train'])
pdf2 = plot_density(test_column, xs, colors['Test'])
plt.gca().set_ylim(bottom=0, top=max(max(pdf1), max(pdf2)) * 1.1)
axs[1].set_xlabel(column_name)
axs[1].set_ylabel('Probability Density')
Expand All @@ -347,7 +346,7 @@ def plot_categorical():
fig, axs = plt.subplots(3, figsize=(8, 4.5), gridspec_kw={'height_ratios': [1, 7, 0.2]})
fig.suptitle(plot_title, horizontalalignment='left', fontweight='bold', x=0.05)
drift_score_bar(axs[0], score, 'PSI')
cat_df.plot.bar(ax=axs[1], color=colors)
cat_df.plot.bar(ax=axs[1], color=(colors['Train'], colors['Test']))
axs[1].set_ylabel('Percentage')
axs[1].legend()
axs[1].set_title('Distribution')
Expand Down
8 changes: 4 additions & 4 deletions deepchecks/checks/distribution/trust_score_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from deepchecks.utils.metrics import task_type_check, ModelType
from deepchecks.utils.strings import format_percent
from deepchecks.utils.validation import validate_model
from deepchecks.utils.plot import colors
from deepchecks.errors import DeepchecksValueError


Expand Down Expand Up @@ -167,17 +168,16 @@ def filter_quantile(data):
x_range = [min(*test_trust_scores_cut, *train_trust_scores_cut),
max(*test_trust_scores_cut, *train_trust_scores_cut)]
xs = np.linspace(x_range[0], x_range[1], 40)
plot_density(test_trust_scores_cut, xs, 'darkblue')
plot_density(train_trust_scores_cut, xs, '#69b3a2')
plot_density(train_trust_scores_cut, xs, colors['Train'])
plot_density(test_trust_scores_cut, xs, colors['Test'])
# Set x axis
axes.set_xlim(x_range)
plt.xlabel('Trust score')
# Set y axis
axes.set_ylim(bottom=0)
plt.ylabel('Probability Density')
# Set labels
colors = {'Test': 'darkblue',
'Train': '#69b3a2'}

labels = list(colors.keys())
handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label]) for label in labels]
plt.legend(handles, labels)
Expand Down
5 changes: 3 additions & 2 deletions deepchecks/checks/methodology/boosting_overfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from deepchecks.utils.metrics import task_type_check, DEFAULT_METRICS_DICT, validate_scorer, DEFAULT_SINGLE_METRIC
from deepchecks.utils.strings import format_percent
from deepchecks.utils.validation import validate_model
from deepchecks.utils.plot import colors
from deepchecks.errors import DeepchecksValueError


Expand Down Expand Up @@ -179,8 +180,8 @@ def display_func():
axes.set_xlabel('Number of boosting iterations')
axes.set_ylabel(metric_name)
axes.grid()
axes.plot(estimator_steps, np.array(train_scores), 'o-', color='r', label='Training score')
axes.plot(estimator_steps, np.array(test_scores), 'o-', color='g', label='Test score')
axes.plot(estimator_steps, np.array(train_scores), 'o-', color=colors['Train'], label='Training score')
axes.plot(estimator_steps, np.array(test_scores), 'o-', color=colors['Test'], label='Test score')
axes.legend(loc='best')
# Display x ticks as integers
axes.xaxis.set_major_locator(MaxNLocator(integer=True))
Expand Down
7 changes: 3 additions & 4 deletions deepchecks/checks/methodology/performance_overfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pandas as pd
import numpy as np

from deepchecks.utils.plot import colors
from deepchecks.utils.strings import format_percent
from deepchecks.utils.validation import validate_model
from deepchecks.utils.metrics import get_metrics_list
Expand Down Expand Up @@ -95,12 +96,10 @@ def _train_test_difference_overfit(self, train_dataset: Dataset, test_dataset: D
def plot_overfit():
res_df = pd.DataFrame.from_dict({'Training Metrics': train_metrics, 'Test Metrics': test_metrics})
width = 0.20
my_cmap = plt.cm.get_cmap('Set2')
indices = np.arange(len(res_df.index))

colors = my_cmap(range(len(res_df.columns)))
plt.bar(indices, res_df['Training Metrics'].values.flatten(), width=width, color=colors[0])
plt.bar(indices + width, res_df['Test Metrics'].values.flatten(), width=width, color=colors[1])
plt.bar(indices, res_df['Training Metrics'].values.flatten(), width=width, color=colors['Train'])
plt.bar(indices + width, res_df['Test Metrics'].values.flatten(), width=width, color=colors['Test'])
plt.ylabel('Metrics')
plt.xticks(ticks=indices + width / 2., labels=res_df.index)
plt.xticks(rotation=30)
Expand Down
9 changes: 6 additions & 3 deletions deepchecks/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from matplotlib.colors import LinearSegmentedColormap


__all__ = ['create_colorbar_barchart_for_check', 'shifted_color_map']
__all__ = ['create_colorbar_barchart_for_check', 'shifted_color_map', 'colors']

colors = {'Train': 'darkblue',
'Test': '#69b3a2'}


def create_colorbar_barchart_for_check(
Expand Down Expand Up @@ -56,8 +59,8 @@ def create_colorbar_barchart_for_check(
my_cmap = shifted_color_map(my_cmap, start=start, midpoint=color_shift_midpoint, stop=stop,
name=color_map + check_name)

colors = my_cmap(list(y))
rects = ax.bar(x, y, color=colors) # pylint: disable=unused-variable
cmap_colors = my_cmap(list(y))
rects = ax.bar(x, y, color=cmap_colors) # pylint: disable=unused-variable

sm = ScalarMappable(cmap=my_cmap, norm=plt.Normalize(start, stop))
sm.set_array([])
Expand Down
10 changes: 5 additions & 5 deletions notebooks/checks/distribution/train_test_drift.ipynb

Large diffs are not rendered by default.

708 changes: 354 additions & 354 deletions notebooks/checks/distribution/trust_score_comparison.ipynb

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions notebooks/checks/methodology/boosting_overfit.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions notebooks/checks/methodology/performance_overfit.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/checks/performance/calibration_metric.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
"version": "3.8.5"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.8.5"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 2d310ea

Please sign in to comment.