Skip to content

Commit

Permalink
feat: add the possibility to get an operator from an enum or a string…
Browse files Browse the repository at this point in the history
…. With this new getter we do not need to rechange the model when using metrics for classification
  • Loading branch information
lucashervier committed Aug 22, 2023
1 parent f5138ae commit 95ac086
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 61 deletions.
44 changes: 42 additions & 2 deletions tests/commons/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
GradCAMPP, Lime, KernelShap, SobolAttributionMethod,
HsicAttributionMethod)
from xplique.commons.operators import (check_operator, predictions_operator, regression_operator,
binary_segmentation_operator, segmentation_operator)
binary_segmentation_operator, segmentation_operator,
classif_metrics_operator)
from xplique.commons.operators import Tasks, get_operator
from xplique.commons.exceptions import InvalidOperatorException
from ..utils import generate_data, generate_regression_model

Expand Down Expand Up @@ -69,10 +71,48 @@ def test_check_operator():
def test_proposed_operators():
# ensure all proposed operators are operators
for operator in [predictions_operator, regression_operator,
binary_segmentation_operator, segmentation_operator]:
binary_segmentation_operator, segmentation_operator,
classif_metrics_operator]:
check_operator(operator)


def test_get_operator():
tasks_name = [task.name for task in Tasks]
assert tasks_name.sort() == ['CLASSIFICATION', 'REGRESSION', 'REGRESSION', \
'BINARY_SEGMENTATION', 'SEGMENTATION'].sort()
# get by enum
assert get_operator(Tasks.CLASSIFICATION) is predictions_operator
assert get_operator(Tasks.CLASSIFICATION, is_for_metric=True) is classif_metrics_operator
assert get_operator(Tasks.REGRESSION) is regression_operator
assert get_operator(Tasks.BINARY_SEGMENTATION) is binary_segmentation_operator
assert get_operator(Tasks.SEGMENTATION) is segmentation_operator

# get by string
assert get_operator("classif") is predictions_operator
assert get_operator("Classif", is_for_metric=True) is classif_metrics_operator
assert get_operator("reg") is regression_operator
assert get_operator("bInary_Seg") is binary_segmentation_operator
assert get_operator("segmentation") is segmentation_operator

# assert a not valid string does not work
try:
get_operator("random")
except ValueError:
pass

# operator must have at least 3 arguments
function_with_2_arguments = lambda x,y: 0

# operator must be Callable
not_a_function = [1, 2, 3]

for operator in [function_with_2_arguments, not_a_function]:
try:
get_operator(operator)
except InvalidOperatorException:
pass


def test_regression_operator():
input_shape, nb_labels, samples = ((10, 10, 1), 10, 20)
x, y = generate_data(input_shape, nb_labels, samples)
Expand Down
22 changes: 9 additions & 13 deletions tests/wrappers/test_pytorch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ def _default_methods_cnn(model):

def _default_methods_regression(model):
return [
Saliency(model),
GradientInput(model),
IntegratedGradients(model),
SmoothGrad(model),
SquareGrad(model),
VarGrad(model),
Occlusion(model, patch_size=1, patch_stride=1),
Lime(model, nb_samples = 20),
KernelShap(model, nb_samples = 20),
Saliency(model, operator="reg"),
GradientInput(model, operator="regression"),
IntegratedGradients(model, operator="regression"),
SmoothGrad(model, operator="regression"),
SquareGrad(model, operator="regression"),
VarGrad(model, operator="regression"),
Occlusion(model, operator="regression", patch_size=1, patch_stride=1),
Lime(model, operator="regression", nb_samples = 20),
KernelShap(model, operator="regression", nb_samples = 20),
]

def generate_torch_model(input_shape=(32, 32, 3), output_shape=10):
Expand Down Expand Up @@ -185,7 +185,3 @@ def test_metric_dense():
score = metric(explanations)
print(f"\n\n\n {type(score)} \n\n\n")
assert type(score) in [np.float32, np.float64, float]

def test_operator():
"""TODO"""
pass
10 changes: 7 additions & 3 deletions xplique/attributions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import tensorflow as tf
import numpy as np

from ..types import Callable, Dict, Tuple, Union, Optional
from ..types import Callable, Dict, Tuple, Union, Optional, OperatorSignature
from ..commons import Tasks, get_operator
from ..commons import (find_layer, tensor_sanitize, get_inference_function,
get_gradient_functions, no_gradients_available)

