Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions captum/attr/_utils/interpretable_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,24 +488,21 @@ def format_attr(self, itp_attr: Tensor) -> Tensor:
return itp_attr


class MMImageMaskInput(InterpretableInput):
class ImageMaskInput(InterpretableInput):
"""
MMImageMaskInput is an implementation of InterpretableInput for the image in
multi-modality inputs, whose
ImageMaskInput is an implementation of InterpretableInput for the image, whose
interpretable features are certain image segments (e.g., groups of pixels).
It takes the image, its corresponding segmentation masks, and
a processor function which converts the image into the model inputs.
Its input format to the model will be the tokenized multi-modality tensors
(the output of the processor function),
while its interpretable representation will be a binary tensor of the number of
It takes an image and its corresponding segmentation masks .
Its interpretable representation will be a binary tensor of the number of
the image segment features whose values indicates if the image segment is
“presence” or “absence”.
“presence” or “absence”. By default, its model input will be the masked out image.
But it also accepts an optional processor function, which converts the image into
the model inputs if the model does not directly consume images.
A typical example is MM LLM, which requires the image to be processed with
other texts into the tokenized multi-modality tensors.

Args:

processor_fn (Callable): the multi-modality processor function which take
an input image to encode with any text prompt and outputs the inputs
for the model
image (PIL.Image.Image): an opened PIL image file.
mask (Tensor, optional): the mask to group the image pixels into
segment features. It must be in the same shape as the image size
Expand All @@ -515,9 +512,13 @@ class MMImageMaskInput(InterpretableInput):
and end with same attributions. When mask is None, the entire image is
considered as one interpretable feature.
Default: None
baselines (Tuple[int, int, int], optional): the baseline RGB value for
baseline (Tuple[int, int, int], optional): the baseline RGB value for
the “absent” image pixels.
Default: (255, 255, 255)
processor_fn (Callable, optional): function convert the image into
the model input. A common example is the multi-modality LLM processor
function which take an input image to encode with any text prompt
and outputs the inputs for the LLM

Examples::

Expand Down Expand Up @@ -551,10 +552,10 @@ class MMImageMaskInput(InterpretableInput):
>>> mask = torch.zeros(image.size[::-1], dtype=torch.int32)
>>> mask[:, image.size[0] // 2:] = 1
>>>
>>> image_mask_inp = MMImageMaskInput(
>>> processor_fn=processor_fn,
>>> image_mask_inp = ImageMaskInput(
>>> image=image,
>>> mask=mask,
>>> processor_fn=processor_fn,
>>> )
>>>
>>> text_inp.to_tensor()
Expand All @@ -565,27 +566,27 @@ class MMImageMaskInput(InterpretableInput):

"""

processor_fn: Callable[[PIL.Image.Image], Any]
image: PIL.Image.Image
mask: Tensor
baselines: Tuple[int, int, int]
baseline: Tuple[int, int, int]
processor_fn: Callable[[PIL.Image.Image], Any]
n_itp_features: int
original_model_inputs: Any
mask_id_to_idx: Dict[int, int]
values: List[str]

def __init__(
self,
processor_fn: Callable[[PIL.Image.Image], Any],
image: PIL.Image.Image,
mask: Optional[Tensor] = None,
baselines: Tuple[int, int, int] = (255, 255, 255),
baseline: Tuple[int, int, int] = (255, 255, 255),
processor_fn: Callable[[PIL.Image.Image], Any] = lambda x: x,
) -> None:
super().__init__()

self.processor_fn = processor_fn
self.image = image
self.baselines = baselines
self.baseline = baseline

# Create a dummy mask if None is provided
if mask is None:
Expand Down Expand Up @@ -622,7 +623,7 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any:
for mask_id, itp_idx in self.mask_id_to_idx.items():
if perturbed_tensor[0][itp_idx] == 0:
mask_positions = self.mask == mask_id
img_array[mask_positions] = self.baselines
img_array[mask_positions] = self.baseline

perturbed_image = PIL.Image.fromarray(img_array.astype("uint8"))

Expand Down
54 changes: 27 additions & 27 deletions tests/attr/test_interpretable_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from captum._utils.typing import BatchEncodingType
from captum.attr._utils.interpretable_input import (
MMImageMaskInput,
ImageMaskInput,
TextTemplateInput,
TextTokenInput,
)
Expand Down Expand Up @@ -236,7 +236,7 @@ def test_input_with_skip_tokens(self) -> None:
)


class TestMMImageMaskInput(BaseTest):
class TestImageMaskInput(BaseTest):
def _create_test_image(
self, width: int = 10, height: int = 10, color: tuple = (255, 0, 0)
) -> PIL.Image.Image:
Expand All @@ -253,8 +253,8 @@ def test_init_without_mask(self) -> None:
# Setup: create test image and processor
image = self._create_test_image()

# Execute: create MMImageMaskInput without mask
mm_input = MMImageMaskInput(
# Execute: create ImageMaskInput without mask
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
)
Expand All @@ -275,8 +275,8 @@ def test_init_with_mask(self) -> None:
mask = torch.zeros((10, 10), dtype=torch.int32)
mask[:, 5:] = 1 # Split horizontally into 2 segments

