Skip to content

Commit

Permalink
adds a shap_kwargs parameter
Browse files Browse the repository at this point in the history
this allows you to add optional parameters to the shap
explainer such as check_additivity=False
  • Loading branch information
Oege Dijk committed May 23, 2022
1 parent f43af40 commit c8c1d5f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 19 deletions.
4 changes: 2 additions & 2 deletions explainerdashboard/dashboard_methods.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
__all__ = [
'delegates_kwargs',
'delegates_doc',
'update_params',
'update_kwargs',
'DummyComponent',
'ExplainerComponent',
'PosLabelSelector',
'GraphPopout',
'IndexSelector',
'make_hideable',
'get_dbc_tooltips',
'update_params',
'update_kwargs',
'encode_callables',
'decode_callables',
'reset_id_generator',
Expand Down
2 changes: 2 additions & 0 deletions explainerdashboard/explainer_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def safe_isinstance(obj, *instance_str):
return False




def guess_shap(model):
"""guesses which SHAP explainer to use for a particular model, based
on str(type(model)). Returns 'tree' for tree based models such as
Expand Down
49 changes: 32 additions & 17 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, model, X:pd.DataFrame, y:pd.Series=None, permutation_metric:C
idxs:pd.Index=None, index_name:str=None, target:str=None,
descriptions:dict=None,
n_jobs:int=None, permutation_cv:int=None, cv:int=None, na_fill:float=-999,
precision:str="float64"):
precision:str="float64", shap_kwargs:Dict=None):
"""Defines the basic functionality that is shared by both
ClassifierExplainer and RegressionExplainer.
Expand Down Expand Up @@ -121,11 +121,14 @@ def __init__(self, model, X:pd.DataFrame, y:pd.Series=None, permutation_metric:C
Defaults to None.
na_fill (int): The filler used for missing values, defaults to -999.
precision: precision with which to store values. Defaults to "float64".
shap_kwargs(dict): dictionary of keyword arguments to be passed to the shap explainer.
most typically used to supress an additivity check e.g. `shap_kwargs=dict(check_additivity=False)`
"""
self._params_dict = dict(
shap=shap, model_output=model_output, cats=cats,
descriptions=descriptions, target=target, n_jobs=n_jobs,
permutation_cv=permutation_cv, cv=cv, na_fill=na_fill, precision=precision)
permutation_cv=permutation_cv, cv=cv, na_fill=na_fill,
precision=precision, shap_kwargs=shap_kwargs)

if permutation_cv is not None:
warnings.warn("Parameter permutation_cv has been deprecated! Please use "
Expand Down Expand Up @@ -246,7 +249,7 @@ def __init__(self, model, X:pd.DataFrame, y:pd.Series=None, permutation_metric:C
print(f"WARNING: For shap='{self.shap}', shap interaction values can unfortunately "
"not be calculated!")
self.interactions_should_work = False

self.shap_kwargs = shap_kwargs if shap_kwargs else {}
self.model_output = model_output

if idxs is not None:
Expand Down Expand Up @@ -293,7 +296,6 @@ def __init__(self, model, X:pd.DataFrame, y:pd.Series=None, permutation_metric:C
if not hasattr(self, "interactions_should_work"):
self.interactions_should_work = True


self.__version__ = "0.3.3"

def get_lock(self):
Expand Down Expand Up @@ -934,11 +936,15 @@ def get_shap_values_df(self, pos_label=None):
print("Calculating shap values...", flush=True)
if self.shap == 'skorch':
import torch
self._shap_values_df = pd.DataFrame(self.shap_explainer.shap_values(torch.tensor(self.X.values)),
columns=self.columns)
self._shap_values_df = pd.DataFrame(
self.shap_explainer.shap_values(torch.tensor(self.X.values), **self.shap_kwargs),
columns=self.columns
)
else:
self._shap_values_df = pd.DataFrame(self.shap_explainer.shap_values(self.X),
columns=self.columns)
self._shap_values_df = pd.DataFrame(
self.shap_explainer.shap_values(self.X, **self.shap_kwargs),
columns=self.columns
)
self._shap_values_df = merge_categorical_shap_values(
self._shap_values_df, self.onehot_dict, self.merged_cols).astype(self.precision)
return self._shap_values_df
Expand Down Expand Up @@ -974,8 +980,9 @@ def get_shap_row(self, index=None, X_row=None, pos_label=None):
import torch
X_row = torch.tensor(X_row.values.astype("float32"))
with self.get_lock():
shap_kwargs = dict(self.shap_kwargs, silent=True) if self.shap == 'kernel' else self.shap_kwargs
shap_row = pd.DataFrame(
self.shap_explainer.shap_values(X_row, **(dict(silent=True) if self.shap=='kernel' else {})),
self.shap_explainer.shap_values(X_row, **self.shap_kwargs),
columns=self.columns)
shap_row = merge_categorical_shap_values(shap_row,
self.onehot_dict, self.merged_cols)
Expand All @@ -986,8 +993,9 @@ def get_shap_row(self, index=None, X_row=None, pos_label=None):
import torch
X_row = torch.tensor(X_row.values.astype("float32"))
with self.get_lock():
shap_kwargs = dict(self.shap_kwargs, silent=True) if self.shap == 'kernel' else self.shap_kwargs
shap_row = pd.DataFrame(
self.shap_explainer.shap_values(X_row, **(dict(silent=True) if self.shap=='kernel' else {})),
self.shap_explainer.shap_values(X_row, **self.shap_kwargs),
columns=self.columns)
shap_row = merge_categorical_shap_values(shap_row,
self.onehot_dict, self.merged_cols)
Expand Down Expand Up @@ -1955,7 +1963,8 @@ def __init__(self, model, X:pd.DataFrame, y:pd.Series=None,
index_name:str=None, target:str=None,
descriptions:Dict=None, n_jobs:int=None,
permutation_cv:int=None, cv:int=None, na_fill:float=-999,
precision:str="float64", labels:List=None, pos_label:int=1):
precision:str="float64", shap_kwargs:Dict=None,
labels:List=None, pos_label:int=1):
"""
Explainer for classification models. Defines the shap values for
each possible class in the classification.
Expand Down Expand Up @@ -2008,6 +2017,8 @@ def __init__(self, model, X:pd.DataFrame, y:pd.Series=None,
Defaults to None.
na_fill (int): The filler used for missing values, defaults to -999.
precision: precision with which to store values. Defaults to "float64".
shap_kwargs(dict): dictionary of keyword arguments to be passed to the shap explainer.
most typically used to supress an additivity check e.g. `shap_kwargs=dict(check_additivity=False)`
labels(list): list of str labels for the different classes,
defaults to e.g. ['0', '1'] for a binary classification
pos_label: class that should be used as the positive class,
Expand All @@ -2017,7 +2028,7 @@ def __init__(self, model, X:pd.DataFrame, y:pd.Series=None,
shap, X_background, model_output,
cats, cats_notencoded, idxs, index_name, target,
descriptions, n_jobs, permutation_cv, cv, na_fill,
precision)
precision, shap_kwargs)

assert hasattr(model, "predict_proba"), \
("for ClassifierExplainer, model should be a scikit-learn "
Expand Down Expand Up @@ -2294,9 +2305,9 @@ def get_shap_values_df(self, pos_label=None):
print("Calculating shap values...", flush=True)
if self.shap == 'skorch':
import torch
_shap_values = self.shap_explainer.shap_values(torch.tensor(self.X.values.astype("float32")))
_shap_values = self.shap_explainer.shap_values(torch.tensor(self.X.values.astype("float32")), **self.shap_kwargs)
else:
_shap_values = self.shap_explainer.shap_values(self.X.values)
_shap_values = self.shap_explainer.shap_values(self.X.values, **self.shap_kwargs)

if len(self.labels) == 2:
if not isinstance(_shap_values, list):
Expand Down Expand Up @@ -2406,7 +2417,8 @@ def X_row_to_shap_row(X_row):
import torch
X_row = torch.tensor(X_row.values.astype("float32"))
with self.get_lock():
sv = self.shap_explainer.shap_values(X_row, **(dict(silent=True) if self.shap=='kernel' else {}))
shap_kwargs = dict(self.shap_kwargs, silent=True) if self.shap == 'kernel' else self.shap_kwargs
sv = self.shap_explainer.shap_values(X_row, **shap_kwargs)
if isinstance(sv, list) and len(sv) > 1:
shap_row = pd.DataFrame(sv[pos_label], columns=self.columns)
elif len(self.labels) == 2:
Expand Down Expand Up @@ -3170,7 +3182,8 @@ def __init__(self, model, X:pd.DataFrame, y:pd.Series=None,
cats:Union[List, Dict]=None, cats_notencoded:Dict=None,
idxs:pd.Index=None, index_name:str=None, target:str=None,
descriptions:Dict=None, n_jobs:int=None, permutation_cv:int=None,
cv:int=None, na_fill:float=-999, precision:str="float64", units:str=""):
cv:int=None, na_fill:float=-999, precision:str="float64",
shap_kwargs:Dict=None, units:str=""):
"""Explainer for regression models.
In addition to BaseExplainer defines a number of plots specific to
Expand Down Expand Up @@ -3219,13 +3232,15 @@ def __init__(self, model, X:pd.DataFrame, y:pd.Series=None,
Defaults to None.
na_fill (int): The filler used for missing values, defaults to -999.
precision: precision with which to store values. Defaults to "float64".
shap_kwargs(dict): dictionary of keyword arguments to be passed to the shap explainer.
most typically used to supress an additivity check e.g. `shap_kwargs=dict(check_additivity=False)`
units(str): units to display for regression quantity
"""
super().__init__(model, X, y, permutation_metric,
shap, X_background, model_output,
cats, cats_notencoded, idxs, index_name, target,
descriptions, n_jobs, permutation_cv, cv, na_fill,
precision)
precision, shap_kwargs)

self._params_dict = {**self._params_dict, **dict(units=units)}
self.units = units
Expand Down

0 comments on commit c8c1d5f

Please sign in to comment.