Skip to content

Commit

Permalink
occlusion: adapt api for tabular data
Browse files Browse the repository at this point in the history
  • Loading branch information
lucashervier committed Jul 27, 2021
1 parent 65676ec commit feea923
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 29 deletions.
21 changes: 15 additions & 6 deletions tests/attributions/test_occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,24 @@ def test_output_shape():


def test_polymorphic_parameters():
"""Ensure we could pass tuple or int to define patch parameters"""
"""Ensure we could pass tuple or int to define patch parameters when inputs are images"""
s = 3
model = generate_model()

occlusion_int = Occlusion(model, patch_size=s, patch_stride=s)
occlusion_tuple = Occlusion(model, patch_size=(s, s), patch_stride=(s, s))
input_shapes = [(28, 28, 1), (32, 32, 3)]
nb_labels = 10

for input_shape in input_shapes:
features, targets = generate_data(input_shape, nb_labels, 20)
model = generate_model(input_shape, nb_labels)

occlusion_int = Occlusion(model, patch_size=s, patch_stride=s)
occlusion_tuple = Occlusion(model, patch_size=(s, s), patch_stride=(s, s))

occlusion_int(features, targets)
occlusion_tuple(features, targets)

assert occlusion_int.patch_size == occlusion_tuple.patch_size
assert occlusion_int.patch_stride == occlusion_tuple.patch_stride
assert occlusion_int.patch_size == occlusion_tuple.patch_size
assert occlusion_int.patch_stride == occlusion_tuple.patch_stride


def test_mask_generator():
Expand Down
71 changes: 48 additions & 23 deletions xplique/attributions/occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,13 @@ class Occlusion(BlackBoxExplainer):
def __init__(self,
model: Callable,
batch_size: Optional[int] = 32,
patch_size: Union[int, Tuple[int, int]] = (3, 3),
patch_stride: Union[int, Tuple[int, int]] = (3, 3),
patch_size: Union[int, Tuple[int, int]] = 3,
patch_stride: Union[int, Tuple[int, int]] = 3,
occlusion_value: float = 0.5):
super().__init__(model, batch_size)

self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
self.patch_stride = patch_stride if isinstance(patch_stride, tuple) \
else (patch_stride, patch_stride)
self.patch_size = patch_size
self.patch_stride = patch_stride
self.occlusion_value = occlusion_value

@sanitize_input_output
Expand All @@ -66,6 +65,16 @@ def explain(self,
explanations
Occlusion sensitivity, same shape as the inputs, except for the channels.
"""

# check if data is tabular
is_tabular = (len(inputs.shape)==2)

if not is_tabular:
if not isinstance(self.patch_size, tuple):
self.patch_size = (self.patch_size, self.patch_size)
if not isinstance(self.patch_stride, tuple):
self.patch_stride = (self.patch_stride, self.patch_stride)

sensitivity = None
batch_size = self.batch_size or len(inputs)

Expand All @@ -88,9 +97,9 @@ def explain(self,
return sensitivity

@staticmethod
def _get_masks(input_shape: Tuple[int, int, int],
patch_size: Tuple[int, int],
patch_stride: Tuple[int, int]) -> tf.Tensor:
def _get_masks(input_shape: Union[Tuple[int, int, int], Tuple[int, int], Tuple[int]],
patch_size: Union[int,Tuple[int, int]],
patch_stride: Union[int,Tuple[int, int]]) -> tf.Tensor:
"""
Create all the possible patches for the given configuration.
Expand All @@ -110,16 +119,28 @@ def _get_masks(input_shape: Tuple[int, int, int],
"""
masks = []

x_anchors = [x * patch_stride[0] for x in
range(0, ceil((input_shape[0] - patch_size[0] + 1) / patch_stride[0]))]
y_anchors = [y * patch_stride[1] for y in
range(0, ceil((input_shape[1] - patch_size[1] + 1) / patch_stride[1]))]
# check if we have tabular data
is_tabular = (len(input_shape)==1)

if is_tabular:
x_anchors = [x * patch_stride for x in
range(0, ceil((input_shape[0] - patch_size + 1) / patch_stride))]

for x_anchor in x_anchors:
for y_anchor in y_anchors:
mask = np.zeros(input_shape[:2], dtype=bool)
mask[x_anchor:x_anchor + patch_size[0], y_anchor:y_anchor + patch_size[1]] = 1
for x_anchor in x_anchors:
mask = np.zeros(input_shape, dtype=bool)
mask[x_anchor:x_anchor + patch_size] = 1
masks.append(mask)
else:
x_anchors = [x * patch_stride[0] for x in
range(0, ceil((input_shape[0] - patch_size[0] + 1) / patch_stride[0]))]
y_anchors = [y * patch_stride[1] for y in
range(0, ceil((input_shape[1] - patch_size[1] + 1) / patch_stride[1]))]

for x_anchor in x_anchors:
for y_anchor in y_anchors:
mask = np.zeros(input_shape[:2], dtype=bool)
mask[x_anchor:x_anchor + patch_size[0], y_anchor:y_anchor + patch_size[1]] = 1
masks.append(mask)

return tf.cast(masks, dtype=tf.bool)

Expand All @@ -145,15 +166,17 @@ def _apply_masks(inputs: tf.Tensor,
occluded_inputs
All the occluded combinations for each inputs.
"""

masks = tf.expand_dims(masks, axis=-1)
masks = tf.repeat(masks, repeats=inputs.shape[-1], axis=-1)

occluded_inputs = tf.expand_dims(inputs, axis=1)
occluded_inputs = tf.repeat(occluded_inputs, repeats=masks.shape[0], axis=1)

occluded_inputs = occluded_inputs * tf.cast(tf.logical_not(masks), tf.float32) + tf.cast(
masks, tf.float32) * occlusion_value
# check if inputs shape is (N, W, H, C)
has_channels = (len(inputs.shape)>3)
if has_channels:
masks = tf.expand_dims(masks, axis=-1)
masks = tf.repeat(masks, repeats=inputs.shape[-1], axis=-1)

occluded_inputs = occluded_inputs * tf.cast(tf.logical_not(masks), tf.float32)
occluded_inputs += tf.cast(masks, tf.float32) * occlusion_value

occluded_inputs = tf.reshape(occluded_inputs, (-1, *occluded_inputs.shape[2:]))

Expand Down Expand Up @@ -186,7 +209,9 @@ def _compute_sensitivity(baseline_scores: tf.Tensor,
occluded_scores = tf.reshape(occluded_scores, (-1, masks.shape[0]))

score_delta = baseline_scores - occluded_scores
score_delta = tf.reshape(score_delta, (*score_delta.shape, 1, 1))
# reshape the delta score to fit masks
score_delta = tf.reshape(score_delta, (*score_delta.shape, *(1,) * len(masks.shape[1:])))

sensitivity = score_delta * tf.cast(masks, tf.float32)
sensitivity = tf.reduce_sum(sensitivity, axis=1)

Expand Down

0 comments on commit feea923

Please sign in to comment.