Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Causal fidelity for tabular #91

Merged
merged 3 commits into from
Feb 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ disable=
R0903, # allows to expose only one public method
R0914, # allow multiples local variables

E0401, # pending issue with pylint see pylint#2603

E1123, # issues between pylint and tensorflow since 2.2.0
E1120, # see https://github.com/PyCQA/pylint/issues/3613
E1120, # see pylint#3613

[FORMAT]
max-line-length=100
Expand Down
53 changes: 53 additions & 0 deletions docs/api/deletion_tab.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# DeletionTab

The tabular data Deletion Fidelity metric measures the faithfulness of explanations on tabular data predictions.
This metric computes the capacity of the model to make predictions while only the most important features are not perturbed[^1].


## Score interpretation

The interpretation of the score depends on the score metric you are using to evaluate your model.
- For metrics where the score increases with the performance of the model (such as accuracy).
If explanations are accurate, the score will quickly rise to the score on non-perturbed input.
Thus, in this case, a higher score represent a more accurate explanation.

- For metrics where the score decreases with the performance of the model (such as losses).
If explanations are accurate, the score will quickly fall to the score on non-perturbed input.
Thus, in this case, a lower score represent a more accurate explanation.


## Remarks

This metric only evaluate the order of importance between features.

The parameters metric, steps and max_percentage_perturbed may drastically change the score :

- For inputs with many features, increasing the number of steps will allow you to capture more efficiently the difference between attributions methods.

- The order of importance of features with low importance may not matter, hence, decreasing the max_percentage_perturbed,
may make the score more relevant.

Sometimes, attributions methods also returns negative attributions,
for those methods, do not take the absolute value before computing deletion metrics.
Otherwise, negative attributions may have higher absolute values, and the order of importance between features will change.
Therefore, take those previous remarks into account to get a relevant score.


## Example

```python
from xplique.metrics import DeletionTab
from xplique.attributions import Saliency

# load images, labels and model
# ...
explainer = Saliency(model)
explanations = explainer(inputs, labels)

metric = DeletionTab(model, inputs, labels)
score = metric.evaluate(explanations)
```

{{xplique.metrics.DeletionTab}}

