Skip to content

Commit

Permalink
cv tests to pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
Oege Dijk committed May 6, 2022
1 parent e28b4c0 commit 3a7dcff
Showing 1 changed file with 30 additions and 33 deletions.
63 changes: 30 additions & 33 deletions tests/test_cv.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,47 @@
import unittest
import pytest

import pandas as pd
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor

from explainerdashboard.explainers import ClassifierExplainer, RegressionExplainer
from explainerdashboard.datasets import titanic_survive, titanic_fare


class ClassifierCVTests(unittest.TestCase):
def setUp(self):
X_train, y_train, X_test, y_test = titanic_survive()
@pytest.fixture(scope="module")
def classifier_explainer_with_cv(fitted_rf_classifier_model):
_, _, X_test, y_test = titanic_survive()
return ClassifierExplainer(
fitted_rf_classifier_model,
X_test, y_test,
cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck', 'Embarked'],
cv=3
)

model = RandomForestClassifier(n_estimators=5, max_depth=2)
model.fit(X_train, y_train)
@pytest.fixture(scope="module")
def regression_explainer_with_cv(fitted_rf_regression_model):
_, _, X_test, y_test = titanic_fare()
return RegressionExplainer(
fitted_rf_regression_model,
X_test, y_test,
cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck', 'Embarked'],
cv=3
)

self.explainer = ClassifierExplainer(
model, X_train.iloc[:50], y_train.iloc[:50],
cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']},
'Deck', 'Embarked'],
cv=3)

def test_cv_permutation_importances(self):
self.assertIsInstance(self.explainer.permutation_importances(), pd.DataFrame)
self.assertIsInstance(self.explainer.permutation_importances(pos_label=0), pd.DataFrame)

def test_cv_metrics(self):
self.assertIsInstance(self.explainer.metrics(), dict)
self.assertIsInstance(self.explainer.metrics(pos_label=0), dict)
def test_clas_cv_permutation_importances(classifier_explainer_with_cv):
assert isinstance(classifier_explainer_with_cv.permutation_importances(), pd.DataFrame)
assert isinstance(classifier_explainer_with_cv.permutation_importances(pos_label=0), pd.DataFrame)

def test_clas_cv_metrics(classifier_explainer_with_cv):
assert isinstance(classifier_explainer_with_cv.metrics(), dict)
assert isinstance(classifier_explainer_with_cv.metrics(pos_label=0), dict)

class RegressionCVTests(unittest.TestCase):
def setUp(self):
X_train, y_train, X_test, y_test = titanic_fare()
model = RandomForestRegressor(n_estimators=5, max_depth=2).fit(X_train, y_train)

self.explainer = RegressionExplainer(
model, X_test, y_test,
cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']},
'Deck', 'Embarked'],
cv=3)

def test_cv_permutation_importances(self):
self.assertIsInstance(self.explainer.permutation_importances(), pd.DataFrame)

def test_cv_metrics(self):
self.assertIsInstance(self.explainer.metrics(), dict)
def test_reg_cv_permutation_importances(regression_explainer_with_cv):
assert isinstance(regression_explainer_with_cv.permutation_importances(), pd.DataFrame)
assert isinstance(regression_explainer_with_cv.permutation_importances(pos_label=0), pd.DataFrame)

def test_reg_cv_metrics(regression_explainer_with_cv):
assert isinstance(regression_explainer_with_cv.metrics(), dict)


0 comments on commit 3a7dcff

Please sign in to comment.