Skip to content

Commit

Permalink
bump njobs tests to pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
Oege Dijk committed May 7, 2022
1 parent 385558d commit ac37e99
Showing 1 changed file with 11 additions and 31 deletions.
42 changes: 11 additions & 31 deletions tests/test_njobs.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,20 @@
import unittest

import pandas as pd
import numpy as np

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score

from explainerdashboard.explainers import ClassifierExplainer
from explainerdashboard.datasets import titanic_survive, titanic_names


class NJobs5ExplainerTests(unittest.TestCase):
def setUp(self):
X_train, y_train, X_test, y_test = titanic_survive()
train_names, test_names = titanic_names()

model = RandomForestClassifier(n_estimators=5, max_depth=2)
model.fit(X_train, y_train)

self.explainer = ClassifierExplainer(
model, X_test, y_test, roc_auc_score, n_jobs=5)

def test_permutation_importances(self):
self.assertIsInstance(self.explainer.get_permutation_importances_df(), pd.DataFrame)

from explainerdashboard.datasets import titanic_survive

class NJobsMinusOneExplainerTests(unittest.TestCase):
def setUp(self):
X_train, y_train, X_test, y_test = titanic_survive()
train_names, test_names = titanic_names()

model = RandomForestClassifier(n_estimators=5, max_depth=2)
model.fit(X_train, y_train)
def test_permutation_importances_njobs_5(fitted_rf_classifier_model):
_, _, X_test, y_test = titanic_survive()
explainer = ClassifierExplainer(
fitted_rf_classifier_model, X_test, y_test, roc_auc_score, n_jobs=5)
assert isinstance(explainer.get_permutation_importances_df(), pd.DataFrame)

self.explainer = ClassifierExplainer(
model, X_test, y_test, roc_auc_score, n_jobs=-1)

def test_permutation_importances(self):
self.assertIsInstance(self.explainer.get_permutation_importances_df(), pd.DataFrame)
def test_permutation_importances_njobs_minus1(fitted_rf_classifier_model):
_, _, X_test, y_test = titanic_survive()
explainer = ClassifierExplainer(
fitted_rf_classifier_model, X_test, y_test, roc_auc_score, n_jobs=-1)
assert isinstance(explainer.get_permutation_importances_df(), pd.DataFrame)

0 comments on commit ac37e99

Please sign in to comment.