Skip to content

Commit

Permalink
more pytest tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Oege Dijk committed May 8, 2022
1 parent 6ebadc4 commit c323497
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 82 deletions.
20 changes: 19 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import pytest

from pathlib import Path
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import ExtraTreesClassifier, ExtraTreesRegressor
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from xgboost import XGBClassifier, XGBRegressor
from lightgbm.sklearn import LGBMClassifier, LGBMRegressor
from catboost import CatBoostClassifier, CatBoostRegressor

from explainerdashboard import RegressionExplainer, ClassifierExplainer, ExplainerDashboard
from explainerdashboard import RegressionExplainer, ClassifierExplainer, ExplainerDashboard, ExplainerHub
from explainerdashboard.custom import ShapDependenceComposite
from explainerdashboard.datasets import titanic_survive, titanic_fare, titanic_names, titanic_embarked

Expand Down Expand Up @@ -447,3 +448,20 @@ def dashboard_dumps_folder(tmp_path_factory, precalculated_rf_classifier_explain
precalculated_rf_classifier_explainer.to_yaml(dump_path / "explainer.yaml")
custom_dashboard.to_yaml(dump_path / "dashboard.yaml", explainerfile=str(dump_path / "explainer.joblib"))
return dump_path

@pytest.fixture(scope="session")
def explainer_hub(precalculated_rf_classifier_explainer, precalculated_rf_regression_explainer):
hub = ExplainerHub([
ExplainerDashboard(precalculated_rf_classifier_explainer, description="Super interesting dashboard"),
ExplainerDashboard(precalculated_rf_regression_explainer, title="Dashboard Two",
name='db2', logins=[['user2', 'password2']])
],
users_file=str(Path.cwd() / "tests" / "test_assets" / "users.yaml")
)
return hub

@pytest.fixture(scope="session")
def explainer_hub_dump_folder(tmp_path_factory, explainer_hub):
dump_path = tmp_path_factory.mktemp("hub_dump")
explainer_hub.to_yaml(dump_path / "hub.yaml")
return dump_path
46 changes: 16 additions & 30 deletions tests/hub/test_hub.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,19 @@
import unittest
from pathlib import Path
from explainerdashboard import ExplainerHub

from sklearn.ensemble import RandomForestClassifier
from explainerdashboard import *
from explainerdashboard.datasets import *
from explainerdashboard.custom import *

class ExplainerHubTests(unittest.TestCase):
def setUp(self):
X_train, y_train, X_test, y_test = titanic_survive()

model = RandomForestClassifier(n_estimators=5, max_depth=2).fit(X_train, y_train)
self.explainer = ClassifierExplainer(model, X_test, y_test)
self.db1 = ExplainerDashboard(self.explainer, description="Super interesting dashboard")
self.db2 = ExplainerDashboard(self.explainer, title="Dashboard Two",
name='db2', logins=[['user2', 'password2']])
self.hub = ExplainerHub([self.db1, self.db2], users_file=str(Path.cwd() / "tests" / "test_assets" / "users.yaml"))

def test_hub_users(self):
self.assertGreater(len(self.hub.users), 0)
self.assertIn("db2", self.hub.dashboards_with_users)
self.hub.add_user("user3", "password")
self.hub.add_user_to_dashboard("db2", "user3")
self.assertIn("user3", self.hub.dashboard_users['db2'])
self.hub.add_user("user4", "password", add_to_users_file=True)
self.hub.add_user_to_dashboard("db2", "user4", add_to_users_file=True)
self.assertIn("user4", self.hub.dashboard_users['db2'])
self.assertIn("user4", self.hub.get_dashboard_users("db2"))
def test_hub_users(explainer_hub):
assert len(explainer_hub.users) > 0
assert ("db2" in explainer_hub.dashboards_with_users)
explainer_hub.add_user("user3", "password")
explainer_hub.add_user_to_dashboard("db2", "user3")
assert ("user3" in explainer_hub.dashboard_users['db2'])
explainer_hub.add_user("user4", "password", add_to_users_file=True)
explainer_hub.add_user_to_dashboard("db2", "user4", add_to_users_file=True)
assert ("user4" in explainer_hub.dashboard_users['db2'])
assert ("user4" in explainer_hub.get_dashboard_users("db2"))

def test_load_from_config(self):
self.hub.to_yaml(Path.cwd() / "tests" / "test_assets" / "hub.yaml")
self.hub2 = ExplainerHub.from_config(Path.cwd() / "tests" / "test_assets" / "hub.yaml")
def test_load_from_config(explainer_hub, tmp_path_factory):
tmp_path = tmp_path_factory.mktemp("tmp_hub")
explainer_hub.to_yaml(tmp_path / "hub.yaml")
explainer_hub2 = ExplainerHub.from_config(tmp_path / "hub.yaml")
assert isinstance(explainer_hub2, ExplainerHub)
25 changes: 2 additions & 23 deletions tests/hub/test_hub_cli.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,7 @@
import pytest
from pathlib import Path

from sklearn.ensemble import RandomForestClassifier

from explainerdashboard import *
from explainerdashboard.datasets import *
from explainerdashboard.custom import *

@pytest.fixture
def generate_assets():
X_train, y_train, X_test, y_test = titanic_survive()
model = RandomForestClassifier(n_estimators=5, max_depth=2).fit(X_train, y_train)
explainer = ClassifierExplainer(model, X_test, y_test)
db1 = ExplainerDashboard(explainer, description="Super interesting dashboard")
db2 = ExplainerDashboard(explainer, title="Dashboard Two",
name='db2', logins=[['user2', 'password2']])
hub = ExplainerHub([db1, db2])
hub.to_yaml(Path.cwd() / "tests" / "test_assets" / "hub.yaml")
return None


def test_explainerhub_cli_help(generate_assets, script_runner):
ret = script_runner.run('explainerhub', ' --help',
cwd=str(Path().cwd() / "tests" / "test_assets"))
def test_explainerhub_cli_help(explainer_hub_dump_folder, script_runner):
ret = script_runner.run('explainerhub', ' --help', cwd=str(explainer_hub_dump_folder))
assert ret.success
assert ret.stderr == ''

61 changes: 33 additions & 28 deletions tests/hub/test_hub_integration.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,40 @@
from unittest import TestCase
import pytest

from sklearn.ensemble import RandomForestClassifier
@pytest.fixture(scope="session")
def explainer_hub_client(explainer_hub):
explainer_hub.app.config["TESTING"] = True
explainer_hub.app.config["WTF_CSRF_CHECK_DEFAULT"] = False
client = explainer_hub.app.test_client()
_ctx = explainer_hub.app.test_request_context()
_ctx.push()

from explainerdashboard import *
from explainerdashboard.datasets import *
from explainerdashboard.custom import *
yield client

from unittest import TestCase
if _ctx is not None:
_ctx.pop()

class UserTest(TestCase):
def setUp(self):
X_train, y_train, X_test, y_test = titanic_survive()
model = RandomForestClassifier(n_estimators=5, max_depth=2).fit(X_train, y_train)
explainer = ClassifierExplainer(model, X_test, y_test)
db1 = ExplainerDashboard(explainer, description="Super interesting dashboard")
db2 = ExplainerDashboard(explainer, title="Dashboard Two",
logins=[['user', 'password']], name='db2')
self.hub = ExplainerHub([db1, db2])
self.hub.app.config["TESTING"] = True
self.hub.app.config["WTF_CSRF_CHECK_DEFAULT"] = False
self.client = self.hub.app.test_client()
self._ctx = self.hub.app.test_request_context()
self._ctx.push()
# class UserTest(TestCase):
# def setUp(self):
# X_train, y_train, X_test, y_test = titanic_survive()
# model = RandomForestClassifier(n_estimators=5, max_depth=2).fit(X_train, y_train)
# explainer = ClassifierExplainer(model, X_test, y_test)
# db1 = ExplainerDashboard(explainer, description="Super interesting dashboard")
# db2 = ExplainerDashboard(explainer, title="Dashboard Two",
# logins=[['user', 'password']], name='db2')
# self.hub = ExplainerHub([db1, db2])
# self.hub.app.config["TESTING"] = True
# self.hub.app.config["WTF_CSRF_CHECK_DEFAULT"] = False
# self.client = self.hub.app.test_client()
# self._ctx = self.hub.app.test_request_context()
# self._ctx.push()

def tearDown(self):
if self._ctx is not None:
self._ctx.pop()
# def tearDown(self):
# if self._ctx is not None:
# self._ctx.pop()

def test_user_index(self):
with self.client:
data = {"username": "user", "password": "password", "next": "/"}
response = self.client.post("/login/", data=data)
self.assertEqual(response.status_code, 200)
def test_explainer_hub_client(explainer_hub_client):
with explainer_hub_client:
data = {"username": "user", "password": "password", "next": "/"}
response = explainer_hub_client.post("/login/", data=data)
assert (response.status_code == 200)

0 comments on commit c323497

Please sign in to comment.