Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 52 additions & 28 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,52 @@ def check_output_shape_valid(
)


def _should_skip_inputs_and_warn(
current_feature_idxs: List[int],
feature_idx_to_tensor_idx: Dict[int, List[int]],
formatted_inputs: Tuple[Tensor, ...],
min_examples_per_batch_grouped: Optional[int] = None,
) -> bool:
"""
Determines whether a feature group should be skipped during attribution computation.

This method checks two conditions that would cause a feature group to be skipped:
1. If min_examples_per_batch_grouped is specified and any input tensor in the
feature group has a batch size (0th dimension) smaller than this threshold.
2. If all input tensors in the feature group are empty (contain no elements).

Returns:
bool: True if the feature group should be skipped, False otherwise.
"""
should_skip = False
all_empty = True
tensor_idx_list = []
for feature_idx in current_feature_idxs:
tensor_idx_list += feature_idx_to_tensor_idx[feature_idx]
for tensor_idx in set(tensor_idx_list):
if all_empty and torch.numel(formatted_inputs[tensor_idx]) != 0:
all_empty = False
if min_examples_per_batch_grouped is not None and (
formatted_inputs[tensor_idx].shape[0] < min_examples_per_batch_grouped
):
should_skip = True
break
if should_skip:
logger.warning(
f"Skipping feature group {current_feature_idxs} since it contains "
f"at least one input tensor with 0th dim less than "
f"{min_examples_per_batch_grouped}"
)
return True
if all_empty:
logger.info(
f"Skipping feature group {current_feature_idxs} since all "
f"input tensors are empty"
)
return True
return False


class FeatureAblation(PerturbationAttribution):
"""
A perturbation based approach to computing attribution, involving
Expand Down Expand Up @@ -688,34 +734,12 @@ def _should_skip_inputs_and_warn(
feature_idx_to_tensor_idx: Dict[int, List[int]],
formatted_inputs: Tuple[Tensor, ...],
) -> bool:
should_skip = False
all_empty = True
tensor_idx_list = []
for feature_idx in current_feature_idxs:
tensor_idx_list += feature_idx_to_tensor_idx[feature_idx]
for tensor_idx in set(tensor_idx_list):
if all_empty and torch.numel(formatted_inputs[tensor_idx]) != 0:
all_empty = False
if self._min_examples_per_batch_grouped is not None and (
formatted_inputs[tensor_idx].shape[0]
< cast(int, self._min_examples_per_batch_grouped)
):
should_skip = True
break
if should_skip:
logger.warning(
f"Skipping feature group {current_feature_idxs} since it contains "
f"at least one input tensor with 0th dim less than "
f"{self._min_examples_per_batch_grouped}"
)
return True
if all_empty:
logger.info(
f"Skipping feature group {current_feature_idxs} since all "
f"input tensors are empty"
)
return True
return False
return _should_skip_inputs_and_warn(
current_feature_idxs=current_feature_idxs,
feature_idx_to_tensor_idx=feature_idx_to_tensor_idx,
formatted_inputs=formatted_inputs,
min_examples_per_batch_grouped=self._min_examples_per_batch_grouped,
)

def _construct_ablated_input_across_tensors(
self,
Expand Down
67 changes: 67 additions & 0 deletions tests/attr/test_feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.feature_ablation import (
_parse_forward_out,
_should_skip_inputs_and_warn,
check_output_shape_valid,
FeatureAblation,
format_result,
Expand Down Expand Up @@ -1086,5 +1087,71 @@ def test_invalid_batch_size_not_divisible_by_num_examples(self) -> None:
)


class TestShouldSkipInputsAndWarn(BaseTest):
def test_skip_when_batch_size_less_than_min_examples(self) -> None:
current_feature_idxs = [0, 1]
feature_idx_to_tensor_idx = {0: [0], 1: [0]}
formatted_inputs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]]),)
min_examples_per_batch_grouped = 3

with unittest.mock.patch(
"captum.attr._core.feature_ablation.logger"
) as mock_logger:
result = _should_skip_inputs_and_warn(
current_feature_idxs,
feature_idx_to_tensor_idx,
formatted_inputs,
min_examples_per_batch_grouped,
)

self.assertTrue(result)
mock_logger.warning.assert_called_once()

def test_no_skip_when_batch_size_equal_to_min_examples(self) -> None:
current_feature_idxs = [0, 1]
feature_idx_to_tensor_idx = {0: [0], 1: [0]}
formatted_inputs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]]),)
min_examples_per_batch_grouped = 2

result = _should_skip_inputs_and_warn(
current_feature_idxs,
feature_idx_to_tensor_idx,
formatted_inputs,
min_examples_per_batch_grouped,
)

self.assertFalse(result)

def test_skip_when_all_tensors_empty(self) -> None:
current_feature_idxs = [0]
feature_idx_to_tensor_idx = {0: [0]}
formatted_inputs = (torch.tensor([]),)

with unittest.mock.patch(
"captum.attr._core.feature_ablation.logger"
) as mock_logger:
result = _should_skip_inputs_and_warn(
current_feature_idxs,
feature_idx_to_tensor_idx,
formatted_inputs,
)

self.assertTrue(result)
mock_logger.info.assert_called_once()

def test_no_skip_when_tensors_not_empty(self) -> None:
current_feature_idxs = [0, 1]
feature_idx_to_tensor_idx = {0: [0], 1: [0]}
formatted_inputs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]]),)

result = _should_skip_inputs_and_warn(
current_feature_idxs,
feature_idx_to_tensor_idx,
formatted_inputs,
)

self.assertFalse(result)


if __name__ == "__main__":
unittest.main()
Loading