Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rise tabular #730

Merged
merged 23 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8c61395
refactor test lime tabular for reuse with rise
cwmeijer Mar 20, 2024
84eb480
refactor: use input_shape instead of complete input for generating masks
cwmeijer Mar 20, 2024
d7d866c
rename timeseries mask tests to clearify the data domain
cwmeijer Mar 26, 2024
59b4676
add mask functionality for tabular
cwmeijer Mar 26, 2024
80a400c
remove default for lime tabular and add ValueError for unsupported modes
cwmeijer Mar 26, 2024
61a2aba
add RISE tabular WIP + tests
cwmeijer Mar 27, 2024
019197a
fix: convert generator to list before stack
cwmeijer Apr 15, 2024
9993e09
add support for categorical features tabular maskes
cwmeijer Apr 18, 2024
c820d1f
add support for user defined mask type
cwmeijer Apr 22, 2024
897ba74
parameterize tabular tests (shared among xai methods), Fixes #712
cwmeijer Apr 24, 2024
de24b29
improve tabular tests (more strict)
cwmeijer Apr 25, 2024
c77dc52
refactor: encapsulate test_utils.run_model function
cwmeijer May 15, 2024
3cc26cc
fix consistent tabular return types
cwmeijer May 29, 2024
a0e7a9b
Update README.md; add links to rise tab and kernelshap tab tuts
cwmeijer May 29, 2024
3e6b502
add rise tabular penguin notebook tutorial
cwmeijer May 29, 2024
17bf17d
Merge remote-tracking branch 'origin/644-rise-tabular' into 644-rise-…
cwmeijer May 29, 2024
ef893e1
merge main into feature branch
cwmeijer May 29, 2024
19469d1
fix merge bugs
cwmeijer May 29, 2024
dec4980
Update README.md links to rise tabular tutorials
cwmeijer May 29, 2024
b85b5d8
move tutorial to explainer folder
cwmeijer May 29, 2024
53de62b
fix rise penguin notebook path and remove duplicate
cwmeijer May 30, 2024
046c72c
add link to rise penguin notebook
cwmeijer Jun 12, 2024
9d5b567
change docstring
cwmeijer Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dianna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def explain_text(model_or_function: Union[Callable,
def explain_tabular(model_or_function: Union[Callable, str],
input_tabular: np.ndarray,
method: str,
labels=(1, ),
labels=None,
**kwargs) -> np.ndarray:
"""Explain tabular (input_text) given a model and a chosen method.

Expand Down
17 changes: 8 additions & 9 deletions dianna/methods/kernelshap_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(
weighted kmeans
"""
if training_data_kmeans:
self.training_data = shap.kmeans(training_data, training_data_kmeans)
self.training_data = shap.kmeans(training_data,
training_data_kmeans)
else:
self.training_data = training_data
self.feature_names = feature_names
Expand Down Expand Up @@ -65,17 +66,15 @@ def explain(
An array (np.ndarray) containing the KernelExplainer explanations for each class.
"""
init_instance_kwargs = utils.get_kwargs_applicable_to_function(
KernelExplainer, kwargs
)
self.explainer = KernelExplainer(
model_or_function, self.training_data, link, **init_instance_kwargs
)
KernelExplainer, kwargs)
self.explainer = KernelExplainer(model_or_function, self.training_data,
link, **init_instance_kwargs)

explain_instance_kwargs = utils.get_kwargs_applicable_to_function(
self.explainer.shap_values, kwargs
)
self.explainer.shap_values, kwargs)

saliency = self.explainer.shap_values(input_tabular, **explain_instance_kwargs)
saliency = self.explainer.shap_values(input_tabular,
**explain_instance_kwargs)

if self.mode == 'regression':
saliency = saliency[0]
Expand Down
20 changes: 11 additions & 9 deletions dianna/methods/lime_tabular.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""LIME tabular explainer."""
import sys
from typing import Iterable
from typing import List
from typing import Optional
from typing import Union
import numpy as np
from lime.lime_tabular import LimeTabularExplainer
Expand Down Expand Up @@ -58,12 +60,10 @@ def __init__(
"""
self.mode = mode
init_instance_kwargs = utils.get_kwargs_applicable_to_function(
LimeTabularExplainer, kwargs
)
LimeTabularExplainer, kwargs)

# temporary solution for setting num_features and top_labels
self.num_features = len(feature_names)
self.top_labels = len(class_names)

self.explainer = LimeTabularExplainer(
training_data,
Expand All @@ -83,7 +83,7 @@ def explain(
self,
model_or_function: Union[str, callable],
input_tabular: np.array,
labels: Iterable[int] = (1,),
labels: Optional[Iterable[int]] = None,
num_samples: int = 5000,
**kwargs,
) -> np.array:
Expand All @@ -93,7 +93,7 @@ def explain(
model_or_function (callable or str): The function that runs the model to be explained
or the path to a ONNX model on disk.
input_tabular (np.ndarray): Data to be explained.
labels (Iterable(int), optional): Indices of classes to be explained.
labels (Iterable(int)): Indices of classes to be explained.
num_samples (int, optional): Number of samples
kwargs: These parameters are passed on

Expand All @@ -105,15 +105,14 @@ def explain(
"""
# run the explanation.
explain_instance_kwargs = utils.get_kwargs_applicable_to_function(
self.explainer.explain_instance, kwargs
)
self.explainer.explain_instance, kwargs)
runner = utils.get_function(model_or_function)

explanation = self.explainer.explain_instance(
input_tabular,
runner,
labels=labels,
top_labels=self.top_labels,
top_labels=sys.maxsize,
num_features=self.num_features,
num_samples=num_samples,
**explain_instance_kwargs,
Expand All @@ -126,10 +125,13 @@ def explain(
elif self.mode == 'classification':
# extract scores from lime explainer
saliency = []
for i in range(self.top_labels):
for i in range(len(explanation.local_exp.items())):
local_exp = sorted(explanation.local_exp[i])
# shape of local_exp [(index, saliency)]
selected_saliency = [x[1] for x in local_exp]
saliency.append(selected_saliency[:])

else:
raise ValueError(f'Unsupported mode "{self.mode}"')

return np.array(saliency)
2 changes: 1 addition & 1 deletion dianna/methods/lime_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def explain(
# wrap up the input model or function using the runner
runner = utils.get_function(
model_or_function, preprocess_function=self.preprocess_function)
masks = generate_time_series_masks(input_timeseries,
masks = generate_time_series_masks(input_timeseries.shape,
num_samples,
p_keep=0.1)
# NOTE: Required by `lime_base` explainer since the first instance must be the original data
Expand Down
111 changes: 111 additions & 0 deletions dianna/methods/rise_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""RISE tabular explainer."""
from typing import Iterable
from typing import List
from typing import Optional
from typing import Union
import numpy as np
from dianna import utils
from dianna.utils.maskers import generate_tabular_masks
from dianna.utils.maskers import mask_data_tabular
from dianna.utils.predict import make_predictions
from dianna.utils.rise_utils import normalize


class RISETabular:
"""RISE explainer for tabular data."""

def __init__(
self,
training_data: np.array,
mode: str = "classification",
feature_names: List[str] = None,
categorical_features: List[int] = None,
n_masks: int = 1000,
feature_res: int = 8,
p_keep: float = 0.5,
preprocess_function: Optional[callable] = None,
class_names=None,
keep_masks: bool = False,
keep_masked: bool = False,
keep_predictions: bool = False,
) -> np.ndarray:
"""RISE initializer.

Args:
n_masks: Number of masks to generate.
feature_res: Resolution of features in masks.
p_keep: Fraction of input data to keep in each mask (Default: auto-tune this value).
preprocess_function: Function to preprocess input data with
categorical_features: list of categorical features
class_names: Names of the classes
feature_names: Names of the features
mode: Either classification of regression
training_data: Training data used for imputation of masked features
keep_masks: keep masks in memory for the user to inspect
keep_masked: keep masked data in memory for the user to inspect
keep_predictions: keep model predictions in memory for the user to inspect
"""
self.training_data = training_data
self.n_masks = n_masks
self.feature_res = feature_res
self.p_keep = p_keep
self.preprocess_function = preprocess_function
self.masks = None
self.masked = None
self.predictions = None
self.keep_masks = keep_masks
self.keep_masked = keep_masked
self.keep_predictions = keep_predictions
self.mode = mode

def explain(
self,
model_or_function: Union[str, callable],
input_tabular: np.array,
labels: Optional[Iterable[int]] = None,
mask_type: Optional[Union[str, callable]] = 'most_frequent',
batch_size: Optional[int] = 100,
) -> np.array:
"""Run the RISE explainer.

Args:
model_or_function: The function that runs the model to be explained
or the path to a ONNX model on disk.
input_tabular: Data to be explained.
labels: Indices of classes to be explained.
num_samples: Number of samples
mask_type: Imputation strategy for masked features
batch_size: Number of samples to process by the model per batch

Returns:
explanation: An Explanation object containing the LIME explanations for each class.
"""
# run the explanation.
runner = utils.get_function(model_or_function)

masks = np.stack(
list(
generate_tabular_masks(input_tabular.shape,
number_of_masks=self.n_masks,
p_keep=self.p_keep)))
self.masks = masks if self.keep_masks else None

masked = mask_data_tabular(input_tabular,
masks,
self.training_data,
mask_type=mask_type)
self.masked = masked if self.keep_masked else None
predictions = make_predictions(masked, runner, batch_size)
self.predictions = predictions if self.keep_predictions else None
n_labels = predictions.shape[1]

masks_reshaped = masks.reshape(self.n_masks, -1)

saliency = predictions.T.dot(masks_reshaped).reshape(
n_labels, *input_tabular.shape)

if self.mode == 'regression':
return saliency[0]

selected_saliency = saliency if labels is None else saliency[labels]
return normalize(selected_saliency, self.n_masks, self.p_keep)
2 changes: 1 addition & 1 deletion dianna/methods/rise_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def explain(self,
runner = utils.get_function(
model_or_function, preprocess_function=self.preprocess_function)

masks = generate_time_series_masks(input_timeseries,
masks = generate_time_series_masks(input_timeseries.shape,
number_of_masks=self.n_masks,
feature_res=self.feature_res,
p_keep=self.p_keep)
Expand Down
Loading
Loading