Skip to content

Commit

Permalink
feat: add the possibility to add an activation layer (softmax or sigm…
Browse files Browse the repository at this point in the history
…oid) to the model when computing metrics (but not for generating explanations), add the corrsponding tests and documentation
  • Loading branch information
lucashervier committed Aug 24, 2023
1 parent dc8484f commit e60a5ff
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 23 deletions.
43 changes: 31 additions & 12 deletions docs/api/metrics/api_metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,25 @@ As pointed out by [Petsiuk et al.](http://arxiv.org/abs/1806.07421) most explana

All metrics inherits from the base class `BaseAttributionMetric` which has the following `__init__` arguments:

- \- `model`: The model from which we want to obtain explanations
- \- `inputs`: Input samples to be explained
- `model`: The model from which we want to obtain explanations
- `inputs`: Input samples to be explained

!!!warning
Inputs should be the same as defined in the attribution's [API Description](https://deel-ai.github.io/xplique/api/attributions/api_attributions/)
!!!info
Inputs should be the same as defined in the [model's documentation](../../attributions/model)

- \- `targets`: One-hot encoding of the model's output from which an explanation is desired
- `targets`: Specify the kind of explanations we want depending on the task at end (e.g. a one-hot encoding of a class of interest, a difference to a ground-truth value..)

!!!warning
Idem
!!!info
Targets should be the same as defined in the [model's documentation](../../attributions/model)

- \- `batch_size`
- `batch_size`

- `activation`: A string that belongs to [None, 'sigmoid', 'softmax']. See the [dedicated section](#activation) for details

Then we can distinguish two category of metrics:

- \- Those which only need the attribution ouputs of an explainer: `ExplanationMetric`
- \- Those which need the explainer: `ExplainerMetric`
- Those which only need the attribution ouputs of an explainer: `ExplanationMetric`
- Those which need the explainer: `ExplainerMetric`

### `ExplanationMetric`

Expand All @@ -40,7 +42,7 @@ Those metrics are agnostic of the explainer used and rely only on the attributio

All metrics inheriting from this class have another argument in their `__init__` method:

- \- `operator`: Optionnal function wrapping the model. It can be seen as a metric which allow to evaluate model evolution. For more details, see the attribution's [API Description](../attributions/api_attributions/) and the [operator documentation](../attributions/operator/)
- `operator`: Optionnal function wrapping the model. It can be seen as a metric which allow to evaluate model evolution. For more details, see the attribution's [API Description](../../attributions/api_attributions/) and the [operator documentation](../../attributions/operator/)

All metrics inheriting from this class have to define a method `evaluate` which will take as input the `attributions` given by an explainer. Those attributions should correspond to the `model`, `inputs` and `targets` used to build the metric object.

Expand All @@ -60,18 +62,35 @@ Those metrics will not assess the quality of the explanations provided but (also
All metrics inheriting from this class have to define a method `evaluate` which will take as input the `explainer` evaluated.

!!!info
It is even more important that `inputs` and `targets` are the same as defined in the attribution's [API Description](../attributions/api_attributions/)
It is even more important that `inputs` and `targets` are the same as defined in the attribution's [API Description](../../attributions/api_attributions/)

Currently, there is only one Stability metric inheriting from this class:

| Metric Name (Stability) |Notebook |
|:----------------------- |:----------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| AverageStability | **(WIP)** |

## Activation

This parameter specify if an additional activation layer should be added once a model has been called on the inputs when you have to compute the metric.

Indeed, most of the times it is recommended when you instantiate an **explainer** (*i.e.* an attribution methods) to provide a model which gives logits as explaining the logits is to explain the class, while explaining the softmax is to explain why this class rather than another.

However, when you compute metrics some were thought to measure a "drop of probability" when you occlude the "most relevant" part of an input. Thus, once you get your explanations (computed from the logits), you might need to have access to a probability score of occluded inputs of a specific class, thus to have access to the logits after a `softmax` or `sigmoid` layer.

Consequently, we add this `activation` parameter so one can provide a model that predicts logits but add an activation layer for the purpose of having probability when using a metric method.

The default behavior is to compute the metric without adding any activation layer to the model.

!!!note
In our opinion, there is no consensus at present concerning the "best practices". Should the model used to generate the explanations be exactly the same for generating the metrics or it should depend on the metric ? As we do not claim to have an answer (yet!), we choose to let the user as much flexibility as possible!

## Other Metrics

A Representatibity metric: [MeGe](https://arxiv.org/abs/2009.04521) is also available. Documentation about it should be added soon.

## Notebooks

- [**Metrics**: Getting started](https://colab.research.google.com/drive/1WEpVpFSq-oL1Ejugr8Ojb3tcbqXIOPBg) <sub> [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1WEpVpFSq-oL1Ejugr8Ojb3tcbqXIOPBg) </sub>

- [**Metrics**: With Pytorch's model](https://colab.research.google.com/drive/16bEmYXzLEkUWLRInPU17QsodAIbjdhGP) <sub> [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/16bEmYXzLEkUWLRInPU17QsodAIbjdhGP) </sub>
29 changes: 28 additions & 1 deletion tests/metrics/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,32 @@ def test_common():
assert hasattr(metric, 'inference_function')
assert hasattr(metric, 'batch_inference_function')
score = metric(explanations)
print(f"\n\n\n {type(score)} \n\n\n")
assert type(score) in [np.float32, np.float64, float]

def test_add_activation():
"""Test that adding a softmax or sigmoid layer still works"""
input_shape, nb_labels, samples = ((16, 16, 3), 5, 8)
x, y = generate_data(input_shape, nb_labels, samples)
model = generate_model(input_shape, nb_labels)

explainers = _default_methods(model)

activations = ["sigmoid", "softmax"]

for explainer in explainers:
explanations = explainer(x, y)
for activation in activations:
metrics = [
Deletion(model, x, y, steps=3, activation=activation),
Insertion(model, x, y, steps=3, activation=activation),
MuFidelity(model, x, y, nb_samples=3, activation=activation),
]
for metric in metrics:
assert hasattr(metric, 'evaluate')
if isinstance(metric, ExplainerMetric):
score = metric(explainer)
else:
assert hasattr(metric, 'inference_function')
assert hasattr(metric, 'batch_inference_function')
score = metric(explanations)
assert type(score) in [np.float32, np.float64, float]
34 changes: 29 additions & 5 deletions xplique/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,28 @@ class BaseAttributionMetric(ABC):
One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample.
batch_size
Number of samples to evaluate at once, if None compute all at once.
activation
A string that belongs to [None, 'sigmoid', 'softmax']. Specify if we should add
an activation layer once the model has been called. It is useful, for instance
if you want to measure a 'drop of probability' by adding a sigmoid or softmax
after getting your logits. If None does not add a layer to your model.
"""

def __init__(self,
model: Callable,
inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
batch_size: Optional[int] = 64):
self.model = model
batch_size: Optional[int] = 64,
activation: Optional[str] = None):
if activation is None:
self.model = model
else:
assert activation in ['sigmoid', 'softmax'], \
"activation must be in ['sigmoid', 'softmax']"
if activation == 'sigmoid':
self.model = lambda x: tf.nn.sigmoid(model(x))
else:
self.model = lambda x: tf.nn.softmax(model(x))
self.inputs, self.targets = numpy_sanitize(inputs, targets)
self.batch_size = batch_size

Expand All @@ -51,6 +65,11 @@ class ExplainerMetric(BaseAttributionMetric, ABC):
One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample.
batch_size
Number of samples to evaluate at once, if None compute all at once.
activation
A string that belongs to [None, 'sigmoid', 'softmax']. Specify if we should add
an activation layer once the model has been called. It is useful, for instance
if you want to measure a 'drop of probability' by adding a sigmoid or softmax
after getting your logits. If None does not add a layer to your model.
"""

@abstractmethod
Expand Down Expand Up @@ -95,21 +114,26 @@ class ExplanationMetric(BaseAttributionMetric, ABC):
Function g to explain, g take 3 parameters (f, x, y) and should return a scalar,
with f the model, x the inputs and y the targets. If None, use the standard
operator g(f, x, y) = f(x)[y].
activation
A string that belongs to [None, 'sigmoid', 'softmax']. Specify if we should add
an activation layer once the model has been called. It is useful, for instance
if you want to measure a 'drop of probability' by adding a sigmoid or softmax
after getting your logits. If None does not add a layer to your model.
"""
# mettre union, operator par défaut c'est de la classif (ENUM)
def __init__(self,
model: tf.keras.Model,
inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
batch_size: Optional[int] = 64,
operator: Optional[Union[Tasks, str, OperatorSignature]] = None,
activation: Optional[str] = None,
):
# pylint: disable=R0913
super().__init__(model, inputs, targets, batch_size)
super().__init__(model, inputs, targets, batch_size, activation)

# define the inference function according to the model type
self.inference_function, self.batch_inference_function = \
get_inference_function(model, operator, is_for_metric=True)
get_inference_function(model, operator)

@abstractmethod
def evaluate(self,
Expand Down
34 changes: 29 additions & 5 deletions xplique/metrics/fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class MuFidelity(ExplanationMetric):
Function g to explain, g take 3 parameters (f, x, y) and should return a scalar,
with f the model, x the inputs and y the targets. If None, use the standard
operator g(f, x, y) = f(x)[y].
activation
A string that belongs to [None, 'sigmoid', 'softmax']. Specify if we should add
an activation layer once the model has been called. It is useful, for instance
if you want to measure a 'drop of probability' by adding a sigmoid or softmax
after getting your logits. If None does not add a layer to your model.
"""

def __init__(self,
Expand All @@ -64,9 +69,10 @@ def __init__(self,
subset_percent: float = 0.2,
baseline_mode: Union[Callable, float] = 0.0,
nb_samples: int = 200,
operator: Optional[Callable] = None,):
operator: Optional[Callable] = None,
activation: Optional[str] = None):
# pylint: disable=R0913
super().__init__(model, inputs, targets, batch_size, operator)
super().__init__(model, inputs, targets, batch_size, operator, activation)
self.grid_size = grid_size
self.subset_percent = subset_percent
self.baseline_mode = baseline_mode
Expand Down Expand Up @@ -166,6 +172,11 @@ class CausalFidelity(ExplanationMetric):
Function g to explain, g take 3 parameters (f, x, y) and should return a scalar,
with f the model, x the inputs and y the targets. If None, use the standard
operator g(f, x, y) = f(x)[y].
activation
A string that belongs to [None, 'sigmoid', 'softmax']. Specify if we should add
an activation layer once the model has been called. It is useful, for instance
if you want to measure a 'drop of probability' by adding a sigmoid or softmax
after getting your logits. If None does not add a layer to your model.
"""

def __init__(self,
Expand All @@ -178,9 +189,10 @@ def __init__(self,
steps: int = 10,
max_percentage_perturbed: float = 1.0,
operator: Optional[Callable] = None,
activation: Optional[str] = None
):
# pylint: disable=R0913
super().__init__(model, inputs, targets, batch_size, operator)
super().__init__(model, inputs, targets, batch_size, operator, activation)
self.causal_mode = causal_mode
self.baseline_mode = baseline_mode

Expand Down Expand Up @@ -330,6 +342,11 @@ class Deletion(CausalFidelity):
Function g to explain, g take 3 parameters (f, x, y) and should return a scalar,
with f the model, x the inputs and y the targets. If None, use the standard
operator g(f, x, y) = f(x)[y].
activation
A string that belongs to [None, 'sigmoid', 'softmax']. Specify if we should add
an activation layer once the model has been called. It is useful, for instance
if you want to measure a 'drop of probability' by adding a sigmoid or softmax
after getting your logits. If None does not add a layer to your model.
"""

def __init__(self,
Expand All @@ -341,10 +358,11 @@ def __init__(self,
steps: int = 10,
max_percentage_perturbed: float = 1.0,
operator: Optional[Callable] = None,
activation: Optional[str] = None
):
super().__init__(model, inputs, targets, batch_size, "deletion",
baseline_mode, steps, max_percentage_perturbed,
operator)
operator, activation)


class Insertion(CausalFidelity):
Expand Down Expand Up @@ -377,6 +395,11 @@ class of interest as pixels are added according to the generated importance map.
Function g to explain, g take 3 parameters (f, x, y) and should return a scalar,
with f the model, x the inputs and y the targets. If None, use the standard
operator g(f, x, y) = f(x)[y].
activation
A string that belongs to [None, 'sigmoid', 'softmax']. Specify if we should add
an activation layer once the model has been called. It is useful, for instance
if you want to measure a 'drop of probability' by adding a sigmoid or softmax
after getting your logits. If None does not add a layer to your model.
"""

def __init__(self,
Expand All @@ -388,7 +411,8 @@ def __init__(self,
steps: int = 10,
max_percentage_perturbed: float = 1.0,
operator: Optional[Callable] = None,
activation: Optional[str] = None
):
super().__init__(model, inputs, targets, batch_size, "insertion",
baseline_mode, steps, max_percentage_perturbed,
operator)
operator, activation)

0 comments on commit e60a5ff

Please sign in to comment.