Skip to content

Commit

Permalink
updates multiclass tests to pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
Oege Dijk committed May 7, 2022
1 parent 98e5337 commit a97b2cf
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 167 deletions.
44 changes: 41 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,50 @@ def rf_multiclass_explainer(fitted_rf_multiclass_model):
_, _, X_test, y_test = titanic_embarked()
_, test_names = titanic_names()
explainer = ClassifierExplainer(fitted_rf_multiclass_model, X_test, y_test,
cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck'],
idxs=test_names,
labels=['Queenstown', 'Southampton', 'Cherbourg'])
cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck'],
idxs=test_names,
labels=['Queenstown', 'Southampton', 'Cherbourg'])
return explainer

@pytest.fixture(scope="session")
def precalculated_rf_multiclass_explainer(rf_multiclass_explainer):
_ = ExplainerDashboard(rf_multiclass_explainer)
return rf_multiclass_explainer

@pytest.fixture(scope="session")
def rf_multiclass_explainer_no_y(fitted_rf_multiclass_model):
_, _, X_test, _ = titanic_embarked()
_, test_names = titanic_names()
explainer = ClassifierExplainer(fitted_rf_multiclass_model, X_test,
cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck'],
idxs=test_names,
labels=['Queenstown', 'Southampton', 'Cherbourg'])
return explainer

@pytest.fixture(scope="session")
def precalculated_rf_multiclass_explainer_no_y(rf_multiclass_explainer_no_y):
_ = ExplainerDashboard(rf_multiclass_explainer_no_y)
return rf_multiclass_explainer


@pytest.fixture(scope="session")
def fitted_xgb_multiclass_model():
X_train, y_train, _, _ = titanic_embarked()
model = XGBClassifier(n_estimators=5, max_depth=2)
model.fit(X_train, y_train)
return model

@pytest.fixture(scope="session")
def xgb_multiclass_explainer(fitted_xgb_multiclass_model):
_, _, X_test, y_test = titanic_embarked()
_, test_names = titanic_names()
explainer = ClassifierExplainer(fitted_xgb_multiclass_model, X_test, y_test,
cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck'],
idxs=test_names,
labels=['Queenstown', 'Southampton', 'Cherbourg'])
return explainer

@pytest.fixture(scope="session")
def precalculated_xgb_multiclass_explainer(xgb_multiclass_explainer):
_ = ExplainerDashboard(xgb_multiclass_explainer)
return xgb_multiclass_explainer
132 changes: 6 additions & 126 deletions tests/integration_tests/test_dashboards.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,7 @@


import dash

from catboost import CatBoostClassifier, CatBoostRegressor
from xgboost import XGBClassifier, XGBRegressor
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor

from explainerdashboard.explainers import ClassifierExplainer, RegressionExplainer
from explainerdashboard.datasets import titanic_survive, titanic_fare, titanic_embarked, titanic_names
from explainerdashboard.dashboards import ExplainerDashboard


# def get_classification_explainer(xgboost=False, include_y=True):
# X_train, y_train, X_test, y_test = titanic_survive()
# if xgboost:
# model = XGBClassifier().fit(X_train, y_train)
# else:
# model = RandomForestClassifier(n_estimators=50, max_depth=10).fit(X_train, y_train)
# if include_y:
# explainer = ClassifierExplainer(
# model, X_test, y_test,
# cats=['Sex', 'Deck', 'Embarked'],
# labels=['Not survived', 'Survived'])
# else:
# explainer = ClassifierExplainer(
# model, X_test,
# cats=['Sex', 'Deck', 'Embarked'],
# labels=['Not survived', 'Survived'])

# explainer.calculate_properties()
# return explainer


# def get_regression_explainer(xgboost=False, include_y=True):
# X_train, y_train, X_test, y_test = titanic_fare()
# train_names, test_names = titanic_names()
# if xgboost:
# model = XGBRegressor().fit(X_train, y_train)
# else:
# model = RandomForestRegressor(n_estimators=50, max_depth=10).fit(X_train, y_train)

# if include_y:
# reg_explainer = RegressionExplainer(model, X_test, y_test,
# cats=['Sex', 'Deck', 'Embarked'],
# idxs=test_names,
# units="$")
# else:
# reg_explainer = RegressionExplainer(model, X_test,
# cats=['Sex', 'Deck', 'Embarked'],
# idxs=test_names,
# units="$")

# reg_explainer.calculate_properties()
# return reg_explainer

