Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lucas/wip #48

Merged
merged 4 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 15 additions & 6 deletions tests/attributions/test_occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,24 @@ def test_output_shape():


def test_polymorphic_parameters():
"""Ensure we could pass tuple or int to define patch parameters"""
"""Ensure we could pass tuple or int to define patch parameters when inputs are images"""
s = 3
model = generate_model()

occlusion_int = Occlusion(model, patch_size=s, patch_stride=s)
occlusion_tuple = Occlusion(model, patch_size=(s, s), patch_stride=(s, s))
input_shapes = [(28, 28, 1), (32, 32, 3)]
nb_labels = 10

for input_shape in input_shapes:
features, targets = generate_data(input_shape, nb_labels, 20)
model = generate_model(input_shape, nb_labels)

occlusion_int = Occlusion(model, patch_size=s, patch_stride=s)
occlusion_tuple = Occlusion(model, patch_size=(s, s), patch_stride=(s, s))

occlusion_int(features, targets)
occlusion_tuple(features, targets)

assert occlusion_int.patch_size == occlusion_tuple.patch_size
assert occlusion_int.patch_stride == occlusion_tuple.patch_stride
assert occlusion_int.patch_size == occlusion_tuple.patch_size
assert occlusion_int.patch_stride == occlusion_tuple.patch_stride
fel-thomas marked this conversation as resolved.
Show resolved Hide resolved


def test_mask_generator():
Expand Down
71 changes: 48 additions & 23 deletions xplique/attributions/occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,13 @@ class Occlusion(BlackBoxExplainer):
def __init__(self,
model: Callable,
batch_size: Optional[int] = 32,
patch_size: Union[int, Tuple[int, int]] = (3, 3),
patch_stride: Union[int, Tuple[int, int]] = (3, 3),
patch_size: Union[int, Tuple[int, int]] = 3,
patch_stride: Union[int, Tuple[int, int]] = 3,
occlusion_value: float = 0.5):
super().__init__(model, batch_size)

self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
self.patch_stride = patch_stride if isinstance(patch_stride, tuple) \
else (patch_stride, patch_stride)
self.patch_size = patch_size
self.patch_stride = patch_stride
self.occlusion_value = occlusion_value

@sanitize_input_output
Expand All @@ -66,6 +65,16 @@ def explain(self,
explanations
Occlusion sensitivity, same shape as the inputs, except for the channels.
"""

# check if data is tabular
is_tabular = (len(inputs.shape)==2)
fel-thomas marked this conversation as resolved.
Show resolved Hide resolved

if not is_tabular:
if not isinstance(self.patch_size, tuple):
self.patch_size = (self.patch_size, self.patch_size)
if not isinstance(self.patch_stride, tuple):
self.patch_stride = (self.patch_stride, self.patch_stride)
fel-thomas marked this conversation as resolved.
Show resolved Hide resolved

sensitivity = None
batch_size = self.batch_size or len(inputs)

Expand All @@ -88,9 +97,9 @@ def explain(self,
return sensitivity

@staticmethod
def _get_masks(input_shape: Tuple[int, int, int],
patch_size: Tuple[int, int],
patch_stride: Tuple[int, int]) -> tf.Tensor:
def _get_masks(input_shape: Union[Tuple[int, int, int], Tuple[int, int], Tuple[int]],
patch_size: Union[int,Tuple[int, int]],
patch_stride: Union[int,Tuple[int, int]]) -> tf.Tensor:
"""
Create all the possible patches for the given configuration.

Expand All @@ -110,16 +119,28 @@ def _get_masks(input_shape: Tuple[int, int, int],
"""
masks = []

x_anchors = [x * patch_stride[0] for x in
range(0, ceil((input_shape[0] - patch_size[0] + 1) / patch_stride[0]))]
y_anchors = [y * patch_stride[1] for y in
range(0, ceil((input_shape[1] - patch_size[1] + 1) / patch_stride[1]))]
# check if we have tabular data
is_tabular = (len(input_shape)==1)

if is_tabular:
x_anchors = [x * patch_stride for x in
range(0, ceil((input_shape[0] - patch_size + 1) / patch_stride))]

for x_anchor in x_anchors:
for y_anchor in y_anchors:
mask = np.zeros(input_shape[:2], dtype=bool)
mask[x_anchor:x_anchor + patch_size[0], y_anchor:y_anchor + patch_size[1]] = 1
for x_anchor in x_anchors:
mask = np.zeros(input_shape, dtype=bool)
mask[x_anchor:x_anchor + patch_size] = 1
masks.append(mask)
else:
x_anchors = [x * patch_stride[0] for x in
range(0, ceil((input_shape[0] - patch_size[0] + 1) / patch_stride[0]))]
y_anchors = [y * patch_stride[1] for y in
range(0, ceil((input_shape[1] - patch_size[1] + 1) / patch_stride[1]))]

for x_anchor in x_anchors:
for y_anchor in y_anchors:
mask = np.zeros(input_shape[:2], dtype=bool)
mask[x_anchor:x_anchor + patch_size[0], y_anchor:y_anchor + patch_size[1]] = 1
masks.append(mask)

return tf.cast(masks, dtype=tf.bool)

Expand All @@ -145,15 +166,17 @@ def _apply_masks(inputs: tf.Tensor,
occluded_inputs
All the occluded combinations for each inputs.
"""

masks = tf.expand_dims(masks, axis=-1)
masks = tf.repeat(masks, repeats=inputs.shape[-1], axis=-1)

occluded_inputs = tf.expand_dims(inputs, axis=1)
occluded_inputs = tf.repeat(occluded_inputs, repeats=masks.shape[0], axis=1)

occluded_inputs = occluded_inputs * tf.cast(tf.logical_not(masks), tf.float32) + tf.cast(
masks, tf.float32) * occlusion_value
# check if inputs shape is (N, W, H, C)
has_channels = (len(inputs.shape)>3)
fel-thomas marked this conversation as resolved.
Show resolved Hide resolved
if has_channels:
masks = tf.expand_dims(masks, axis=-1)
masks = tf.repeat(masks, repeats=inputs.shape[-1], axis=-1)

occluded_inputs = occluded_inputs * tf.cast(tf.logical_not(masks), tf.float32)
occluded_inputs += tf.cast(masks, tf.float32) * occlusion_value

occluded_inputs = tf.reshape(occluded_inputs, (-1, *occluded_inputs.shape[2:]))

Expand Down Expand Up @@ -186,7 +209,9 @@ def _compute_sensitivity(baseline_scores: tf.Tensor,
occluded_scores = tf.reshape(occluded_scores, (-1, masks.shape[0]))

score_delta = baseline_scores - occluded_scores
score_delta = tf.reshape(score_delta, (*score_delta.shape, 1, 1))
# reshape the delta score to fit masks
score_delta = tf.reshape(score_delta, (*score_delta.shape, *(1,) * len(masks.shape[1:])))

sensitivity = score_delta * tf.cast(masks, tf.float32)
sensitivity = tf.reduce_sum(sensitivity, axis=1)

Expand Down