[^1]:[RISE: Randomized Input Sampling for Explanation of Black-box Models (2018)](https://arxiv.org/abs/1806.07421)
53 changes: 53 additions & 0 deletions docs/api/insertion_tab.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# InsertionTab

The tabular data Insertion Fidelity metric measures the faithfulness of explanations on tabular data predictions.
This metric computes the capacity of the model to make predictions while only the most important features are not perturbed[^1].


## Score interpretation

The interpretation of the score depends on the score metric you are using to evaluate your model.
- For metrics where the score increases with the performance of the model (such as accuracy).
If explanations are accurate, the score will quickly rise to the score on non-perturbed input.
Thus, in this case, a higher score represent a more accurate explanation.

- For metrics where the score decreases with the performance of the model (such as losses).
If explanations are accurate, the score will quickly fall to the score on non-perturbed input.
Thus, in this case, a lower score represent a more accurate explanation.


## Remarks

This metric only evaluate the order of importance between features.

The parameters metric, steps and max_percentage_perturbed may drastically change the score :

- For inputs with many features, increasing the number of steps will allow you to capture more efficiently the difference between attributions methods.

- The order of importance of features with low importance may not matter, hence, decreasing the max_percentage_perturbed,
may make the score more relevant.

Sometimes, attributions methods also returns negative attributions,
for those methods, do not take the absolute value before computing insertion metrics.
Otherwise, negative attributions may have higher absolute values, and the order of importance between features will change.
Therefore, take those previous remarks into account to get a relevant score.


## Example

```python
from xplique.metrics import InsertionTab
from xplique.attributions import Saliency

# load images, labels and model
# ...
explainer = Saliency(model)
explanations = explainer(inputs, labels)

metric = InsertionTab(model, inputs, labels)
score = metric.evaluate(explanations)
```

{{xplique.metrics.InsertionTab}}

[^1]:[RISE: Randomized Input Sampling for Explanation of Black-box Models (2018)](https://arxiv.org/abs/1806.07421)
4 changes: 3 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,14 @@ All the attributions method presented below handle both **Classification** and *
\* : See the [Callable documentation](callable.md)

| **Attribution Metrics** | Type of Model | Property | Source |
| :---------------------- | :------------ | :--------------- | :------------------------------------------------------------------------------------ |
|:------------------------| :------------ | :--------------- |:--------------------------------------------------------------------------------------|
| MuFidelity | TF | Fidelity | [Paper](https://arxiv.org/abs/2005.00631) |
| Deletion | TF | Fidelity | [Paper](https://arxiv.org/abs/1806.07421) |
| Insertion | TF | Fidelity | [Paper](https://arxiv.org/abs/1806.07421) |
| Deletion TS | TF | Fidelity | [Paper1](https://arxiv.org/abs/1806.07421) [Paper2](https://arxiv.org/abs/1909.07082) |
| Insertion TS | TF | Fidelity | [Paper1](https://arxiv.org/abs/1806.07421) [Paper2](https://arxiv.org/abs/1909.07082) |
| Deletion Tab | TF | Fidelity | [Paper1](https://arxiv.org/abs/1806.07421) [Paper2](https://arxiv.org/abs/1909.07082) |
| Insertion Tab | TF | Fidelity | [Paper1](https://arxiv.org/abs/1806.07421) [Paper2](https://arxiv.org/abs/1909.07082) |
| Average Stability | TF | Stability | [Paper](https://arxiv.org/abs/2005.00631) |
| MeGe | TF | Representativity | [Paper](https://arxiv.org/abs/2009.04521) |
| ReCo | TF | Consistency | [Paper](https://arxiv.org/abs/2009.04521) |
Expand Down
46 changes: 40 additions & 6 deletions tests/metrics/test_fidelity.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import tensorflow as tf
import numpy as np

from ..utils import generate_model, generate_timeseries_model, generate_data, almost_equal
from xplique.metrics import Insertion, Deletion, MuFidelity, InsertionTS, DeletionTS
from ..utils import generate_model, generate_timeseries_model, generate_regression_model, generate_data, almost_equal
from xplique.metrics import Insertion, Deletion, MuFidelity, InsertionTS, DeletionTS, InsertionTab, DeletionTab


def test_mu_fidelity():
Expand Down Expand Up @@ -51,22 +51,56 @@ def test_perturbation_metrics():
model = generate_timeseries_model(input_shape, nb_labels)
explanations = np.random.uniform(0, 1, x.shape)

def inverse(x):
maximums = x.max(axis=1)
maximums = np.expand_dims(maximums, axis=1)
maximums = np.repeat(maximums, x.shape[1], axis=1)
baselines = maximums - x
return baselines

for step in [-1, 10]:
for baseline_mode in [0.0, "inverse"]:
for baseline_mode in [0.0, inverse]:
for metric in ["loss", "accuracy"]:
score_insertion = InsertionTS(
model, x, y, metric=metric, baseline_mode=baseline_mode,
steps=step, max_percentage_perturbed=0.2,
)(explanations)
).evaluate(explanations)
score_deletion = DeletionTS(
model, x, y, metric=metric, baseline_mode=baseline_mode,
steps=step, max_percentage_perturbed=0.2,
)(explanations)
).evaluate(explanations)

for score in [score_insertion, score_deletion]:
if metric == "loss":
assert 0.0 < score
elif score == "accuracy":
elif metric == "accuracy":
assert 0.0 <= score <= 1.0


def test_regression_metrics():
# ensure we can compute insertion/deletion metric with consistent arguments
input_shape, nb_labels, nb_samples = ((20, 10), 5, 50)
x, y = generate_data(input_shape, nb_labels, nb_samples)
model = generate_regression_model(input_shape, nb_labels)
explanations = np.random.uniform(0, 1, x.shape)

for step in [5, 10]:
for baseline_mode in [0.0, lambda x: x-0.5]:
for metric in ["loss", "accuracy"]:
score_insertion = InsertionTab(
model, x, y, metric=metric, baseline_mode=baseline_mode,
steps=step, max_percentage_perturbed=0.2,
).evaluate(explanations)
score_deletion = DeletionTab(
model, x, y, metric=metric, baseline_mode=baseline_mode,
steps=step, max_percentage_perturbed=0.2,
).evaluate(explanations)

for score in [score_insertion, score_deletion]:
if metric == "loss":
assert 0.0 < score

elif metric == "accuracy":
assert 0.0 <= score <= 1.0


Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def generate_regression_model(features_shape, output_shape=1):
model.add(Dense(4, activation='relu'))
model.add(Dense(4, activation='relu'))
model.add(Dense(output_shape))
model.compile(loss='mean_absolute_error',
optimizer='sgd')
model.compile(loss='mean_absolute_error', optimizer='sgd',
metrics=['accuracy'])
fel-thomas marked this conversation as resolved.
Show resolved Hide resolved

return model

Expand Down
3 changes: 2 additions & 1 deletion xplique/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Explanations Metrics module
"""

from .fidelity import MuFidelity, Deletion, Insertion, DeletionTS, InsertionTS
from .fidelity import MuFidelity, Deletion, Insertion, DeletionTS, InsertionTS, \
DeletionTab, InsertionTab
from .stability import AverageStability
from .representativity import MeGe
111 changes: 93 additions & 18 deletions xplique/metrics/fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def __init__(self,
metric: str = "loss",
batch_size: Optional[int] = 64,
causal_mode: str = "deletion",
baseline_mode: Union[float, str] = 0.0,
baseline_mode: Union[float, Callable] = 0.0,
fel-thomas marked this conversation as resolved.
Show resolved Hide resolved
steps: int = 10,
max_percentage_perturbed: float = 1.0,
): # pylint: disable=R0913
Expand Down Expand Up @@ -484,21 +484,8 @@ def detailed_evaluate(self,
# for each sample, sort by most important features according to the explanation
most_important_features = np.argsort(explanations_flatten, axis=-1)[:, ::-1]

if isinstance(self.baseline_mode, float):
baselines = np.full(self.inputs.shape, self.baseline_mode, dtype=np.float32)
elif self.baseline_mode == "zero":
baselines = np.zeros(self.inputs.shape)
elif self.baseline_mode == "inverse":
time_ax = 1
maximums = self.inputs.max(axis=time_ax)
maximums = np.expand_dims(maximums, axis=time_ax)
maximums = np.repeat(maximums, self.inputs.shape[time_ax], axis=time_ax)
baselines = maximums - self.inputs
elif self.baseline_mode == "negative":
baselines = -self.inputs
else:
raise NotImplementedError(f'Unknown perturbation type `{self.baseline_mode}`.')

baselines = self.baseline_mode(self.inputs) if isfunction(self.baseline_mode) else \
np.full(self.inputs.shape, self.baseline_mode, dtype=np.float32)
baselines_flatten = baselines.reshape(self.inputs_flatten.shape)

steps = np.linspace(0, self.max_nb_perturbed, self.steps+1, dtype=np.int32)
Expand Down Expand Up @@ -572,7 +559,7 @@ def __init__(self,
targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
metric: str = "loss",
batch_size: Optional[int] = 64,
baseline_mode: Union[float, str] = 0.0,
baseline_mode: Union[float, Callable] = 0.0,
steps: int = 10,
max_percentage_perturbed: float = 1.0,
): # pylint: disable=R0913
Expand Down Expand Up @@ -619,9 +606,97 @@ def __init__(self,
targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
metric: str = "loss",
batch_size: Optional[int] = 64,
baseline_mode: Union[float, str] = 0.0,
baseline_mode: Union[float, Callable] = 0.0,
steps: int = 10,
max_percentage_perturbed: float = 1.0,
): # pylint: disable=R0913
super().__init__(model, inputs, targets, metric, batch_size,
"insertion", baseline_mode, steps, max_percentage_perturbed)


class DeletionTab(CausalFidelityTS):
"""
Adaptation of the deletion metric for tabular data.

Ref. Petsiuk & al., RISE: Randomized Input Sampling for Explanation of Black-box Models (2018).
https://arxiv.org/pdf/1806.07421.pdf
Ref. Schlegel et al., Towards a Rigorous Evaluation of XAI Methods (2019).

Parameters
----------
model
Model used for computing metric.
inputs
Input samples under study.
targets
One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample.
metric
The metric used to evaluate the model performance. One of the model metric keys when calling
the evaluate function (e.g 'loss', 'accuracy'...). Default to loss.
batch_size
Number of samples to explain at once, if None compute all at once.
baseline_mode
Value of the baseline state, will be called with the inputs if it is a function.
steps
Number of steps between the start and the end state.
Can be set to -1 for all possible steps to be computed.
max_percentage_perturbed
Maximum percentage of the input perturbed.
""" # pylint: disable=R0913

def __init__(self,
model: tf.keras.Model,
inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
metric: str = "loss",
batch_size: Optional[int] = 64,
baseline_mode: Union[float, Callable] = 0.0,
steps: int = 10,
max_percentage_perturbed: float = 1.0,
):
super().__init__(model, inputs, targets, metric, batch_size,
"deletion", baseline_mode, steps, max_percentage_perturbed)


class InsertionTab(CausalFidelityTS):
"""
Adaptation of the insertion metric for tabular data.

Ref. Petsiuk & al., RISE: Randomized Input Sampling for Explanation of Black-box Models (2018).
https://arxiv.org/pdf/1806.07421.pdf
Ref. Schlegel et al., Towards a Rigorous Evaluation of XAI Methods (2019).

Parameters
----------
model
Model used for computing metric.
inputs
Input samples under study.
targets
One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample.
metric
The metric used to evaluate the model performance. One of the model metric keys when calling
the evaluate function (e.g 'loss', 'accuracy'...). Default to loss.
batch_size
Number of samples to explain at once, if None compute all at once.
baseline_mode
Value of the baseline state, will be called with the inputs if it is a function.
steps
Number of steps between the start and the end state.
Can be set to -1 for all possible steps to be computed.
max_percentage_perturbed
Maximum percentage of the input perturbed.
""" # pylint: disable=R0913

def __init__(self,
model: tf.keras.Model,
inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
metric: str = "loss",
batch_size: Optional[int] = 64,
baseline_mode: Union[float, Callable] = 0.0,
steps: int = 10,
max_percentage_perturbed: float = 1.0,
):
super().__init__(model, inputs, targets, metric, batch_size,
"insertion", baseline_mode, steps, max_percentage_perturbed)
3 changes: 2 additions & 1 deletion xplique/plots/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ def plot_feature_impact(

# add some text for labels and custom y-axis tick labels
axes.set_xlabel('Impact on output')
fel-thomas marked this conversation as resolved.
Show resolved Hide resolved
axes.set_ylabel('')
fel-thomas marked this conversation as resolved.
Show resolved Hide resolved
axes.set_title('Features impact')
axes.set_yticks(y_pos)
axes.set_yticklabels(yticklabels)
axes.legend()

# make the plot prettier
fig.tight_layout()
Expand Down Expand Up @@ -293,6 +293,7 @@ def summary_plot_tabular(

# build the figure
row_height = 0.4
plt.figure()
if plot_size is None:
plt.gcf().set_size_inches(8, nb_features_kept * row_height + 1.5)
elif isinstance(plot_size,(list, tuple)):
Expand Down