Skip to content

Commit

Permalink
speeds up linear kernel tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Oege Dijk committed May 10, 2022
1 parent b524eb5 commit 5b9a4c4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
8 changes: 3 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,11 @@ def logistic_regression_kernel_explainer(fitted_logistic_regression_model, class
_, _, X_test, y_test = classifier_data
explainer = ClassifierExplainer(
fitted_logistic_regression_model,
X_test,
y_test,
X_test.iloc[:10], y_test.iloc[:10],
cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck', 'Embarked'],
cats_notencoded={'Gender': 'No Gender'},
labels=['Not survived', 'Survived'],
shap='kernel', model_output='probability'
shap='kernel',
)
return explainer

Expand Down Expand Up @@ -299,9 +298,8 @@ def linear_regression_kernel_explainer(fitted_linear_regression_model, regressio
_, _, X_test, y_test = regression_data
explainer = RegressionExplainer(
fitted_linear_regression_model,
X_test, y_test,
X_test.iloc[:10], y_test.iloc[:10],
cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, 'Deck', 'Embarked'],
idxs=test_names,
shap='kernel')
return explainer

Expand Down
3 changes: 2 additions & 1 deletion tests/test_linear_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

import pytest

import pandas as pd
import numpy as np
Expand Down Expand Up @@ -135,13 +136,13 @@ def test_logreg_lift_curve_df(precalculated_logistic_regression_explainer):



##### KERNEL TESTS

def test_logistic_regression_kernel_shap_values(logistic_regression_kernel_explainer):
assert isinstance(logistic_regression_kernel_explainer.shap_base_value(), (np.floating, float))
assert (logistic_regression_kernel_explainer.get_shap_values_df().shape == (len(logistic_regression_kernel_explainer), len(logistic_regression_kernel_explainer.merged_cols)))
assert isinstance(logistic_regression_kernel_explainer.get_shap_values_df(), pd.DataFrame)


def test_linear_regression_kernel_shap_values(linear_regression_kernel_explainer):
assert isinstance(linear_regression_kernel_explainer.shap_base_value(), (np.floating, float))
assert (linear_regression_kernel_explainer.get_shap_values_df().shape == (len(linear_regression_kernel_explainer), len(linear_regression_kernel_explainer.merged_cols)))
Expand Down

0 comments on commit 5b9a4c4

Please sign in to comment.