From 4f45e616a65c1760622fafaf4f17f67e882b3ea3 Mon Sep 17 00:00:00 2001 From: Oege Dijk Date: Mon, 18 Dec 2023 20:57:13 +0100 Subject: [PATCH] Revert "add gputree support" This reverts commit 8f9f9affa90e79449603a55e8f2c15b83ef909de. --- explainerdashboard/explainers.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/explainerdashboard/explainers.py b/explainerdashboard/explainers.py index 7fae612..cbc50a2 100644 --- a/explainerdashboard/explainers.py +++ b/explainerdashboard/explainers.py @@ -321,12 +321,12 @@ def __init__( "sklearn-compatible NeuralNet wrapper are supported for now! " "See https://github.com/skorch-dev/skorch" ) - assert shap in ["tree", "linear", "deep", "kernel", "skorch", "gputree"], ( - "ERROR! Only shap='guess', 'tree', 'linear', ' kernel', 'skorch' " - "or 'gputree' are supported for now!" + assert shap in ["tree", "linear", "deep", "kernel", "skorch"], ( + "ERROR! Only shap='guess', 'tree', 'linear', ' kernel' or 'skorch' are " + " supported for now!" ) self.shap = shap - if self.shap in {"kernel", "skorch", "linear", "gputree"}: + if self.shap in {"kernel", "skorch", "linear"}: print( f"WARNING: For shap='{self.shap}', shap interaction values can unfortunately " "not be calculated!" @@ -334,8 +334,8 @@ def __init__( self.interactions_should_work = False if self.shap == "skorch": print( - "WARNING: For shap='skorch' the additivity check tends to fail. " - "For now you can set shap_kwargs=dict(check_additivity=False) to supress " + "WARNING: For shap='skorch' the additivity check tends to fail, " + "you set set shap_kwargs=dict(check_additivity=False) to supress " "this error (at your own risk)!" ) @@ -1068,8 +1068,8 @@ def shap_explainer(self): ) elif self.shap == "deep": print( - "Generating self.shap_explainer = " - "shap.DeepExplainer(model, X_background)" + f"Generating self.shap_explainer = " + f"shap.DeepExplainer(model, X_background)" ) print( "Warning: shap values for shap.DeepExplainer get " @@ -1084,8 +1084,8 @@ def shap_explainer(self): ) elif self.shap == "skorch": print( - "Generating self.shap_explainer = " - "shap.DeepExplainer(model, X_background)" + f"Generating self.shap_explainer = " + f"shap.DeepExplainer(model, X_background)" ) print( "Warning: shap values for shap.DeepExplainer get " @@ -1123,12 +1123,6 @@ def model_predict(data_asarray): if self.X_background is not None else shap.sample(self.X, 50), ) - elif self.shap == "gputree": - print( - "Using GPUTree explainer. Make sure you set up you CUDA GPU correctly first." - "See e.g. https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/explainers/GPUTree.html" - ) - self._shap_explainer = shap.explainer.GPUTree(self.model, self.X) return self._shap_explainer @insert_pos_label