def get_multiclass_explainer(xgboost=False, include_y=True):
X_train, y_train, X_test, y_test = titanic_embarked()
train_names, test_names = titanic_names()
if xgboost:
model = XGBClassifier().fit(X_train, y_train)
else:
model = RandomForestClassifier(n_estimators=50, max_depth=10).fit(X_train, y_train)

if include_y:
if xgboost:
multi_explainer = ClassifierExplainer(model, X_test, y_test,
model_output='logodds',
cats=['Sex', 'Deck'],
labels=['Queenstown', 'Southampton', 'Cherbourg'])
else:
multi_explainer = ClassifierExplainer(model, X_test, y_test,
cats=['Sex', 'Deck'],
labels=['Queenstown', 'Southampton', 'Cherbourg'])
else:
if xgboost:
multi_explainer = ClassifierExplainer(model, X_test,
model_output='logodds',
cats=['Sex', 'Deck'],
labels=['Queenstown', 'Southampton', 'Cherbourg'])
else:
multi_explainer = ClassifierExplainer(model, X_test,
cats=['Sex', 'Deck'],
labels=['Queenstown', 'Southampton', 'Cherbourg'])

multi_explainer.calculate_properties()
return multi_explainer


# def get_catboost_classifier():
# X_train, y_train, X_test, y_test = titanic_survive()
# train_names, test_names = titanic_names()

# model = CatBoostClassifier(iterations=100, verbose=0).fit(X_train, y_train)
# explainer = ClassifierExplainer(
# model, X_test, y_test,
# cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']},
# 'Deck', 'Embarked'],
# labels=['Not survived', 'Survived'],
# idxs=test_names)

# X_cats, y_cats = explainer.X_merged, explainer.y.astype("int")
# model = CatBoostClassifier(iterations=5, verbose=0).fit(X_cats, y_cats, cat_features=[5, 6, 7])
# explainer = ClassifierExplainer(model, X_cats, y_cats, idxs=X_test.index)
# explainer.calculate_properties(include_interactions=False)
# return explainer


# def get_catboost_regressor():
# X_train, y_train, X_test, y_test = titanic_fare()

# model = CatBoostRegressor(iterations=5, verbose=0).fit(X_train, y_train)
# explainer = RegressionExplainer(model, X_test, y_test,
# cats=["Sex", 'Deck', 'Embarked'])
# X_cats, y_cats = explainer.X_merged, explainer.y
# model = CatBoostRegressor(iterations=5, verbose=0).fit(X_cats, y_cats, cat_features=[5, 6, 7])
# explainer = RegressionExplainer(model, X_cats, y_cats, idxs=X_test.index)
# explainer.calculate_properties(include_interactions=False)
# return explainer


