Skip to content

Commit

Permalink
adds various tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Oege Dijk committed Jun 16, 2022
1 parent 56ead0c commit 2a21322
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 58 deletions.
8 changes: 0 additions & 8 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,8 @@
- Add this method? : https://arxiv.org/abs/2006.04750?

## Tests:
- add hub to html test
- add pipeline with X_background test
- add test dataframe y (passing + failing)
- test explainer.dump and explainer.from_file with .pkl or .dill
- explainer.to_yaml return_dict=True
- explainer.__contains__
- explainer.get_idx(int) and explainer.get_idx not found
- explainer.get_index(str)
- test index_exists_func with method
- reset_index_list
- add get_descriptions_df tests -> sort='shap'
- set_shap_values test
- set_shap_interaction_values test
Expand Down
1 change: 0 additions & 1 deletion explainerdashboard/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
import dash_bootstrap_components as dbc

from .dashboard_components import *
from .dashboard_tabs import *
from .dashboards import ExplainerTabsLayout, ExplainerPageLayout
from . import to_html
82 changes: 36 additions & 46 deletions explainerdashboard/dashboards.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@

from .dashboard_methods import instantiate_component, encode_callables, decode_callables
from .dashboard_components import *
from .dashboard_tabs import *
from .explainers import BaseExplainer
from . import to_html

Expand Down Expand Up @@ -1913,6 +1912,9 @@ def dashboard_decks(dashboards, n_cols):
return index_page

