-
-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Oege Dijk
committed
May 8, 2022
1 parent
6ebadc4
commit c323497
Showing
4 changed files
with
70 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,19 @@ | ||
import unittest | ||
from pathlib import Path | ||
from explainerdashboard import ExplainerHub | ||
|
||
from sklearn.ensemble import RandomForestClassifier | ||
from explainerdashboard import * | ||
from explainerdashboard.datasets import * | ||
from explainerdashboard.custom import * | ||
|
||
class ExplainerHubTests(unittest.TestCase): | ||
def setUp(self): | ||
X_train, y_train, X_test, y_test = titanic_survive() | ||
|
||
model = RandomForestClassifier(n_estimators=5, max_depth=2).fit(X_train, y_train) | ||
self.explainer = ClassifierExplainer(model, X_test, y_test) | ||
self.db1 = ExplainerDashboard(self.explainer, description="Super interesting dashboard") | ||
self.db2 = ExplainerDashboard(self.explainer, title="Dashboard Two", | ||
name='db2', logins=[['user2', 'password2']]) | ||
self.hub = ExplainerHub([self.db1, self.db2], users_file=str(Path.cwd() / "tests" / "test_assets" / "users.yaml")) | ||
|
||
def test_hub_users(self): | ||
self.assertGreater(len(self.hub.users), 0) | ||
self.assertIn("db2", self.hub.dashboards_with_users) | ||
self.hub.add_user("user3", "password") | ||
self.hub.add_user_to_dashboard("db2", "user3") | ||
self.assertIn("user3", self.hub.dashboard_users['db2']) | ||
self.hub.add_user("user4", "password", add_to_users_file=True) | ||
self.hub.add_user_to_dashboard("db2", "user4", add_to_users_file=True) | ||
self.assertIn("user4", self.hub.dashboard_users['db2']) | ||
self.assertIn("user4", self.hub.get_dashboard_users("db2")) | ||
def test_hub_users(explainer_hub): | ||
assert len(explainer_hub.users) > 0 | ||
assert ("db2" in explainer_hub.dashboards_with_users) | ||
explainer_hub.add_user("user3", "password") | ||
explainer_hub.add_user_to_dashboard("db2", "user3") | ||
assert ("user3" in explainer_hub.dashboard_users['db2']) | ||
explainer_hub.add_user("user4", "password", add_to_users_file=True) | ||
explainer_hub.add_user_to_dashboard("db2", "user4", add_to_users_file=True) | ||
assert ("user4" in explainer_hub.dashboard_users['db2']) | ||
assert ("user4" in explainer_hub.get_dashboard_users("db2")) | ||
|
||
def test_load_from_config(self): | ||
self.hub.to_yaml(Path.cwd() / "tests" / "test_assets" / "hub.yaml") | ||
self.hub2 = ExplainerHub.from_config(Path.cwd() / "tests" / "test_assets" / "hub.yaml") | ||
def test_load_from_config(explainer_hub, tmp_path_factory): | ||
tmp_path = tmp_path_factory.mktemp("tmp_hub") | ||
explainer_hub.to_yaml(tmp_path / "hub.yaml") | ||
explainer_hub2 = ExplainerHub.from_config(tmp_path / "hub.yaml") | ||
assert isinstance(explainer_hub2, ExplainerHub) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,7 @@ | ||
import pytest | ||
from pathlib import Path | ||
|
||
from sklearn.ensemble import RandomForestClassifier | ||
|
||
from explainerdashboard import * | ||
from explainerdashboard.datasets import * | ||
from explainerdashboard.custom import * | ||
|
||
@pytest.fixture | ||
def generate_assets(): | ||
X_train, y_train, X_test, y_test = titanic_survive() | ||
model = RandomForestClassifier(n_estimators=5, max_depth=2).fit(X_train, y_train) | ||
explainer = ClassifierExplainer(model, X_test, y_test) | ||
db1 = ExplainerDashboard(explainer, description="Super interesting dashboard") | ||
db2 = ExplainerDashboard(explainer, title="Dashboard Two", | ||
name='db2', logins=[['user2', 'password2']]) | ||
hub = ExplainerHub([db1, db2]) | ||
hub.to_yaml(Path.cwd() / "tests" / "test_assets" / "hub.yaml") | ||
return None | ||
|
||
|
||
def test_explainerhub_cli_help(generate_assets, script_runner): | ||
ret = script_runner.run('explainerhub', ' --help', | ||
cwd=str(Path().cwd() / "tests" / "test_assets")) | ||
def test_explainerhub_cli_help(explainer_hub_dump_folder, script_runner): | ||
ret = script_runner.run('explainerhub', ' --help', cwd=str(explainer_hub_dump_folder)) | ||
assert ret.success | ||
assert ret.stderr == '' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,40 @@ | ||
from unittest import TestCase | ||
import pytest | ||
|
||
from sklearn.ensemble import RandomForestClassifier | ||
@pytest.fixture(scope="session") | ||
def explainer_hub_client(explainer_hub): | ||
explainer_hub.app.config["TESTING"] = True | ||
explainer_hub.app.config["WTF_CSRF_CHECK_DEFAULT"] = False | ||
client = explainer_hub.app.test_client() | ||
_ctx = explainer_hub.app.test_request_context() | ||
_ctx.push() | ||
|
||
from explainerdashboard import * | ||
from explainerdashboard.datasets import * | ||
from explainerdashboard.custom import * | ||
yield client | ||
|
||
from unittest import TestCase | ||
if _ctx is not None: | ||
_ctx.pop() | ||
|
||
class UserTest(TestCase): | ||
def setUp(self): | ||
X_train, y_train, X_test, y_test = titanic_survive() | ||
model = RandomForestClassifier(n_estimators=5, max_depth=2).fit(X_train, y_train) | ||
explainer = ClassifierExplainer(model, X_test, y_test) | ||
db1 = ExplainerDashboard(explainer, description="Super interesting dashboard") | ||
db2 = ExplainerDashboard(explainer, title="Dashboard Two", | ||
logins=[['user', 'password']], name='db2') | ||
self.hub = ExplainerHub([db1, db2]) | ||
self.hub.app.config["TESTING"] = True | ||
self.hub.app.config["WTF_CSRF_CHECK_DEFAULT"] = False | ||
self.client = self.hub.app.test_client() | ||
self._ctx = self.hub.app.test_request_context() | ||
self._ctx.push() | ||
# class UserTest(TestCase): | ||
# def setUp(self): | ||
# X_train, y_train, X_test, y_test = titanic_survive() | ||
# model = RandomForestClassifier(n_estimators=5, max_depth=2).fit(X_train, y_train) | ||
# explainer = ClassifierExplainer(model, X_test, y_test) | ||
# db1 = ExplainerDashboard(explainer, description="Super interesting dashboard") | ||
# db2 = ExplainerDashboard(explainer, title="Dashboard Two", | ||
# logins=[['user', 'password']], name='db2') | ||
# self.hub = ExplainerHub([db1, db2]) | ||
# self.hub.app.config["TESTING"] = True | ||
# self.hub.app.config["WTF_CSRF_CHECK_DEFAULT"] = False | ||
# self.client = self.hub.app.test_client() | ||
# self._ctx = self.hub.app.test_request_context() | ||
# self._ctx.push() | ||
|
||
def tearDown(self): | ||
if self._ctx is not None: | ||
self._ctx.pop() | ||
# def tearDown(self): | ||
# if self._ctx is not None: | ||
# self._ctx.pop() | ||
|
||
def test_user_index(self): | ||
with self.client: | ||
data = {"username": "user", "password": "password", "next": "/"} | ||
response = self.client.post("/login/", data=data) | ||
self.assertEqual(response.status_code, 200) | ||
def test_explainer_hub_client(explainer_hub_client): | ||
with explainer_hub_client: | ||
data = {"username": "user", "password": "password", "next": "/"} | ||
response = explainer_hub_client.post("/login/", data=data) | ||
assert (response.status_code == 200) | ||
|