Expand Down Expand Up @@ -53,7 +54,7 @@ class BlackBoxExplainer(ABC):
_cache_models: Dict[Tuple[int, int], tf.keras.Model] = {}

def __init__(self, model: Callable, batch_size: Optional[int] = 64,
operator: Optional[Callable[[tf.keras.Model, tf.Tensor, tf.Tensor], float]] = None):
operator: Optional[Union[Tasks, str, OperatorSignature]] = None):

if isinstance(model, tf.keras.Model):
try:
Expand All @@ -68,6 +69,9 @@ def __init__(self, model: Callable, batch_size: Optional[int] = 64,

self.batch_size = batch_size

# get the operator
operator = get_operator(operator)

# define the inference function according to the model type
self.inference_function, self.batch_inference_function = \
get_inference_function(model, operator)
Expand Down Expand Up @@ -155,4 +159,4 @@ def __init__(self,
pass

# check and get gradient function from model and operator
self.gradient, self.batch_gradient = get_gradient_functions(model, operator)
self.gradient, self.batch_gradient = get_gradient_functions(model, get_operator(operator))
2 changes: 1 addition & 1 deletion xplique/commons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
find_layer, open_relu_policy
from .tf_operations import repeat_labels, batch_tensor
from .callable_operations import predictions_one_hot_callable
from .operators import check_operator, operator_batching,\
from .operators import Tasks, get_operator, check_operator, operator_batching,\
get_inference_function, get_gradient_functions
from .exceptions import no_gradients_available, raise_invalid_operator
from .forgrad import forgrad
166 changes: 131 additions & 35 deletions xplique/commons/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
"""

import inspect
from enum import Enum, auto

import tensorflow as tf

from ..types import Callable, Optional
from ..types import Callable, Optional, Union, OperatorSignature
from .exceptions import raise_invalid_operator, no_gradients_available
from .callable_operations import predictions_one_hot_callable


@tf.function
def predictions_operator(model: Callable,
inputs: tf.Tensor,
Expand All @@ -34,6 +35,31 @@ def predictions_operator(model: Callable,
scores = tf.reduce_sum(model(inputs) * targets, axis=-1)
return scores

@tf.function
def classif_metrics_operator(model: Callable,
inputs: tf.Tensor,
targets: tf.Tensor) -> tf.Tensor:
"""
Compute predictions scores, only for the label class, for a batch of samples. However, this time
softmax or sigmoid are needed to correctly compute metrics this time while it was remove to
compute attributions values so we add it here.
Parameters
----------
model
Model used for computing predictions.
inputs
Input samples to be explained.
targets
One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample.
Returns
-------
scores
Probability scores computed, only for the label class.
"""
scores = tf.reduce_sum(tf.nn.softmax(model(inputs)) * targets, axis=-1)
return scores

@tf.function
def regression_operator(model: Callable,
Expand Down Expand Up @@ -111,6 +137,109 @@ def segmentation_operator(model: Callable,
scores = tf.reduce_sum(model(inputs) * targets, axis=(1, 2, 3))
return scores

class Tasks(Enum):
"""
Enumeration of different tasks for which we have defined operators
"""
CLASSIFICATION = auto()
REGRESSION = auto()
BINARY_SEGMENTATION = auto()
SEGMENTATION = auto()

enum_to_method = {
Tasks.CLASSIFICATION: predictions_operator,
Tasks.REGRESSION: regression_operator,
Tasks.BINARY_SEGMENTATION: binary_segmentation_operator,
Tasks.SEGMENTATION: segmentation_operator
}

def check_operator(operator: Callable):
"""
Check if the operator is valid g(f, x, y) -> tf.Tensor
and raise an exception and return true if so.
Parameters
----------
operator
Operator to check
Returns
-------
is_valid
True if the operator is valid, False otherwise.
"""
# handle tf functions
# pylint: disable=protected-access
if hasattr(operator, '_python_function'):
return check_operator(operator._python_function)

# the operator must be callable
# pylint: disable=isinstance-second-argument-not-valid-type
if not isinstance(operator, Callable):
raise_invalid_operator()

# the operator should take at least three arguments
args = inspect.getfullargspec(operator).args
if len(args) < 3:
raise_invalid_operator()

return True

def get_operator(
operator: Optional[Union[Tasks, str, OperatorSignature]],
is_for_metric: bool = False):
"""
This function allows to retrieve an operator from: a Tasks, a task name. If the operator
is a custom one, we simply check if its signature is correct
Parameters
----------
operator
An operator from the Tasks enum or the task name or a custom operator. If None, use a
classification operator.
is_for_metric
A boolean value that specify if we want the operator for computation of a metric.
Especially, this is relevant for classification as we need the softmax for the metric while
we don't want it for computing explanations.
Returns
-------
operator
The operator requested
"""
# case when no operator is provided
if operator is None:
if is_for_metric:
return classif_metrics_operator
return predictions_operator

# case when the query is a string
if isinstance(operator, str):

# transform the string to one of the Tasks enum if it exists
if operator.upper() in 'CLASSIFICATION':
operator = Tasks.CLASSIFICATION
elif operator.upper() in 'REGRESSION':
operator = Tasks.REGRESSION
elif operator.upper() in 'SEGMENTATION':
operator = Tasks.SEGMENTATION
elif operator.upper() in 'BINARY_SEGMENTATION':
operator = Tasks.BINARY_SEGMENTATION
else:
valid_op_name = ', '.join([operator.name for operator in Tasks])
raise ValueError(
f"Invalid operators name: {operator}. "
f"Availables operators are: {valid_op_name}."
)

# case when the query belong to the Tasks enum
if operator in Tasks.__members__.values():
if operator is Tasks.CLASSIFICATION and is_for_metric:
return classif_metrics_operator
return enum_to_method[operator]

assert check_operator(operator)
return operator

def get_gradient_of_operator(operator):
"""
Expand Down Expand Up @@ -167,39 +296,6 @@ def batched_operator(model, inputs, targets, batch_size=None):
return batched_operator


def check_operator(operator: Callable):
"""
Check if the operator is valid g(f, x, y) -> tf.Tensor
and raise an exception and return true if so.
Parameters
----------
operator
Operator to check
Returns
-------
is_valid
True if the operator is valid, False otherwise.
"""
# handle tf functions
# pylint: disable=protected-access
if hasattr(operator, '_python_function'):
return check_operator(operator._python_function)

# the operator must be callable
# pylint: disable=isinstance-second-argument-not-valid-type
if not isinstance(operator, Callable):
raise_invalid_operator()

# the operator should take at least three arguments
args = inspect.getfullargspec(operator).args
if len(args) < 3:
raise_invalid_operator()

return True


batch_predictions = operator_batching(predictions_operator)
gradients_predictions = get_gradient_of_operator(predictions_operator)
batch_gradients_predictions = operator_batching(gradients_predictions)
Expand Down
11 changes: 7 additions & 4 deletions xplique/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import tensorflow as tf
import numpy as np

from ..commons import numpy_sanitize, get_inference_function
from ..types import Callable, Optional, Union
from ..commons import Tasks, numpy_sanitize, get_operator, get_inference_function
from ..types import Callable, Optional, Union, OperatorSignature


class BaseAttributionMetric(ABC):
Expand Down Expand Up @@ -96,17 +96,20 @@ class ExplanationMetric(BaseAttributionMetric, ABC):
with f the model, x the inputs and y the targets. If None, use the standard
operator g(f, x, y) = f(x)[y].
"""

# mettre union, operator par défaut c'est de la classif (ENUM)
def __init__(self,
model: tf.keras.Model,
inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
targets: Optional[Union[tf.Tensor, np.ndarray]] = None,
batch_size: Optional[int] = 64,
operator: Optional[Callable] = None,
operator: Optional[Union[Tasks, str, OperatorSignature]] = None,
):
# pylint: disable=R0913
super().__init__(model, inputs, targets, batch_size)

# get the operator
operator = get_operator(operator, is_for_metric=True)

# define the inference function according to the model type
self.inference_function, self.batch_inference_function = \
get_inference_function(model, operator)
Expand Down
1 change: 1 addition & 0 deletions xplique/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
"""

from typing import Union, Tuple, List, Callable, Dict, Optional, Any
from .custom_type import OperatorSignature
7 changes: 7 additions & 0 deletions xplique/types/custom_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
Module for custom types or signature
"""
from typing import Callable
import tensorflow as tf

OperatorSignature = Callable[[tf.keras.Model, tf.Tensor, tf.Tensor], float]

0 comments on commit 95ac086

Please sign in to comment.