Skip to content

Commit

Permalink
Revert "add gputree support"
Browse files Browse the repository at this point in the history
This reverts commit 8f9f9af.
  • Loading branch information
oegedijk committed Dec 18, 2023
1 parent 8f9f9af commit 4f45e61
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,21 +321,21 @@ def __init__(
"sklearn-compatible NeuralNet wrapper are supported for now! "
"See https://github.com/skorch-dev/skorch"
)
assert shap in ["tree", "linear", "deep", "kernel", "skorch", "gputree"], (
"ERROR! Only shap='guess', 'tree', 'linear', ' kernel', 'skorch' "
"or 'gputree' are supported for now!"
assert shap in ["tree", "linear", "deep", "kernel", "skorch"], (
"ERROR! Only shap='guess', 'tree', 'linear', ' kernel' or 'skorch' are "
" supported for now!"
)
self.shap = shap
if self.shap in {"kernel", "skorch", "linear", "gputree"}:
if self.shap in {"kernel", "skorch", "linear"}:
print(
f"WARNING: For shap='{self.shap}', shap interaction values can unfortunately "
"not be calculated!"
)
self.interactions_should_work = False
if self.shap == "skorch":
print(
"WARNING: For shap='skorch' the additivity check tends to fail. "
"For now you can set shap_kwargs=dict(check_additivity=False) to supress "
"WARNING: For shap='skorch' the additivity check tends to fail, "
"you set set shap_kwargs=dict(check_additivity=False) to supress "
"this error (at your own risk)!"
)

Expand Down Expand Up @@ -1068,8 +1068,8 @@ def shap_explainer(self):
)
elif self.shap == "deep":
print(
"Generating self.shap_explainer = "
"shap.DeepExplainer(model, X_background)"
f"Generating self.shap_explainer = "
f"shap.DeepExplainer(model, X_background)"
)
print(
"Warning: shap values for shap.DeepExplainer get "
Expand All @@ -1084,8 +1084,8 @@ def shap_explainer(self):
)
elif self.shap == "skorch":
print(
"Generating self.shap_explainer = "
"shap.DeepExplainer(model, X_background)"
f"Generating self.shap_explainer = "
f"shap.DeepExplainer(model, X_background)"
)
print(
"Warning: shap values for shap.DeepExplainer get "
Expand Down Expand Up @@ -1123,12 +1123,6 @@ def model_predict(data_asarray):
if self.X_background is not None
else shap.sample(self.X, 50),
)
elif self.shap == "gputree":
print(
"Using GPUTree explainer. Make sure you set up you CUDA GPU correctly first."
"See e.g. https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/explainers/GPUTree.html"
)
self._shap_explainer = shap.explainer.GPUTree(self.model, self.X)
return self._shap_explainer

@insert_pos_label
Expand Down

0 comments on commit 4f45e61

Please sign in to comment.