Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 162 additions & 4 deletions captum/attr/_utils/interpretable_input.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# pyre-strict
from abc import ABC, abstractmethod
from typing import Callable, cast, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

import numpy as np
import PIL.Image
import torch

from captum._utils.typing import TokenizerLike
Expand Down Expand Up @@ -123,9 +125,7 @@ def to_tensor(self) -> Tensor:
pass

@abstractmethod
def to_model_input(
self, perturbed_tensor: Optional[Tensor] = None
) -> Union[str, Tensor]:
def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any:
"""
Get the (perturbed) input in the format required by the model
based on the given (perturbed) interpretable representation.
Expand Down Expand Up @@ -486,3 +486,161 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Tensor:

def format_attr(self, itp_attr: Tensor) -> Tensor:
return itp_attr


class MMImageMaskInput(InterpretableInput):
"""
MMImageMaskInput is an implementation of InterpretableInput for the image in
multi-modality inputs, whose
interpretable features are certain image segments (e.g., groups of pixels).
It takes the image, its corresponding segmentation masks, and
a processor function which converts the image into the model inputs.
Its input format to the model will be the tokenized multi-modality tensors
(the output of the processor function),
while its interpretable representation will be a binary tensor of the number of
the image segment features whose values indicates if the image segment is
“presence” or “absence”.

Args:

processor_fn (Callable): the multi-modality processor function which take
an input image to encode with any text prompt and outputs the inputs
fot the model
image (PIL.Image.Image): an opened PIL image file.
mask (Tensor, optional): the mask to group the image pixels into
segment features. It must be in the same shape as the image size
and assign each pixel a mask index.
Pixels with the same index will be seen as a single
interpretable feature, which means they must be perturbed together
and end with same attributions. When mask is None, the entire image is
considered as one interpretable feature.
Default: None
baselines (Tuple[int, int, int], optional): the baseline RGB value for
the “absence” image pixels.
Default: (255, 255, 255)

Examples::

>>> def processor_fn(image):
>>> messages = [
>>> {
>>> "role": "user",
>>> "content": [
>>> {"type": "image"},
>>> {
>>> "type": "text",
>>> "text": "Please describe the image in detail.",
>>> },
>>> ],
>>> }
>>> ]
>>>
>>> prompt = processor.apply_chat_template(
>>> messages, add_generation_prompt=True
>>>)
>>>
>>> return processor(
>>> text=prompt,
>>> images=image,
>>> # tokenize=True,
>>> # return_dict=True,
>>> return_tensors="pt",
>>> ).to(model.device)

>>> image = Image.open("test.jpg")

>>> # Split horizontally: left half = 0, right half = 1
>>> mask = torch.zeros(image.size[::-1], dtype=torch.int32)
>>> mask[:, image.size[0] // 2:] = 1

>>> image_mask_inp = MMImageMaskInput(
>>> processor_fn=processor_fn,
>>> image=image,
>>> mask=mask,
>>> )
>>>
>>> text_inp.to_tensor()
>>> # torch.tensor([[1, 1]])
>>>
>>> text_inp.to_model_input(torch.tensor([[0, 1]]))
>>> # model inputs where the right half of the image is masked out

"""

processor_fn: Callable[[PIL.Image.Image], Any]
image: PIL.Image.Image
mask: Optional[Tensor]
baselines: Tuple[int, int, int]
n_itp_features: int
original_model_inputs: Any
mask_id_to_idx: Dict[int, int]
values: List[str] = [] # no use for now

def __init__(
self,
processor_fn: Callable[[PIL.Image.Image], Any],
image: PIL.Image.Image,
mask: Optional[Tensor] = None,
baselines: Tuple[int, int, int] = (255, 255, 255),
) -> None:
super().__init__()

self.processor_fn = processor_fn
self.image = image
self.mask = mask
self.baselines = baselines

if mask is None:
self.n_itp_features = 1
self.mask_id_to_idx = {}
else:
mask_ids = torch.unique(mask)
self.n_itp_features = len(mask_ids)
self.mask_id_to_idx = {int(mid): i for i, mid in enumerate(mask_ids)}

self.original_model_inputs = processor_fn(image)

def to_tensor(self) -> Tensor:
return torch.tensor([[1.0] * self.n_itp_features])

def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any:
if perturbed_tensor is None:
return self.original_model_inputs

img_array = np.array(self.image)

if self.mask is None:
if perturbed_tensor[0][0] == 0:
img_array[:, :] = self.baselines
else:
for mask_id, itp_idx in self.mask_id_to_idx.items():
if perturbed_tensor[0][itp_idx] == 0:
mask_positions = self.mask == mask_id
img_array[mask_positions] = self.baselines

perturbed_image = PIL.Image.fromarray(img_array.astype("uint8"))

return self.processor_fn(perturbed_image)

def format_attr(self, itp_attr: Tensor) -> Tensor:
device = itp_attr.device

if self.mask is None:
# When mask is None, treat entire image as one segment
# Create a uniform mask of all zeros to broadcast the single attribution
img_array = np.array(self.image)
image_shape = img_array.shape[:2] # (height, width)
formatted_mask = torch.zeros(image_shape, dtype=torch.long, device=device)
else:
# Map mask IDs to continuous indices
image_shape = self.mask.shape
formatted_mask = torch.zeros_like(self.mask, device=device)
for mask_id, itp_idx in self.mask_id_to_idx.items():
formatted_mask[self.mask == mask_id] = itp_idx

formatted_attr = _scatter_itp_attr_by_mask(
itp_attr,
(1, *image_shape),
formatted_mask.unsqueeze(0),
)
return formatted_attr
Loading