Skip to content

Commit

Permalink
fidelity: add causal fidelity tabular
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidPetiteau committed Jan 25, 2022
1 parent e4afe76 commit 318a019
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 1 deletion.
3 changes: 2 additions & 1 deletion xplique/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Explanations Metrics module
"""

from .fidelity import MuFidelity, Deletion, Insertion, DeletionTS, InsertionTS
from .fidelity import MuFidelity, Deletion, Insertion, DeletionTS, InsertionTS, \
DeletionTab, InsertionTab
from .stability import AverageStability
from .representativity import MeGe
243 changes: 243 additions & 0 deletions xplique/metrics/fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,3 +625,246 @@ def __init__(self,
): # pylint: disable=R0913
super().__init__(model, inputs, targets, metric, batch_size,
"insertion", baseline_mode, steps, max_percentage_perturbed)


class CausalFidelityTab(ExplanationMetric):
"""
Used to compute the insertion and deletion metrics for tabular data explanations.
Parameters
----------
model
Model used for computing metric.
inputs
Input samples under study. (n*t*d)
targets
One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample.
metric
The metric used to evaluate the model performance. One of the model metric keys when calling
the evaluate function (e.g 'loss', 'accuracy'...). Default to loss.
batch_size
Number of samples to explain at once, if None compute all at once.
causal_mode
If 'insertion', the path is baseline to original tabular data,
for 'deletion' the path is original tabular data to baseline.
baseline_mode
Value of the baseline state, associated perturbation for strings.
steps
Number of steps between the start and the end state.
Can be set to -1 for all possible steps to be computed.
max_percentage_perturbed
Maximum percentage of the input perturbed.
"""

# pylint: disable=too-many-instance-attributes

def __init__(self,
model: tf.keras.Model,
inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
metric: str = "loss",
batch_size: Optional[int] = 64,
causal_mode: str = "deletion",
baseline_mode: Union[float, Callable] = 0.0,
steps: int = 10,
max_percentage_perturbed: float = 1.0,
): # pylint: disable=R0913
super().__init__(model, inputs, targets, batch_size)

self.baseline_mode = baseline_mode
self.causal_mode = causal_mode
assert metric == "loss" or metric in self.model.metrics_names
self.metric = metric

self.nb_samples = inputs.shape[0]
self.nb_features = np.prod(inputs.shape[1:])
self.inputs_flatten = inputs.reshape(
(self.nb_samples, self.nb_features, 1)
)

assert 0 < max_percentage_perturbed <= 1, \
"max_percentage_perturbed should be between 0 and 1"
self.max_nb_perturbed = int(self.nb_features * max_percentage_perturbed)

if steps == -1:
steps = self.max_nb_perturbed
self.steps = steps

def evaluate(self,
explanations: Union[tf.Tensor, np.ndarray]) -> float:
"""
Evaluate the causal score for tabular data explanations.
Parameters
----------
explanations
Explanation for the inputs, labels to evaluate.
Returns
-------
causal_score
Metric score (for interpretation, see score interpretation in the documentation).
"""
scores_dict = self.detailed_evaluate(explanations)

# compute auc with trapeze
np_scores = np.array(list(scores_dict.values()))
auc = np.mean(np_scores[:-1] + np_scores[1:]) * 0.5

return auc

def detailed_evaluate(self,
explanations: Union[tf.Tensor, np.ndarray]) -> Dict[int, float]:
"""
Evaluate model performance for successive perturbations of an input.
Used to compute causal score for tabular data explanations.
The successive perturbations in the Insertion and Deletion metrics create a list of scores.
This list of scores make a score evolution curve.
The AUC of such curve is used as an explanation metric.
However, the curve in itself is rich in information,
its visualization and interpretation can bring further comprehension
on the explanation and the model.
Therefore this method was added so that it is possible to construct such curves.
Parameters
----------
explanations
Explanation for the inputs, labels to evaluate.
Returns
-------
causal_score_dict
Dictionary of scores obtain for different perturbations
Keys are the steps, i.e the number of features perturbed
Values are the scores, the score of the model
on the inputs with the corresponding number of features perturbed
"""
explanations = np.array(explanations)
assert explanations.shape == self.inputs.shape, "The number of explanations must be the " \
"same as the number of inputs"

explanations_flatten = explanations.reshape((len(explanations), -1))

# for each sample, sort by most important features according to the explanation
most_important_features = np.argsort(explanations_flatten, axis=-1)[:, ::-1]

baselines = self.baseline_mode(self.inputs) if isfunction(self.baseline_mode) else \
np.ones_like(self.inputs, dtype=np.float32) * self.baseline_mode
baselines_flatten = baselines.reshape(self.inputs_flatten.shape)

steps = np.linspace(0, self.max_nb_perturbed, self.steps+1, dtype=np.int32)

if self.causal_mode == "deletion":
start = self.inputs_flatten
end = baselines_flatten
elif self.causal_mode == "insertion":
start = baselines_flatten
end = self.inputs_flatten
else:
raise NotImplementedError(f'Unknown causal mode `{self.causal_mode}`.')

scores_dict = {}
for step in steps:
ids_to_flip = most_important_features[:, :step]
perturbed_inputs = start.copy()

for i, ids in enumerate(ids_to_flip):
perturbed_inputs[i, ids] = end[i, ids]

perturbed_inputs = perturbed_inputs.reshape((-1, *self.inputs.shape[1:]))

score = self.model.evaluate(perturbed_inputs, self.targets,
self.batch_size, verbose=0,
return_dict=True)
scores_dict[step] = score[self.metric]

return scores_dict


class DeletionTab(CausalFidelityTab):
"""
Adaptation of the insertion metric for tabular data.
Ref. Petsiuk & al., RISE: Randomized Input Sampling for Explanation of Black-box Models (2018).
https://arxiv.org/pdf/1806.07421.pdf
Ref. Schlegel et al., Towards a Rigorous Evaluation of XAI Methods (2019).
Parameters
----------
model
Model used for computing metric.
inputs
Input samples under study.
targets
One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample.
metric
The metric used to evaluate the model performance. One of the model metric keys when calling
the evaluate function (e.g 'loss', 'accuracy'...). Default to loss.
batch_size
Number of samples to explain at once, if None compute all at once.
baseline_mode
Value of the baseline state, will be called with the inputs if it is a function.
steps
Number of steps between the start and the end state.
Can be set to -1 for all possible steps to be computed.
max_percentage_perturbed
Maximum percentage of the input perturbed.
""" # pylint: disable=R0913

def __init__(self,
model: tf.keras.Model,
inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
metric: str = "loss",
batch_size: Optional[int] = 64,
baseline_mode: Union[float, str] = 0.0,
steps: int = 10,
max_percentage_perturbed: float = 1.0,
):
super().__init__(model, inputs, targets, metric, batch_size,
"deletion", baseline_mode, steps, max_percentage_perturbed)


class InsertionTab(CausalFidelityTab):
"""
Adaptation of the insertion metric for tabular data.
Ref. Petsiuk & al., RISE: Randomized Input Sampling for Explanation of Black-box Models (2018).
https://arxiv.org/pdf/1806.07421.pdf
Ref. Schlegel et al., Towards a Rigorous Evaluation of XAI Methods (2019).
Parameters
----------
model
Model used for computing metric.
inputs
Input samples under study.
targets
One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample.
metric
The metric used to evaluate the model performance. One of the model metric keys when calling
the evaluate function (e.g 'loss', 'accuracy'...). Default to loss.
batch_size
Number of samples to explain at once, if None compute all at once.
baseline_mode
Value of the baseline state, will be called with the inputs if it is a function.
steps
Number of steps between the start and the end state.
Can be set to -1 for all possible steps to be computed.
max_percentage_perturbed
Maximum percentage of the input perturbed.
""" # pylint: disable=R0913

def __init__(self,
model: tf.keras.Model,
inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
metric: str = "loss",
batch_size: Optional[int] = 64,
baseline_mode: Union[float, str] = 0.0,
steps: int = 10,
max_percentage_perturbed: float = 1.0,
):
super().__init__(model, inputs, targets, metric, batch_size,
"insertion", baseline_mode, steps, max_percentage_perturbed)

0 comments on commit 318a019

Please sign in to comment.