def test_classification_dashboard(dash_duo, precalculated_rf_classifier_explainer):
db = ExplainerDashboard(precalculated_rf_classifier_explainer, title="testing", responsive=False)
html = db.to_html()
Expand Down Expand Up @@ -158,9 +41,8 @@ def test_simple_regression_dashboard(dash_duo, precalculated_rf_regression_expla
assert dash_duo.get_logs() == [], "browser console should contain no error"


def test_multiclass_dashboard(dash_duo):
explainer = get_multiclass_explainer()
db = ExplainerDashboard(explainer, title="testing", responsive=False)
def test_multiclass_dashboard(dash_duo, precalculated_rf_multiclass_explainer):
db = ExplainerDashboard(precalculated_rf_multiclass_explainer, title="testing", responsive=False)
html = db.to_html()
assert html.startswith('\n<!DOCTYPE html>\n<html'), "failed to generate dashboard to_html"

Expand Down Expand Up @@ -189,9 +71,8 @@ def test_xgboost_regression_dashboard(dash_duo, precalculated_xgb_regression_exp
assert dash_duo.get_logs() == [], "browser console should contain no error"


def test_xgboost_multiclass_dashboard(dash_duo):
explainer = get_multiclass_explainer(xgboost=True)
db = ExplainerDashboard(explainer, title="testing", responsive=False)
def test_xgboost_multiclass_dashboard(dash_duo, precalculated_xgb_multiclass_explainer):
db = ExplainerDashboard(precalculated_xgb_multiclass_explainer, title="testing", responsive=False)
html = db.to_html()
assert html.startswith('\n<!DOCTYPE html>\n<html'), "failed to generate dashboard to_html"

Expand Down Expand Up @@ -220,9 +101,8 @@ def test_regression_dashboard_no_y(dash_duo, precalculated_rf_regression_explain
assert dash_duo.get_logs() == [], "browser console should contain no error"


def test_multiclass_dashboard_no_y(dash_duo):
explainer = get_multiclass_explainer(include_y=False)
db = ExplainerDashboard(explainer, title="testing", responsive=False)
def test_multiclass_dashboard_no_y(dash_duo, precalculated_rf_multiclass_explainer_no_y):
db = ExplainerDashboard(precalculated_rf_multiclass_explainer_no_y, title="testing", responsive=False)
html = db.to_html()
assert html.startswith('\n<!DOCTYPE html>\n<html'), "failed to generate dashboard to_html"

Expand Down
62 changes: 24 additions & 38 deletions tests/test_xgboost_treeviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,53 +64,39 @@ def test_xgbreg_plot_trees(precalculated_xgb_regression_explainer, test_names):
assert isinstance(fig, go.Figure)


class XGBMultiClassifierExplainerTests(unittest.TestCase):
def setUp(self):
X_train, y_train, X_test, y_test = titanic_embarked()
train_names, test_names = titanic_names()
_, self.names = titanic_names()

model = XGBClassifier(n_estimators=5)
model.fit(X_train, y_train)

self.explainer = ClassifierExplainer(
model, X_test, y_test, model_output='raw',
cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']},
'Deck'],
idxs=test_names,
labels=['Queenstown', 'Southampton', 'Cherbourg'])
def test_graphviz_available(precalculated_xgb_multiclass_explainer):
assert isinstance(precalculated_xgb_multiclass_explainer.graphviz_available, bool)

def test_graphviz_available(self):
self.assertIsInstance(self.explainer.graphviz_available, bool)

def test_shadow_trees(self):
dt = self.explainer.shadow_trees
self.assertIsInstance(dt, list)
self.assertIsInstance(dt[0], dtreeviz.models.shadow_decision_tree.ShadowDecTree)
def test_shadow_trees(precalculated_xgb_multiclass_explainer):
dt = precalculated_xgb_multiclass_explainer.shadow_trees
assert isinstance(dt, list)
assert isinstance(dt[0], dtreeviz.models.shadow_decision_tree.ShadowDecTree)

def test_decisionpath_df(self):
df = self.explainer.get_decisionpath_df(tree_idx=0, index=0)
self.assertIsInstance(df, pd.DataFrame)
def test_decisionpath_df(precalculated_xgb_multiclass_explainer, test_names):
df = precalculated_xgb_multiclass_explainer.get_decisionpath_df(tree_idx=0, index=0)
assert isinstance(df, pd.DataFrame)

df = self.explainer.get_decisionpath_df(tree_idx=0, index=self.names[0])
self.assertIsInstance(df, pd.DataFrame)
df = precalculated_xgb_multiclass_explainer.get_decisionpath_df(tree_idx=0, index=test_names[0])
assert isinstance(df, pd.DataFrame)

df = self.explainer.get_decisionpath_df(tree_idx=0, index=self.names[0], pos_label=0)
self.assertIsInstance(df, pd.DataFrame)
df = precalculated_xgb_multiclass_explainer.get_decisionpath_df(tree_idx=0, index=test_names[0], pos_label=0)
assert isinstance(df, pd.DataFrame)


def test_plot_trees(self):
fig = self.explainer.plot_trees(index=0)
self.assertIsInstance(fig, go.Figure)
def test_plot_trees(precalculated_xgb_multiclass_explainer, test_names):
fig = precalculated_xgb_multiclass_explainer.plot_trees(index=0)
assert isinstance(fig, go.Figure)

fig = self.explainer.plot_trees(index=self.names[0])
self.assertIsInstance(fig, go.Figure)
fig = precalculated_xgb_multiclass_explainer.plot_trees(index=test_names[0])
assert isinstance(fig, go.Figure)

fig = self.explainer.plot_trees(index=self.names[0], highlight_tree=0)
self.assertIsInstance(fig, go.Figure)
fig = precalculated_xgb_multiclass_explainer.plot_trees(index=test_names[0], highlight_tree=0)
assert isinstance(fig, go.Figure)

fig = self.explainer.plot_trees(index=self.names[0], pos_label=0)
self.assertIsInstance(fig, go.Figure)
fig = precalculated_xgb_multiclass_explainer.plot_trees(index=test_names[0], pos_label=0)
assert isinstance(fig, go.Figure)

def test_calculate_properties(self):
self.explainer.calculate_properties()
def test_calculate_properties(precalculated_xgb_multiclass_explainer):
precalculated_xgb_multiclass_explainer.calculate_properties()

0 comments on commit a97b2cf

Please sign in to comment.