# Execute: create MMImageMaskInput with mask
mm_input = MMImageMaskInput(
# Execute: create ImageMaskInput with mask
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
mask=mask,
Expand All @@ -295,8 +295,8 @@ def test_init_with_non_continuous_mask_ids(self) -> None:
mask[:, 5:10] = 5
mask[:, 10:] = 10

# Execute: create MMImageMaskInput
mm_input = MMImageMaskInput(
# Execute: create ImageMaskInput
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
mask=mask,
Expand All @@ -310,9 +310,9 @@ def test_init_with_non_continuous_mask_ids(self) -> None:
self.assertEqual(mapped_indices, {0, 1, 2})

def test_to_tensor_without_mask(self) -> None:
# Setup: create MMImageMaskInput without mask
# Setup: create ImageMaskInput without mask
image = self._create_test_image()
mm_input = MMImageMaskInput(
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
)
Expand All @@ -325,13 +325,13 @@ def test_to_tensor_without_mask(self) -> None:
assertTensorAlmostEqual(self, result, expected)

def test_to_tensor_with_mask(self) -> None:
# Setup: create MMImageMaskInput with 3 segments
# Setup: create ImageMaskInput with 3 segments
image = self._create_test_image(width=15)
mask = torch.zeros((10, 15), dtype=torch.int32)
mask[:, 5:10] = 1
mask[:, 10:] = 2

mm_input = MMImageMaskInput(
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
mask=mask,
Expand All @@ -345,9 +345,9 @@ def test_to_tensor_with_mask(self) -> None:
assertTensorAlmostEqual(self, result, expected)

def test_to_model_input_without_perturbation(self) -> None:
# Setup: create MMImageMaskInput
# Setup: create ImageMaskInput
image = self._create_test_image()
mm_input = MMImageMaskInput(
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
)
Expand All @@ -364,10 +364,10 @@ def test_to_model_input_without_perturbation(self) -> None:
def test_to_model_input_with_perturbation_no_mask_present(self) -> None:
# Setup: create red image without mask
image = self._create_test_image(color=(255, 0, 0))
mm_input = MMImageMaskInput(
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
baselines=(255, 255, 255), # white baseline
baseline=(255, 255, 255), # white baseline
)

# Execute: perturb with feature present (value 1)
Expand All @@ -383,10 +383,10 @@ def test_to_model_input_with_perturbation_no_mask_present(self) -> None:
def test_to_model_input_with_perturbation_no_mask_absent(self) -> None:
# Setup: create red image without mask
image = self._create_test_image(color=(255, 0, 0))
mm_input = MMImageMaskInput(
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
baselines=(255, 255, 255), # white baseline
baseline=(255, 255, 255), # white baseline
)

# Execute: perturb with feature absent (value 0)
Expand All @@ -407,11 +407,11 @@ def test_to_model_input_with_mask_partial_perturbation(self) -> None:
mask = torch.zeros((10, 10), dtype=torch.int32)
mask[:, 5:] = 1 # Right half is segment 1

mm_input = MMImageMaskInput(
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
mask=mask,
baselines=(255, 255, 255), # white baseline
baseline=(255, 255, 255), # white baseline
)

# Execute: perturb to keep left segment (0) but remove right segment (1)
Expand All @@ -429,10 +429,10 @@ def test_to_model_input_with_mask_partial_perturbation(self) -> None:
def test_to_model_input_with_custom_baselines(self) -> None:
# Setup: create image with custom baseline color
image = self._create_test_image(color=(255, 0, 0))
mm_input = MMImageMaskInput(
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
baselines=(0, 128, 255), # Custom blue-ish baseline
baseline=(0, 128, 255), # Custom blue-ish baseline
)

# Execute: perturb to remove feature
Expand All @@ -446,9 +446,9 @@ def test_to_model_input_with_custom_baselines(self) -> None:
self.assertTrue(np.all(img_array[:, :, 2] == 255))

def test_format_attr_without_mask(self) -> None:
# Setup: create MMImageMaskInput without mask
# Setup: create ImageMaskInput without mask
image = self._create_test_image(width=5, height=5)
mm_input = MMImageMaskInput(
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
)
Expand All @@ -462,12 +462,12 @@ def test_format_attr_without_mask(self) -> None:
self.assertTrue(torch.all(result == 0.5))

def test_format_attr_with_mask(self) -> None:
# Setup: create MMImageMaskInput with 2 segments
# Setup: create ImageMaskInput with 2 segments
image = self._create_test_image(width=10, height=5)
mask = torch.zeros((5, 10), dtype=torch.int32)
mask[:, 5:] = 1 # Split horizontally

mm_input = MMImageMaskInput(
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
mask=mask,
Expand All @@ -493,7 +493,7 @@ def test_format_attr_with_non_continuous_mask(self) -> None:
mask[:, 5:10] = 10
mask[:, 10:] = 20

mm_input = MMImageMaskInput(
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
mask=mask,
Expand Down
Loading