def to_html(self):
"""
returns static html version of the hub landing page
"""
def dashboard_cards(dashboards, n_cols):
full_rows = int(len(dashboards)/ n_cols)
n_last_row = len(dashboards) % n_cols
Expand Down Expand Up @@ -2142,8 +2144,14 @@ class InlineExplainer:
"""
Run a single tab inline in a Jupyter notebook using specific method calls.
"""
def __init__(self, explainer, mode='inline', width=1000, height=800,
port=8050, **kwargs):
def __init__(
self,
explainer:BaseExplainer,
mode:str='inline',
width:int=1000,
height:int=800,
port:int=8050,
**kwargs):
"""
:param explainer: an Explainer object
:param mode: either 'inline', 'jupyterlab' or 'external'
Expand Down Expand Up @@ -2275,53 +2283,56 @@ def __repr__(self):

class InlineExplainerTabs(InlineExplainerComponent):

@delegates_kwargs(ImportancesTab)
@delegates_doc(ImportancesTab)
@delegates_kwargs(ImportancesComposite)
@delegates_doc(ImportancesComposite)
def importances(self, title='Importances', **kwargs):
"""Show contributions (permutation or shap) inline in notebook"""
tab = ImportancesTab(self._explainer, **kwargs)
tab = ImportancesComposite(self._explainer, **kwargs)
self._run_component(tab, title)

@delegates_kwargs(ModelSummaryTab)
@delegates_doc(ModelSummaryTab)
@delegates_kwargs(RegressionModelStatsComposite)
@delegates_doc(RegressionModelStatsComposite)
def modelsummary(self, title='Model Summary', **kwargs):
"""Runs model_summary tab inline in notebook"""
tab = ModelSummaryTab(self._explainer, **kwargs)
if self._explainer.is_classifier:
tab = ClassifierModelStatsComposite(self._explainer, **kwargs)
else:
tab = RegressionModelStatsComposite(self._explainer, **kwargs)
self._run_component(tab, title)

@delegates_kwargs(ContributionsTab)
@delegates_doc(ContributionsTab)
@delegates_kwargs(IndividualPredictionsComposite)
@delegates_doc(IndividualPredictionsComposite)
def contributions(self, title='Contributions', **kwargs):
"""Show contributions (permutation or shap) inline in notebook"""
tab = ContributionsTab(self._explainer, **kwargs)
tab = IndividualPredictionsComposite(self._explainer, **kwargs)
self._run_component(tab, title)

@delegates_kwargs(WhatIfTab)
@delegates_doc(WhatIfTab)
@delegates_kwargs(WhatIfComposite)
@delegates_doc(WhatIfComposite)
def whatif(self, title='What if...', **kwargs):
"""Show What if... tab inline in notebook"""
tab = WhatIfTab(self._explainer, **kwargs)
tab = WhatIfComposite(self._explainer, **kwargs)
self._run_component(tab, title)

@delegates_kwargs(ShapDependenceTab)
@delegates_doc(ShapDependenceTab)
@delegates_kwargs(ShapDependenceComposite)
@delegates_doc(ShapDependenceComposite)
def dependence(self, title='Shap Dependence', **kwargs):
"""Runs shap_dependence tab inline in notebook"""
tab = ShapDependenceTab(self._explainer, **kwargs)
tab = ShapDependenceComposite(self._explainer, **kwargs)
self._run_component(tab, title)

@delegates_kwargs(ShapInteractionsTab)
@delegates_doc(ShapInteractionsTab)
@delegates_kwargs(ShapInteractionsComposite)
@delegates_doc(ShapInteractionsComposite)
def interactions(self, title='Shap Interactions', **kwargs):
"""Runs shap_interactions tab inline in notebook"""
tab = ShapInteractionsTab(self._explainer, **kwargs)
tab = ShapInteractionsComposite(self._explainer, **kwargs)
self._run_component(tab, title)

@delegates_kwargs(DecisionTreesTab)
@delegates_doc(DecisionTreesTab)
@delegates_kwargs(DecisionTreesComposite)
@delegates_doc(DecisionTreesComposite)
def decisiontrees(self, title='Decision Trees', **kwargs):
"""Runs shap_interactions tab inline in notebook"""
tab = DecisionTreesTab(self._explainer, **kwargs)
tab = DecisionTreesComposite(self._explainer, **kwargs)
self._run_component(tab, title)


Expand Down Expand Up @@ -2510,25 +2521,4 @@ def decisionpath_table(self, title='Decision path', **kwargs):
def decisionpath_graph(self, title='Decision path', **kwargs):
"""Runs decision_trees tab inline in notebook"""
comp = DecisionPathTableComponent(self._explainer, **kwargs)
self._run_component(comp, title)



class JupyterExplainerDashboard(ExplainerDashboard):
def __init__(self, *args, **kwargs):
raise ValueError("JupyterExplainerDashboard has been deprecated. "
"Use e.g. ExplainerDashboard(mode='inline') instead.")

class ExplainerTab:
def __init__(self, *args, **kwargs):
raise ValueError("ExplainerTab has been deprecated. "
"Use e.g. ExplainerDashboard(explainer, ImportancesTab) instead.")


class JupyterExplainerTab(ExplainerTab):
def __init__(self, *args, **kwargs):
raise ValueError("ExplainerTab has been deprecated. "
"Use e.g. ExplainerDashboard(explainer, ImportancesTab, mode='inline') instead.")



self._run_component(comp, title)
7 changes: 5 additions & 2 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,11 @@ def __len__(self):
return len(self.X)

def __contains__(self, index):
if self.get_idx(index) is not None:
return True
try:
if self.get_idx(index) is not None:
return True
except IndexNotFoundError:
return False
return False

def get_idx(self, index):
Expand Down
17 changes: 17 additions & 0 deletions tests/hub/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,20 @@ def test_load_from_config(explainer_hub, tmp_path_factory):
explainer_hub.to_yaml(tmp_path / "hub.yaml")
explainer_hub2 = ExplainerHub.from_config(tmp_path / "hub.yaml")
assert isinstance(explainer_hub2, ExplainerHub)

def test_hub_to_html(explainer_hub):
html = explainer_hub.to_html()
assert isinstance(html, str)

def test_hub_save_html(explainer_hub, tmp_path_factory):
tmp_path = tmp_path_factory.mktemp("tmp_hub")
explainer_hub.save_html(tmp_path / "hub.html", save_dashboards=True)
with open(tmp_path / "hub.html") as html:
assert isinstance(html.read(), str)

def test_hub_to_zip(explainer_hub, tmp_path_factory):
tmp_path = tmp_path_factory.mktemp("tmp_hub")
explainer_hub.to_zip(tmp_path / "hub.zip")
assert (tmp_path / "hub.zip").exists()


47 changes: 46 additions & 1 deletion tests/test_classifier_base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,59 @@
import pytest

import pandas as pd
import numpy as np
from pandas.api.types import is_categorical_dtype, is_numeric_dtype


import plotly.graph_objects as go

from explainerdashboard import ClassifierExplainer, ExplainerDashboard
from explainerdashboard.explainer_methods import IndexNotFoundError



def test_explainer_with_dataframe_y(fitted_rf_classifier_model, classifier_data):
_, _, X_test, y_test = classifier_data
explainer = ClassifierExplainer(
fitted_rf_classifier_model,
X_test,
y_test.to_frame(),
cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck', 'Embarked'],
cats_notencoded={'Gender': 'No Gender'},
labels=['Not survived', 'Survived']
)
_ = ExplainerDashboard(explainer)

def test_explainer_contains(precalculated_rf_classifier_explainer, test_names):
assert 1 in precalculated_rf_classifier_explainer
assert test_names[0] in precalculated_rf_classifier_explainer
assert 1000 not in precalculated_rf_classifier_explainer
assert "randomname" not in precalculated_rf_classifier_explainer

def test_explainer_len(precalculated_rf_classifier_explainer, testlen):
assert len(precalculated_rf_classifier_explainer) == testlen

def test_int_idx(precalculated_rf_classifier_explainer, test_names):
assert precalculated_rf_classifier_explainer.get_idx(test_names[0]) == 0

def test_getindex(precalculated_rf_classifier_explainer, test_names):
assert precalculated_rf_classifier_explainer.get_index(0) == test_names[0]
assert precalculated_rf_classifier_explainer.get_index(test_names[0]) == test_names[0]
assert precalculated_rf_classifier_explainer.get_index(-1) is None
assert precalculated_rf_classifier_explainer.get_index(10_000) is None
assert precalculated_rf_classifier_explainer.get_index("Non existent index") is None

def test_get_idx(precalculated_rf_classifier_explainer, test_names):
assert precalculated_rf_classifier_explainer.get_idx(test_names[0]) == 0
assert precalculated_rf_classifier_explainer.get_idx(5) == 5
with pytest.raises(IndexNotFoundError):
precalculated_rf_classifier_explainer.get_idx(-1)
with pytest.raises(IndexNotFoundError):
precalculated_rf_classifier_explainer.get_idx(1000)
with pytest.raises(IndexNotFoundError):
precalculated_rf_classifier_explainer.get_idx("randomname")


def test_random_index(precalculated_rf_classifier_explainer):
assert isinstance(precalculated_rf_classifier_explainer.random_index(), int)
assert isinstance(precalculated_rf_classifier_explainer.random_index(return_str=True), str)
Expand All @@ -21,7 +63,6 @@ def test_index_exists(precalculated_rf_classifier_explainer):
assert (precalculated_rf_classifier_explainer.index_exists(precalculated_rf_classifier_explainer.idxs[0]))
assert (not precalculated_rf_classifier_explainer.index_exists('bla'))


def test_preds(precalculated_rf_classifier_explainer):
assert isinstance(precalculated_rf_classifier_explainer.preds, np.ndarray)

Expand Down Expand Up @@ -237,6 +278,10 @@ def test_yaml(precalculated_rf_classifier_explainer):
yaml = precalculated_rf_classifier_explainer.to_yaml()
assert isinstance(yaml, str)

def test_yaml_return_dict(precalculated_rf_classifier_explainer):
return_dict = precalculated_rf_classifier_explainer.to_yaml(return_dict=True)
assert isinstance(return_dict, dict)




72 changes: 72 additions & 0 deletions tests/test_externalsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,43 @@ def X_func(index):
explainer.set_y_func(y_func)
return explainer


@pytest.fixture(scope='session')
def classifier_explainer_with_external_data_methods(fitted_rf_classifier_model):
_, _, X_test, y_test = titanic_survive()
X_test.reset_index(drop=True, inplace=True)
X_test.index = X_test.index.astype(str)

X_test1, y_test1 = X_test.iloc[:100], y_test.iloc[:100]
X_test2, y_test2 = X_test.iloc[100:], y_test.iloc[100:]

explainer = ClassifierExplainer(fitted_rf_classifier_model, X_test1, y_test1, cats=['Sex', 'Deck'])

def index_exists_func(self, index):
assert self.is_classifier
return index in X_test2.index

def index_list_func(self):
assert self.is_classifier
# only returns first 50 indexes
return list(X_test2.index[:50])

def y_func(self, index):
assert self.is_classifier
idx = X_test2.index.get_loc(index)
return y_test2.iloc[[idx]]

def X_func(self, index):
assert self.is_classifier
idx = X_test2.index.get_loc(index)
return X_test2.iloc[[idx]]

explainer.set_index_exists_func(index_exists_func)
explainer.set_index_list_func(index_list_func)
explainer.set_X_row_func(X_func)
explainer.set_y_func(y_func)
return explainer

@pytest.fixture(scope='session')
def regression_explainer_with_external_data(fitted_rf_regression_model):
_, _, X_test, y_test = titanic_fare()
Expand Down Expand Up @@ -71,6 +108,11 @@ def X_func(index):
explainer.set_y_func(y_func)
return explainer

def test_clas_externalsource_reset_index_list(classifier_explainer_with_external_data):
classifier_explainer_with_external_data.reset_index_list()
index_list = classifier_explainer_with_external_data.get_index_list()
assert ('100' in index_list)
assert (not '160'in index_list)

def test_clas_externalsource_get_X_row(classifier_explainer_with_external_data):
assert isinstance(classifier_explainer_with_external_data.get_X_row(0), pd.DataFrame)
Expand Down Expand Up @@ -105,8 +147,38 @@ def test_clas_externalsource_index_exists(classifier_explainer_with_external_dat
assert (not classifier_explainer_with_external_data.index_exists(120))
assert (not classifier_explainer_with_external_data.index_exists("wrong index"))

def test_clas_externalsource_methods_get_X_row(classifier_explainer_with_external_data_methods):
assert isinstance(classifier_explainer_with_external_data_methods.get_X_row(0), pd.DataFrame)
assert isinstance(classifier_explainer_with_external_data_methods.get_X_row("0"), pd.DataFrame)
assert isinstance(classifier_explainer_with_external_data_methods.get_X_row("120"), pd.DataFrame)
assert isinstance(classifier_explainer_with_external_data_methods.get_X_row("150"), pd.DataFrame)

def test_clas_externalsource_methods_get_shap_row(classifier_explainer_with_external_data_methods):
assert isinstance(classifier_explainer_with_external_data_methods.get_shap_row(0), pd.DataFrame)
assert isinstance(classifier_explainer_with_external_data_methods.get_shap_row("0"), pd.DataFrame)
assert isinstance(classifier_explainer_with_external_data_methods.get_shap_row("120"), pd.DataFrame)
assert isinstance(classifier_explainer_with_external_data_methods.get_shap_row("150"), pd.DataFrame)

def test_clas_externalsource_methods_get_y(classifier_explainer_with_external_data_methods):
assert isinstance(classifier_explainer_with_external_data_methods.get_y(0), int)
assert isinstance(classifier_explainer_with_external_data_methods.get_y("0"), int)
assert isinstance(classifier_explainer_with_external_data_methods.get_y("120"), int)
assert isinstance(classifier_explainer_with_external_data_methods.get_y("150"), int)

def test_clas_externalsource_methods_index_list(classifier_explainer_with_external_data_methods):
index_list = classifier_explainer_with_external_data_methods.get_index_list()
assert ('100' in index_list)
assert (not '160'in index_list)

def test_clas_externalsource_methods_index_exists(classifier_explainer_with_external_data_methods):
assert (classifier_explainer_with_external_data_methods.index_exists("0"))
assert (classifier_explainer_with_external_data_methods.index_exists("100"))
assert (classifier_explainer_with_external_data_methods.index_exists("160"))
assert (classifier_explainer_with_external_data_methods.index_exists(0))

assert (not classifier_explainer_with_external_data_methods.index_exists(-1))
assert (not classifier_explainer_with_external_data_methods.index_exists(120))
assert (not classifier_explainer_with_external_data_methods.index_exists("wrong index"))

def test_reg_externalsource_get_X_row(regression_explainer_with_external_data):
assert isinstance(regression_explainer_with_external_data.get_X_row(0), pd.DataFrame)
Expand Down

0 comments on commit 2a21322

Please sign in to comment.