diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index ec97f8e8c..8f44202ea 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -1,7 +1,9 @@ # pyre-strict from abc import ABC, abstractmethod -from typing import Callable, cast, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL.Image import torch from captum._utils.typing import TokenizerLike @@ -123,9 +125,7 @@ def to_tensor(self) -> Tensor: pass @abstractmethod - def to_model_input( - self, perturbed_tensor: Optional[Tensor] = None - ) -> Union[str, Tensor]: + def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any: """ Get the (perturbed) input in the format required by the model based on the given (perturbed) interpretable representation. @@ -486,3 +486,161 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Tensor: def format_attr(self, itp_attr: Tensor) -> Tensor: return itp_attr + + +class MMImageMaskInput(InterpretableInput): + """ + MMImageMaskInput is an implementation of InterpretableInput for the image in + multi-modality inputs, whose + interpretable features are certain image segments (e.g., groups of pixels). + It takes the image, its corresponding segmentation masks, and + a processor function which converts the image into the model inputs. + Its input format to the model will be the tokenized multi-modality tensors + (the output of the processor function), + while its interpretable representation will be a binary tensor of the number of + the image segment features whose values indicates if the image segment is + “presence” or “absence”. + + Args: + + processor_fn (Callable): the multi-modality processor function which take + an input image to encode with any text prompt and outputs the inputs + fot the model + image (PIL.Image.Image): an opened PIL image file. + mask (Tensor, optional): the mask to group the image pixels into + segment features. It must be in the same shape as the image size + and assign each pixel a mask index. + Pixels with the same index will be seen as a single + interpretable feature, which means they must be perturbed together + and end with same attributions. When mask is None, the entire image is + considered as one interpretable feature. + Default: None + baselines (Tuple[int, int, int], optional): the baseline RGB value for + the “absence” image pixels. + Default: (255, 255, 255) + + Examples:: + + >>> def processor_fn(image): + >>> messages = [ + >>> { + >>> "role": "user", + >>> "content": [ + >>> {"type": "image"}, + >>> { + >>> "type": "text", + >>> "text": "Please describe the image in detail.", + >>> }, + >>> ], + >>> } + >>> ] + >>> + >>> prompt = processor.apply_chat_template( + >>> messages, add_generation_prompt=True + >>>) + >>> + >>> return processor( + >>> text=prompt, + >>> images=image, + >>> # tokenize=True, + >>> # return_dict=True, + >>> return_tensors="pt", + >>> ).to(model.device) + + >>> image = Image.open("test.jpg") + + >>> # Split horizontally: left half = 0, right half = 1 + >>> mask = torch.zeros(image.size[::-1], dtype=torch.int32) + >>> mask[:, image.size[0] // 2:] = 1 + + >>> image_mask_inp = MMImageMaskInput( + >>> processor_fn=processor_fn, + >>> image=image, + >>> mask=mask, + >>> ) + >>> + >>> text_inp.to_tensor() + >>> # torch.tensor([[1, 1]]) + >>> + >>> text_inp.to_model_input(torch.tensor([[0, 1]])) + >>> # model inputs where the right half of the image is masked out + + """ + + processor_fn: Callable[[PIL.Image.Image], Any] + image: PIL.Image.Image + mask: Optional[Tensor] + baselines: Tuple[int, int, int] + n_itp_features: int + original_model_inputs: Any + mask_id_to_idx: Dict[int, int] + values: List[str] = [] # no use for now + + def __init__( + self, + processor_fn: Callable[[PIL.Image.Image], Any], + image: PIL.Image.Image, + mask: Optional[Tensor] = None, + baselines: Tuple[int, int, int] = (255, 255, 255), + ) -> None: + super().__init__() + + self.processor_fn = processor_fn + self.image = image + self.mask = mask + self.baselines = baselines + + if mask is None: + self.n_itp_features = 1 + self.mask_id_to_idx = {} + else: + mask_ids = torch.unique(mask) + self.n_itp_features = len(mask_ids) + self.mask_id_to_idx = {int(mid): i for i, mid in enumerate(mask_ids)} + + self.original_model_inputs = processor_fn(image) + + def to_tensor(self) -> Tensor: + return torch.tensor([[1.0] * self.n_itp_features]) + + def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any: + if perturbed_tensor is None: + return self.original_model_inputs + + img_array = np.array(self.image) + + if self.mask is None: + if perturbed_tensor[0][0] == 0: + img_array[:, :] = self.baselines + else: + for mask_id, itp_idx in self.mask_id_to_idx.items(): + if perturbed_tensor[0][itp_idx] == 0: + mask_positions = self.mask == mask_id + img_array[mask_positions] = self.baselines + + perturbed_image = PIL.Image.fromarray(img_array.astype("uint8")) + + return self.processor_fn(perturbed_image) + + def format_attr(self, itp_attr: Tensor) -> Tensor: + device = itp_attr.device + + if self.mask is None: + # When mask is None, treat entire image as one segment + # Create a uniform mask of all zeros to broadcast the single attribution + img_array = np.array(self.image) + image_shape = img_array.shape[:2] # (height, width) + formatted_mask = torch.zeros(image_shape, dtype=torch.long, device=device) + else: + # Map mask IDs to continuous indices + image_shape = self.mask.shape + formatted_mask = torch.zeros_like(self.mask, device=device) + for mask_id, itp_idx in self.mask_id_to_idx.items(): + formatted_mask[self.mask == mask_id] = itp_idx + + formatted_attr = _scatter_itp_attr_by_mask( + itp_attr, + (1, *image_shape), + formatted_mask.unsqueeze(0), + ) + return formatted_attr diff --git a/tests/attr/test_interpretable_input.py b/tests/attr/test_interpretable_input.py index 4d6be2c33..78400508f 100644 --- a/tests/attr/test_interpretable_input.py +++ b/tests/attr/test_interpretable_input.py @@ -2,11 +2,17 @@ # pyre-unsafe -from typing import List, Literal, Optional, overload, Union +from typing import Any, Dict, List, Literal, Optional, overload, Union +import numpy as np +import PIL.Image import torch from captum._utils.typing import BatchEncodingType -from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput +from captum.attr._utils.interpretable_input import ( + MMImageMaskInput, + TextTemplateInput, + TextTokenInput, +) from captum.testing.helpers import BaseTest from captum.testing.helpers.basic import assertTensorAlmostEqual from parameterized import parameterized @@ -228,3 +234,282 @@ def test_input_with_skip_tokens(self) -> None: assertTensorAlmostEqual( self, tt_input.to_model_input(perturbed_tensor), expected_perturbed_inp ) + + +class TestMMImageMaskInput(BaseTest): + def _create_test_image( + self, width: int = 10, height: int = 10, color: tuple = (255, 0, 0) + ) -> PIL.Image.Image: + """Helper method to create a test PIL image.""" + img_array = np.full((height, width, 3), color, dtype=np.uint8) + return PIL.Image.fromarray(img_array) + + def _simple_processor(self, image: PIL.Image.Image) -> Dict[str, Tensor]: + """Simple test processor that converts image to tensor.""" + img_array = np.array(image) + return {"pixel_values": torch.from_numpy(img_array).float()} + + def test_init_without_mask(self) -> None: + # Setup: create test image and processor + image = self._create_test_image() + + # Execute: create MMImageMaskInput without mask + mm_input = MMImageMaskInput( + processor_fn=self._simple_processor, + image=image, + ) + + # Assert: verify n_itp_features is 1 when no mask provided + self.assertEqual(mm_input.n_itp_features, 1) + self.assertEqual(len(mm_input.mask_id_to_idx), 0) + self.assertIsNone(mm_input.mask) + + def test_init_with_mask(self) -> None: + # Setup: create test image and mask with 2 segments + image = self._create_test_image() + mask = torch.zeros((10, 10), dtype=torch.int32) + mask[:, 5:] = 1 # Split horizontally into 2 segments + + # Execute: create MMImageMaskInput with mask + mm_input = MMImageMaskInput( + processor_fn=self._simple_processor, + image=image, + mask=mask, + ) + + # Assert: verify n_itp_features matches number of unique mask values + self.assertEqual(mm_input.n_itp_features, 2) + self.assertEqual(len(mm_input.mask_id_to_idx), 2) + self.assertIn(0, mm_input.mask_id_to_idx) + self.assertIn(1, mm_input.mask_id_to_idx) + + def test_init_with_non_continuous_mask_ids(self) -> None: + # Setup: create mask with non-continuous IDs (e.g., 5, 10, 15) + image = self._create_test_image(width=15, height=10) + mask = torch.zeros((10, 15), dtype=torch.int32) + mask[:, 5:10] = 5 + mask[:, 10:] = 10 + + # Execute: create MMImageMaskInput + mm_input = MMImageMaskInput( + processor_fn=self._simple_processor, + image=image, + mask=mask, + ) + + # Assert: verify mask_id_to_idx creates continuous mapping + self.assertEqual(mm_input.n_itp_features, 3) + self.assertEqual(len(mm_input.mask_id_to_idx), 3) + # Verify all mask IDs are mapped to continuous indices 0, 1, 2 + mapped_indices = set(mm_input.mask_id_to_idx.values()) + self.assertEqual(mapped_indices, {0, 1, 2}) + + def test_to_tensor_without_mask(self) -> None: + # Setup: create MMImageMaskInput without mask + image = self._create_test_image() + mm_input = MMImageMaskInput( + processor_fn=self._simple_processor, + image=image, + ) + + # Execute: convert to tensor + result = mm_input.to_tensor() + + # Assert: verify tensor shape and values for single feature + expected = torch.tensor([[1.0]]) + assertTensorAlmostEqual(self, result, expected) + + def test_to_tensor_with_mask(self) -> None: + # Setup: create MMImageMaskInput with 3 segments + image = self._create_test_image(width=15) + mask = torch.zeros((10, 15), dtype=torch.int32) + mask[:, 5:10] = 1 + mask[:, 10:] = 2 + + mm_input = MMImageMaskInput( + processor_fn=self._simple_processor, + image=image, + mask=mask, + ) + + # Execute: convert to tensor + result = mm_input.to_tensor() + + # Assert: verify tensor has correct number of features + expected = torch.tensor([[1.0, 1.0, 1.0]]) + assertTensorAlmostEqual(self, result, expected) + + def test_to_model_input_without_perturbation(self) -> None: + # Setup: create MMImageMaskInput + image = self._create_test_image() + mm_input = MMImageMaskInput( + processor_fn=self._simple_processor, + image=image, + ) + + # Execute: get model input without perturbation + result = mm_input.to_model_input() + + # Assert: verify returns original model inputs + self.assertIn("pixel_values", result) + assertTensorAlmostEqual( + self, result["pixel_values"], mm_input.original_model_inputs["pixel_values"] + ) + + def test_to_model_input_with_perturbation_no_mask_present(self) -> None: + # Setup: create red image without mask + image = self._create_test_image(color=(255, 0, 0)) + mm_input = MMImageMaskInput( + processor_fn=self._simple_processor, + image=image, + baselines=(255, 255, 255), # white baseline + ) + + # Execute: perturb with feature present (value 1) + perturbed_tensor = torch.tensor([[1.0]]) + result = mm_input.to_model_input(perturbed_tensor) + + # Assert: image should remain red (unchanged) + img_array = result["pixel_values"].numpy().astype(np.uint8) + self.assertTrue(np.all(img_array[:, :, 0] == 255)) + self.assertTrue(np.all(img_array[:, :, 1] == 0)) + self.assertTrue(np.all(img_array[:, :, 2] == 0)) + + def test_to_model_input_with_perturbation_no_mask_absent(self) -> None: + # Setup: create red image without mask + image = self._create_test_image(color=(255, 0, 0)) + mm_input = MMImageMaskInput( + processor_fn=self._simple_processor, + image=image, + baselines=(255, 255, 255), # white baseline + ) + + # Execute: perturb with feature absent (value 0) + perturbed_tensor = torch.tensor([[0.0]]) + result = mm_input.to_model_input(perturbed_tensor) + + # Assert: entire image should be white (baseline) + img_array = result["pixel_values"].numpy().astype(np.uint8) + self.assertTrue(np.all(img_array == 255)) + + def test_to_model_input_with_mask_partial_perturbation(self) -> None: + # Setup: create image with 2 segments (left red, right green) + img_array = np.zeros((10, 10, 3), dtype=np.uint8) + img_array[:, :5] = [255, 0, 0] # Left half red + img_array[:, 5:] = [0, 255, 0] # Right half green + image = PIL.Image.fromarray(img_array) + + mask = torch.zeros((10, 10), dtype=torch.int32) + mask[:, 5:] = 1 # Right half is segment 1 + + mm_input = MMImageMaskInput( + processor_fn=self._simple_processor, + image=image, + mask=mask, + baselines=(255, 255, 255), # white baseline + ) + + # Execute: perturb to keep left segment (0) but remove right segment (1) + perturbed_tensor = torch.tensor([[1.0, 0.0]]) + result = mm_input.to_model_input(perturbed_tensor) + + # Assert: left half should be red, right half should be white + img_array = result["pixel_values"].numpy().astype(np.uint8) + # Left half should be red + self.assertTrue(np.all(img_array[:, :5, 0] == 255)) + self.assertTrue(np.all(img_array[:, :5, 1] == 0)) + # Right half should be white (baseline) + self.assertTrue(np.all(img_array[:, 5:] == 255)) + + def test_to_model_input_with_custom_baselines(self) -> None: + # Setup: create image with custom baseline color + image = self._create_test_image(color=(255, 0, 0)) + mm_input = MMImageMaskInput( + processor_fn=self._simple_processor, + image=image, + baselines=(0, 128, 255), # Custom blue-ish baseline + ) + + # Execute: perturb to remove feature + perturbed_tensor = torch.tensor([[0.0]]) + result = mm_input.to_model_input(perturbed_tensor) + + # Assert: image should have custom baseline color + img_array = result["pixel_values"].numpy().astype(np.uint8) + self.assertTrue(np.all(img_array[:, :, 0] == 0)) + self.assertTrue(np.all(img_array[:, :, 1] == 128)) + self.assertTrue(np.all(img_array[:, :, 2] == 255)) + + def test_format_attr_without_mask(self) -> None: + # Setup: create MMImageMaskInput without mask + image = self._create_test_image(width=5, height=5) + mm_input = MMImageMaskInput( + processor_fn=self._simple_processor, + image=image, + ) + + # Execute: format attribution for single feature + attr = torch.tensor([[0.5]]) + result = mm_input.format_attr(attr) + + # Assert: attribution should be broadcast to all pixels + self.assertEqual(result.shape, (1, 5, 5)) + self.assertTrue(torch.all(result == 0.5)) + + def test_format_attr_with_mask(self) -> None: + # Setup: create MMImageMaskInput with 2 segments + image = self._create_test_image(width=10, height=5) + mask = torch.zeros((5, 10), dtype=torch.int32) + mask[:, 5:] = 1 # Split horizontally + + mm_input = MMImageMaskInput( + processor_fn=self._simple_processor, + image=image, + mask=mask, + ) + + # Execute: format attribution with different values for each segment + attr = torch.tensor([[0.3, 0.7]]) + result = mm_input.format_attr(attr) + + # Assert: left half should have 0.3, right half should have 0.7 + self.assertEqual(result.shape, (1, 5, 10)) + assertTensorAlmostEqual( + self, result[0, :, :5], torch.full((5, 5), 0.3) + ) # Left half + assertTensorAlmostEqual( + self, result[0, :, 5:], torch.full((5, 5), 0.7) + ) # Right half + + def test_format_attr_with_non_continuous_mask(self) -> None: + # Setup: create mask with non-continuous IDs + image = self._create_test_image(width=15, height=5) + mask = torch.zeros((5, 15), dtype=torch.int32) + mask[:, 5:10] = 10 + mask[:, 10:] = 20 + + mm_input = MMImageMaskInput( + processor_fn=self._simple_processor, + image=image, + mask=mask, + ) + + # Execute: format attribution + attr = torch.tensor([[0.1, 0.2, 0.3]]) + result = mm_input.format_attr(attr) + + # Assert: verify correct attribution values for each segment + self.assertEqual(result.shape, (1, 5, 15)) + # Find which continuous index maps to which mask ID + idx_0 = mm_input.mask_id_to_idx[0] + idx_10 = mm_input.mask_id_to_idx[10] + idx_20 = mm_input.mask_id_to_idx[20] + + # Verify each segment has its corresponding attribution + segment_0_value = attr[0, idx_0].item() + segment_10_value = attr[0, idx_10].item() + segment_20_value = attr[0, idx_20].item() + + self.assertTrue(torch.all(result[0, :, :5] == segment_0_value)) + self.assertTrue(torch.all(result[0, :, 5:10] == segment_10_value)) + self.assertTrue(torch.all(result[0, :, 10:] == segment_20_value))