Skip to content

Commit

Permalink
Merge pull request #730 from dianna-ai/644-rise-tabular
Browse files Browse the repository at this point in the history
add rise tabular
  • Loading branch information
cwmeijer committed Jun 19, 2024
2 parents a77e66b + 9d5b567 commit eeb4e71
Show file tree
Hide file tree
Showing 21 changed files with 937 additions and 243 deletions.
2 changes: 1 addition & 1 deletion dianna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def explain_text(model_or_function: Union[Callable,
def explain_tabular(model_or_function: Union[Callable, str],
input_tabular: np.ndarray,
method: str,
labels=(1, ),
labels=None,
**kwargs) -> np.ndarray:
"""Explain tabular (input_text) given a model and a chosen method.
Expand Down
17 changes: 8 additions & 9 deletions dianna/methods/kernelshap_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(
weighted kmeans
"""
if training_data_kmeans:
self.training_data = shap.kmeans(training_data, training_data_kmeans)
self.training_data = shap.kmeans(training_data,
training_data_kmeans)
else:
self.training_data = training_data
self.feature_names = feature_names
Expand Down Expand Up @@ -65,17 +66,15 @@ def explain(
An array (np.ndarray) containing the KernelExplainer explanations for each class.
"""
init_instance_kwargs = utils.get_kwargs_applicable_to_function(
KernelExplainer, kwargs
)
self.explainer = KernelExplainer(
model_or_function, self.training_data, link, **init_instance_kwargs
)
KernelExplainer, kwargs)
self.explainer = KernelExplainer(model_or_function, self.training_data,
link, **init_instance_kwargs)

explain_instance_kwargs = utils.get_kwargs_applicable_to_function(
self.explainer.shap_values, kwargs
)
self.explainer.shap_values, kwargs)

saliency = self.explainer.shap_values(input_tabular, **explain_instance_kwargs)
saliency = self.explainer.shap_values(input_tabular,
**explain_instance_kwargs)

if self.mode == 'regression':
saliency = saliency[0]
Expand Down
20 changes: 11 additions & 9 deletions dianna/methods/lime_tabular.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""LIME tabular explainer."""
import sys
from typing import Iterable
from typing import List
from typing import Optional
from typing import Union
import numpy as np
from lime.lime_tabular import LimeTabularExplainer
Expand Down Expand Up @@ -58,12 +60,10 @@ def __init__(
"""
self.mode = mode
init_instance_kwargs = utils.get_kwargs_applicable_to_function(
LimeTabularExplainer, kwargs
)
LimeTabularExplainer, kwargs)

# temporary solution for setting num_features and top_labels
self.num_features = len(feature_names)
self.top_labels = len(class_names)

self.explainer = LimeTabularExplainer(
training_data,
Expand All @@ -83,7 +83,7 @@ def explain(
self,
model_or_function: Union[str, callable],
input_tabular: np.array,
labels: Iterable[int] = (1,),
labels: Optional[Iterable[int]] = None,
num_samples: int = 5000,
**kwargs,
) -> np.array:
Expand All @@ -93,7 +93,7 @@ def explain(
model_or_function (callable or str): The function that runs the model to be explained
or the path to a ONNX model on disk.
input_tabular (np.ndarray): Data to be explained.
labels (Iterable(int), optional): Indices of classes to be explained.
labels (Iterable(int)): Indices of classes to be explained.
num_samples (int, optional): Number of samples
kwargs: These parameters are passed on
Expand All @@ -105,15 +105,14 @@ def explain(
"""
# run the explanation.
explain_instance_kwargs = utils.get_kwargs_applicable_to_function(
self.explainer.explain_instance, kwargs
)
self.explainer.explain_instance, kwargs)
runner = utils.get_function(model_or_function)

explanation = self.explainer.explain_instance(
input_tabular,
runner,
labels=labels,
top_labels=self.top_labels,
top_labels=sys.maxsize,
num_features=self.num_features,
num_samples=num_samples,
**explain_instance_kwargs,
Expand All @@ -126,10 +125,13 @@ def explain(
elif self.mode == 'classification':
# extract scores from lime explainer
saliency = []
for i in range(self.top_labels):
for i in range(len(explanation.local_exp.items())):
local_exp = sorted(explanation.local_exp[i])
# shape of local_exp [(index, saliency)]
selected_saliency = [x[1] for x in local_exp]
saliency.append(selected_saliency[:])

else:
raise ValueError(f'Unsupported mode "{self.mode}"')

return np.array(saliency)
2 changes: 1 addition & 1 deletion dianna/methods/lime_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def explain(
# wrap up the input model or function using the runner
runner = utils.get_function(
model_or_function, preprocess_function=self.preprocess_function)
masks = generate_time_series_masks(input_timeseries,
masks = generate_time_series_masks(input_timeseries.shape,
num_samples,
p_keep=0.1)
# NOTE: Required by `lime_base` explainer since the first instance must be the original data
Expand Down
111 changes: 111 additions & 0 deletions dianna/methods/rise_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""RISE tabular explainer."""
from typing import Iterable
from typing import List
from typing import Optional
from typing import Union
import numpy as np
from dianna import utils
from dianna.utils.maskers import generate_tabular_masks
from dianna.utils.maskers import mask_data_tabular
from dianna.utils.predict import make_predictions
from dianna.utils.rise_utils import normalize


class RISETabular:
"""RISE explainer for tabular data."""

def __init__(
self,
training_data: np.array,
mode: str = "classification",
feature_names: List[str] = None,
categorical_features: List[int] = None,
n_masks: int = 1000,
feature_res: int = 8,
p_keep: float = 0.5,
preprocess_function: Optional[callable] = None,
class_names=None,
keep_masks: bool = False,
keep_masked: bool = False,
keep_predictions: bool = False,
) -> np.ndarray:
"""RISE initializer.
Args:
n_masks: Number of masks to generate.
feature_res: Resolution of features in masks.
p_keep: Fraction of input data to keep in each mask (Default: auto-tune this value).
preprocess_function: Function to preprocess input data with
categorical_features: list of categorical features
class_names: Names of the classes
feature_names: Names of the features
mode: Either classification of regression
training_data: Training data used for imputation of masked features
keep_masks: keep masks in memory for the user to inspect
keep_masked: keep masked data in memory for the user to inspect
keep_predictions: keep model predictions in memory for the user to inspect
"""
self.training_data = training_data
self.n_masks = n_masks
self.feature_res = feature_res
self.p_keep = p_keep
self.preprocess_function = preprocess_function
self.masks = None
self.masked = None
self.predictions = None
self.keep_masks = keep_masks
self.keep_masked = keep_masked
self.keep_predictions = keep_predictions
self.mode = mode

def explain(
self,
model_or_function: Union[str, callable],
input_tabular: np.array,
labels: Optional[Iterable[int]] = None,
mask_type: Optional[Union[str, callable]] = 'most_frequent',
batch_size: Optional[int] = 100,
) -> np.array:
"""Run the RISE explainer.
Args:
model_or_function: The function that runs the model to be explained
or the path to a ONNX model on disk.
input_tabular: Data to be explained.
labels: Indices of classes to be explained.
num_samples: Number of samples
mask_type: Imputation strategy for masked features
batch_size: Number of samples to process by the model per batch
Returns:
explanation: An Explanation object containing the LIME explanations for each class.
"""
# run the explanation.
runner = utils.get_function(model_or_function)

masks = np.stack(
list(
generate_tabular_masks(input_tabular.shape,
number_of_masks=self.n_masks,
p_keep=self.p_keep)))
self.masks = masks if self.keep_masks else None

masked = mask_data_tabular(input_tabular,
masks,
self.training_data,
mask_type=mask_type)
self.masked = masked if self.keep_masked else None
predictions = make_predictions(masked, runner, batch_size)
self.predictions = predictions if self.keep_predictions else None
n_labels = predictions.shape[1]

masks_reshaped = masks.reshape(self.n_masks, -1)

saliency = predictions.T.dot(masks_reshaped).reshape(
n_labels, *input_tabular.shape)

if self.mode == 'regression':
return saliency[0]

selected_saliency = saliency if labels is None else saliency[labels]
return normalize(selected_saliency, self.n_masks, self.p_keep)
2 changes: 1 addition & 1 deletion dianna/methods/rise_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def explain(self,
runner = utils.get_function(
model_or_function, preprocess_function=self.preprocess_function)

masks = generate_time_series_masks(input_timeseries,
masks = generate_time_series_masks(input_timeseries.shape,
number_of_masks=self.n_masks,
feature_res=self.feature_res,
p_keep=self.p_keep)
Expand Down
Loading

0 comments on commit eeb4e71

Please sign in to comment.