Skip to content

Commit

Permalink
kernel_shap: adapt api for tabular data
Browse files Browse the repository at this point in the history
  • Loading branch information
lucashervier committed Jul 27, 2021
1 parent feea923 commit 6899e89
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions xplique/attributions/kernel_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,23 @@ def __init__(self,
feature (e.g super-pixel).
It allows to transpose from (resp. to) the original input space to (resp. from)
the interpretable space.
The default mapping is the identity mapping which is quickly a poor mapping.
The default mapping is:
- the quickshift segmentation algorithm for inputs with (N, W, H, C) shape,
we assume here such shape is used to represent (W, H, C) images.
- the felzenszwalb segmentation algorithm for inputs with (N, W, H) shape,
we assume here such shape is used to represent (W, H) images.
- an identity mapping if inputs has shape (N, W), we assume here your inputs
are tabular data.
To use your own custom map function you should use the following scheme:
def custom_map_to_interpret_space(inputs: tf.tensor (N, W, H, C)) ->
tf.tensor (N, W, H):
def custom_map_to_interpret_space(inputs: tf.tensor (N, W (, H, C) )) ->
tf.tensor (N, W (, H)):
**some grouping techniques**
return mappings
For instance you can use the scikit-image library to defines super pixels on your
images.
For instance you can use the scikit-image (as we did for the quickshift algorithm)
library to defines super pixels on your images..
nb_samples
The number of pertubed samples you want to generate for each input sample.
Expand All @@ -60,7 +66,8 @@ def custom_map_to_interpret_space(inputs: tf.tensor (N, W, H, C)) ->
ref_values
It defines reference value which replaces each feature when the corresponding
interpretable feature is set to 0.
It should be provided as: a ndarray (C,)
It should be provided as: a ndarray of shape (1) if there is no channels in your input
and (C,) otherwise
The default ref value is set to (0.5,0.5,0.5) for inputs with 3 channels (corresponding
to a grey pixel when inputs are normalized by 255) and to 0 otherwise.
Expand Down Expand Up @@ -92,9 +99,14 @@ def _kernel_shap_similarity_kernel(
This method compute the similarity between interpretable pertubed samples and
the original input (i.e a tf.ones(num_features)).
"""

# when calling the kernel, we will call it for interpretable
# samples which all have the same size, thus we can use the
# following trich to get the total number of interpretable
# features toward a specific input
nb_total_features = interpret_samples.bounding_shape(out_type = tf.int32)[1]
interpret_samples = interpret_samples.to_tensor()
num_selected_features = tf.reduce_sum(interpret_samples, axis=1)
num_features = len(interpret_samples[0])
nb_selected_features = tf.reduce_sum(interpret_samples, axis=1)

# Theoretically, in the case where the number of selected
# features is zero or the total number of features of the
Expand All @@ -103,8 +115,8 @@ def _kernel_shap_similarity_kernel(
# weight to 1000000 (all other weights are 1).
similarities = tf.where(
tf.logical_or(
tf.equal(num_selected_features, tf.constant(0)),
tf.equal(num_selected_features, tf.constant(num_features))
tf.equal(nb_selected_features, tf.constant(0)),
tf.equal(nb_selected_features, tf.constant(nb_total_features))
),
tf.ones(len(interpret_samples), dtype=tf.float32)*1000000.0,
tf.ones(len(interpret_samples), dtype=tf.float32)
Expand All @@ -114,7 +126,7 @@ def _kernel_shap_similarity_kernel(

@staticmethod
@tf.function
def _kernel_shap_pertub_func(num_features: Union[int, tf.Tensor],
def _kernel_shap_pertub_func(nb_features: Union[int, tf.Tensor],
nb_samples: int) -> tf.Tensor:
"""
The pertubed instances are sampled that way:
Expand All @@ -132,7 +144,7 @@ def _kernel_shap_pertub_func(num_features: Union[int, tf.Tensor],
This trick is the one used in the Captum library: https://github.com/pytorch/captum
"""
probs_nb_selected_feature = KernelShap._get_probs_nb_selected_feature(
tf.cast(num_features, dtype=tf.int32))
tf.cast(nb_features, dtype=tf.int32))
nb_selected_features = tf.random.categorical(tf.math.log([probs_nb_selected_feature]),
nb_samples,
dtype=tf.int32)
Expand All @@ -141,7 +153,7 @@ def _kernel_shap_pertub_func(num_features: Union[int, tf.Tensor],
interpret_samples = []

for i in range(nb_samples):
rand_vals = tf.random.normal([num_features])
rand_vals = tf.random.normal([nb_features])
idx_sorted_values = tf.argsort(rand_vals, direction='DESCENDING')
threshold_idx = idx_sorted_values[nb_selected_features[i]]
threshold = rand_vals[threshold_idx]
Expand Down

0 comments on commit 6899e89

Please sign in to comment.