From 5f8365daaa7e5b9416b901df9440da5a8a5b3bb8 Mon Sep 17 00:00:00 2001 From: Alejandro Perez Munoz Date: Wed, 12 Nov 2025 13:13:17 -0800 Subject: [PATCH 1/2] Add documentation and clean some typing information for common file. (#1662) Summary: Improve documentation of the _format method and cleanup a few typing errors. Reviewed By: styusuf Differential Revision: D86432880 --- captum/_utils/common.py | 79 +++++++++++++++++++++++++---------------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index ee4c84834..757e5b0b9 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 # pyre-strict -import typing from enum import Enum from functools import reduce from inspect import signature @@ -65,7 +64,7 @@ def safe_div( denom: Union[Tensor, int, float], default_denom: Union[Tensor, int, float] = 1.0, ) -> Tensor: - r""" + """ A simple utility function to perform `numerator / denom` if the statement is undefined => result will be `numerator / default_denorm` """ @@ -81,15 +80,15 @@ def safe_div( return numerator / torch.where(denom != 0, denom, default_denom) -@typing.overload +@overload def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ... -@typing.overload +@overload def _is_tuple(inputs: Tensor) -> Literal[False]: ... -@typing.overload +@overload def _is_tuple( inputs: Union[Tensor, Tuple[Tensor, ...]], ) -> bool: ... @@ -150,7 +149,7 @@ def _validate_input( def _zeros(inputs: Tuple[Tensor, ...]) -> Tuple[int, ...]: - r""" + """ Takes a tuple of tensors as input and returns a tuple that has the same length as `inputs` with each element as the integer 0. """ @@ -160,6 +159,10 @@ def _zeros(inputs: Tuple[Tensor, ...]) -> Tuple[int, ...]: def _format_baseline( baselines: BaselineType, inputs: Tuple[Tensor, ...] ) -> Tuple[Union[Tensor, int, float], ...]: + """ + Converts baselines to tuple format, returning zeros if None, + or wrapping single values in a tuple. + """ if baselines is None: return _zeros(inputs) @@ -197,11 +200,8 @@ def _format_feature_mask( start_idx: int = 0, ) -> Tuple[Tensor, ...]: """ - Format a feature mask into a tuple of tensors. - The `inputs` should be correctly formatted first - If `feature_mask` is None, assign each non-batch dimension with a consecutive - integer from `start_idx`. - If `feature_mask` is a tensor, wrap it in a tuple. + Converts feature mask to tuple format, auto-generating default mask + from start_idx if None. """ if feature_mask is None: formatted_mask = [] @@ -240,6 +240,9 @@ def _format_tensor_into_tuples( def _format_tensor_into_tuples( inputs: Union[None, Tensor, Tuple[Tensor, ...]], ) -> Union[None, Tuple[Tensor, ...]]: + """ + Converts tensor inputs to tuple format, returning None unchanged if None. + """ if inputs is None: return None if not isinstance(inputs, tuple): @@ -252,6 +255,10 @@ def _format_tensor_into_tuples( def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any: + """ + Returns inputs unchanged if already tuple/list + and unpack_inputs=True, otherwise wraps in tuple. + """ return ( inputs if (isinstance(inputs, tuple) or isinstance(inputs, list)) and unpack_inputs @@ -262,6 +269,9 @@ def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any: def _format_float_or_tensor_into_tuples( inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]], ) -> Tuple[Union[float, Tensor], ...]: + """ + Converts float or tensor inputs to tuple format, wrapping single values in a tuple. + """ if not isinstance(inputs, tuple): assert isinstance( inputs, (torch.Tensor, float) @@ -274,23 +284,28 @@ def _format_float_or_tensor_into_tuples( @overload def _format_additional_forward_args( - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. - additional_forward_args: Union[Tensor, Tuple], - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. -) -> Tuple: ... + additional_forward_args: Union[Tensor, Tuple[Any, ...]], +) -> Tuple[Any, ...]: ... @overload -def _format_additional_forward_args( # type: ignore +def _format_additional_forward_args( + additional_forward_args: None, +) -> None: ... + + +@overload +def _format_additional_forward_args( additional_forward_args: Optional[object], - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. -) -> Union[None, Tuple]: ... +) -> Optional[Tuple[Any, ...]]: ... def _format_additional_forward_args( additional_forward_args: Optional[object], - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. -) -> Union[None, Tuple]: +) -> Optional[Tuple[Any, ...]]: + """ + Converts additional forward args to tuple format, returning None unchanged if None. + """ if additional_forward_args is not None and not isinstance( additional_forward_args, tuple ): @@ -478,21 +493,21 @@ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict) -> None: kwargs["feature_mask"] = feature_mask -@typing.overload +@overload def _format_output( is_inputs_tuple: Literal[True], output: Tuple[Tensor, ...], ) -> Tuple[Tensor, ...]: ... -@typing.overload +@overload def _format_output( is_inputs_tuple: Literal[False], output: Tuple[Tensor, ...], ) -> Tensor: ... -@typing.overload +@overload def _format_output( is_inputs_tuple: bool, output: Tuple[Tensor, ...] ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @@ -501,7 +516,7 @@ def _format_output( def _format_output( is_inputs_tuple: bool, output: Tuple[Tensor, ...] ) -> Union[Tensor, Tuple[Tensor, ...]]: - r""" + """ In case input is a tensor and the output is returned in form of a tuple we take the first element of the output's tuple to match the same shape signatues of the inputs @@ -516,21 +531,21 @@ def _format_output( return output if is_inputs_tuple else output[0] -@typing.overload +@overload def _format_outputs( is_multiple_inputs: Literal[False], outputs: List[Tuple[Tensor, ...]], ) -> Union[Tensor, Tuple[Tensor, ...]]: ... -@typing.overload +@overload def _format_outputs( is_multiple_inputs: Literal[True], outputs: List[Tuple[Tensor, ...]], ) -> List[Union[Tensor, Tuple[Tensor, ...]]]: ... -@typing.overload +@overload def _format_outputs( is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]] ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ... @@ -539,6 +554,10 @@ def _format_outputs( def _format_outputs( is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]] ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: + """ + Formats list of output tuples: returns list if is_multiple_inputs is True, + otherwise single formatted output. + """ assert isinstance(outputs, list), "Outputs must be a list" assert is_multiple_inputs or len(outputs) == 1, ( "outputs should contain multiple inputs or have a single output" @@ -554,9 +573,7 @@ def _format_outputs( # pyre-fixme[24] Callable requires 2 arguments def _construct_future_forward(original_forward: Callable) -> Callable: - # pyre-fixme[3] return type not specified - def future_forward(*args: Any, **kwargs: Any): - # pyre-fixme[29]: `typing.Type[torch.futures.Future]` is not a function. + def future_forward(*args: Any, **kwargs: Any) -> torch.futures.Future[Tensor]: fut: torch.futures.Future[Tensor] = torch.futures.Future() fut.set_result(original_forward(*args, **kwargs)) return fut @@ -829,7 +846,7 @@ def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor: def _get_module_from_name(model: Module, layer_name: str) -> Any: - r""" + """ Returns the module (layer) object, given its (string) name in the model. From 8b8ae736cf9b9a2da8b164685f84ea9fa709f78e Mon Sep 17 00:00:00 2001 From: Alejandro Perez Munoz Date: Wed, 12 Nov 2025 13:13:17 -0800 Subject: [PATCH 2/2] Move methods out of FeatureAblation and make them free methods. (#1663) Summary: This diff moves methods out of the `FeatureAblation` class and makes them free methods. The changes include creating a new method `_parse_forward_out` to force forward output type assertion and conversion, and modifying the `add_one_back` module to use the new methods. The `attr/fb/add_one_back.py` file has been modified to use the new methods. The `attr/fb/within_group_utils.py` file has also been modified to use the new methods. The `attr/fb/test_within_group_utils.py` file has been Differential Revision: D86785624 --- captum/attr/_core/feature_ablation.py | 202 ++++++++++++++----------- captum/testing/helpers/basic_models.py | 1 - tests/attr/test_feature_ablation.py | 117 +++++++++++++- 3 files changed, 225 insertions(+), 95 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 8df1e53d3..e0fa0dbdb 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -4,7 +4,18 @@ import logging import math -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Tuple, + TypeVar, + Union, +) import torch from captum._utils.common import ( @@ -37,6 +48,94 @@ logger: logging.Logger = logging.getLogger(__name__) +def _parse_forward_out(forward_output: object) -> Tensor: + """ + A temp wrapper for global _run_forward util to force forward output + type assertion & conversion. + Remove after the strict logic is supported by all attr classes + """ + if isinstance(forward_output, Tensor): + return forward_output + + output_type = type(forward_output) + assert output_type is int or output_type is float, ( + "the return of forward_func must be a tensor, int, or float," + f" received: {forward_output}" + ) + + # using python built-in type as torch dtype + # int -> torch.int64, float -> torch.float64 + # ref: https://github.com/pytorch/pytorch/pull/21215 + return torch.tensor(forward_output, dtype=cast(dtype, output_type)) + + +def process_initial_eval( + initial_eval: Tensor, + inputs: Iterable[Tensor], + use_weights: bool = False, +) -> Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]: + + initial_eval = _parse_forward_out(initial_eval) + + # number of elements in the output of forward_func + n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1 + + # flatten eval outputs into 1D (n_outputs) + # add the leading dim for n_feature_perturbed + flattened_initial_eval = initial_eval.reshape(1, -1) + + # Initialize attribution totals and counts + attrib_type = flattened_initial_eval.dtype + + total_attrib = [ + # attribute w.r.t each output element + torch.zeros( + (n_outputs,) + input.shape[1:], + dtype=attrib_type, + device=input.device, + ) + for input in inputs + ] + + # Weights are used in cases where ablations may be overlapping. + weights = [] + if use_weights: + weights = [ + torch.zeros((n_outputs,) + input.shape[1:], device=input.device).float() + for input in inputs + ] + + return ( + total_attrib, + weights, + initial_eval, + flattened_initial_eval, + n_outputs, + attrib_type, + ) + + +def format_result( + total_attrib: List[Tensor], + weights: List[Tensor], + is_inputs_tuple: bool, + use_weights: bool, +) -> Union[Tensor, Tuple[Tensor, ...]]: + """ + Normalizes attributions by weights if enabled and + formats output as single tensor or tuple. + """ + # Divide total attributions by counts and return formatted attributions + if use_weights: + attrib = tuple( + single_attrib.float() / weight + for single_attrib, weight in zip(total_attrib, weights) + ) + else: + attrib = tuple(total_attrib) + return _format_output(is_inputs_tuple, attrib) + + class FeatureAblation(PerturbationAttribution): r""" A perturbation based approach to computing attribution, involving @@ -331,9 +430,8 @@ def attribute( flattened_initial_eval, n_outputs, attrib_type, - ) = self._process_initial_eval( - initial_eval, - formatted_inputs, + ) = process_initial_eval( + initial_eval, formatted_inputs, use_weights=self.use_weights ) total_attrib, weights = self._attribute_with_cross_tensor_feature_masks( @@ -358,7 +456,9 @@ def attribute( return cast( TensorOrTupleOfTensorsGeneric, - self._generate_result(total_attrib, weights, is_inputs_tuple), + format_result( + total_attrib, weights, is_inputs_tuple, use_weights=self.use_weights + ), ) def _attribute_with_cross_tensor_feature_masks( @@ -586,8 +686,8 @@ def _initial_eval_to_processed_initial_eval_fut( "initial_eval_to_processed_initial_eval_fut: " "initial_eval should be a Tensor" ) - result = self._process_initial_eval( - initial_eval_processed, formatted_inputs + result = process_initial_eval( + initial_eval_processed, formatted_inputs, use_weights=self.use_weights ) except FeatureAblationFutureError as e: @@ -886,10 +986,8 @@ def _generate_async_result_cross_tensor( ) result_fut = collect_all(accumulate_fut_list).then( - lambda x: self._generate_result( - total_attrib, - weights, - is_inputs_tuple, + lambda x: format_result( + total_attrib, weights, is_inputs_tuple, use_weights=self.use_weights ) ) @@ -955,70 +1053,6 @@ def _eval_fut_to_ablated_out_fut_cross_tensor( ) from e return total_attrib, weights - def _parse_forward_out(self, forward_output: Tensor) -> Tensor: - """ - A temp wrapper for global _run_forward util to force forward output - type assertion & conversion. - Remove after the strict logic is supported by all attr classes - """ - if isinstance(forward_output, Tensor): - return forward_output - - output_type = type(forward_output) - assert output_type is int or output_type is float, ( - "the return of forward_func must be a tensor, int, or float," - f" received: {forward_output}" - ) - - # using python built-in type as torch dtype - # int -> torch.int64, float -> torch.float64 - # ref: https://github.com/pytorch/pytorch/pull/21215 - return torch.tensor(forward_output, dtype=cast(dtype, output_type)) - - def _process_initial_eval( - self, - initial_eval: Tensor, - inputs: TensorOrTupleOfTensorsGeneric, - ) -> Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]: - initial_eval = self._parse_forward_out(initial_eval) - - # number of elements in the output of forward_func - n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1 - - # flatten eval outputs into 1D (n_outputs) - # add the leading dim for n_feature_perturbed - flattened_initial_eval = initial_eval.reshape(1, -1) - - # Initialize attribution totals and counts - attrib_type = flattened_initial_eval.dtype - - total_attrib = [ - # attribute w.r.t each output element - torch.zeros( - (n_outputs,) + input.shape[1:], - dtype=attrib_type, - device=input.device, - ) - for input in inputs - ] - - # Weights are used in cases where ablations may be overlapping. - weights = [] - if self.use_weights: - weights = [ - torch.zeros((n_outputs,) + input.shape[1:], device=input.device).float() - for input in inputs - ] - - return ( - total_attrib, - weights, - initial_eval, - flattened_initial_eval, - n_outputs, - attrib_type, - ) - def _process_ablated_out_full( self, modified_eval: Tensor, @@ -1033,7 +1067,7 @@ def _process_ablated_out_full( attrib_type: dtype, perturbations_per_eval: int, ) -> Tuple[List[Tensor], List[Tensor]]: - modified_eval = self._parse_forward_out(modified_eval) + modified_eval = _parse_forward_out(modified_eval) # if perturbations_per_eval > 1, the output shape must grow with # input and not be aggregated current_batch_size = inputs[0].shape[0] @@ -1086,19 +1120,3 @@ def _process_ablated_out_full( total_attrib[i] += (eval_diff * mask.to(attrib_type)).sum(dim=0) return total_attrib, weights - - def _generate_result( - self, - total_attrib: List[Tensor], - weights: List[Tensor], - is_inputs_tuple: bool, - ) -> Union[Tensor, Tuple[Tensor, ...]]: - # Divide total attributions by counts and return formatted attributions - if self.use_weights: - attrib = tuple( - single_attrib.float() / weight - for single_attrib, weight in zip(total_attrib, weights) - ) - else: - attrib = tuple(total_attrib) - return _format_output(is_inputs_tuple, attrib) diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index a36168b03..f663baa2d 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -581,7 +581,6 @@ def forward( self.relu(lin1_out) else: relu_out = self.relu(lin1_out) - # pyre-fixme [29]: `typing.Type[Future]` is not a function result = Future() lin2_out = self.linear2(relu_out) if multidim_output: diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index e18b1fda1..64c36d42b 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -12,7 +12,11 @@ import torch from captum._utils.common import _construct_future_forward from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric -from captum.attr._core.feature_ablation import FeatureAblation +from captum.attr._core.feature_ablation import ( + _parse_forward_out, + FeatureAblation, + format_result, +) from captum.attr._core.noise_tunnel import NoiseTunnel from captum.attr._utils.attribution import Attribution from captum.testing.helpers import BaseTest @@ -595,7 +599,6 @@ def slow_set_future(fut: torch.futures.Future[Tensor], value: Tensor) -> None: fut.set_result(out) def forward_func(inp: Tensor) -> torch.futures.Future[Tensor]: - # pyre-fixme[29]: `typing.Type[torch.futures.Future]` is not a function. fut: torch.futures.Future[Tensor] = torch.futures.Future() t = threading.Thread(target=slow_set_future, args=(fut, inp)) t.start() @@ -900,5 +903,115 @@ def _ablation_test_assert( assertTensorAlmostEqual(self, attributions, expected_ablation) +class TestParseForwardOutput(BaseTest): + + def test_parse_forward_out_tensor_passthrough(self) -> None: + input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + result = _parse_forward_out(input_tensor) + + self.assertIs(result, input_tensor) + assertTensorAlmostEqual(self, result, input_tensor) + + def test_parse_forward_out_python_int(self) -> None: + input_value = 42 + result = _parse_forward_out(input_value) + + self.assertIsInstance(result, Tensor) + self.assertEqual(result.dtype, torch.int64) + assertTensorAlmostEqual(self, result, torch.tensor(42)) + + def test_parse_forward_out_python_float(self) -> None: + input_value = 3.14 + result = _parse_forward_out(input_value) + + self.assertIsInstance(result, Tensor) + self.assertEqual(result.dtype, torch.float64) + assertTensorAlmostEqual(self, result, torch.tensor(3.14)) + + def test_parse_forward_out_invalid_none(self) -> None: + with self.assertRaises(AssertionError): + _parse_forward_out(None) + + +class TestFormatResult(BaseTest): + + def test_format_result_single_tensor_no_weights(self) -> None: + total_attrib = [torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])] + weights = [] + is_inputs_tuple = False + use_weights = False + + result = format_result(total_attrib, weights, is_inputs_tuple, use_weights) + + self.assertIsInstance(result, Tensor) + assert isinstance(result, Tensor) # Type narrowing for pyre + self.assertEqual(result.shape, (2, 3)) + assertTensorAlmostEqual( + self, result, torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + ) + + def test_format_result_tuple_output_no_weights(self) -> None: + total_attrib = [ + torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + torch.tensor([[5.0, 6.0], [7.0, 8.0]]), + ] + weights = [] + is_inputs_tuple = True + use_weights = False + + result = format_result(total_attrib, weights, is_inputs_tuple, use_weights) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + assertTensorAlmostEqual(self, result[0], torch.tensor([[1.0, 2.0], [3.0, 4.0]])) + assertTensorAlmostEqual(self, result[1], torch.tensor([[5.0, 6.0], [7.0, 8.0]])) + + def test_format_result_single_tensor_with_weights(self) -> None: + total_attrib = [torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]])] + weights = [torch.tensor([[2.0, 4.0, 5.0], [8.0, 10.0, 12.0]])] + is_inputs_tuple = False + use_weights = True + + result = format_result(total_attrib, weights, is_inputs_tuple, use_weights) + + self.assertIsInstance(result, Tensor) + expected = torch.tensor([[5.0, 5.0, 6.0], [5.0, 5.0, 5.0]]) + assertTensorAlmostEqual(self, result, expected) + + def test_format_result_tuple_output_with_weights(self) -> None: + total_attrib = [ + torch.tensor([[10.0, 20.0], [30.0, 40.0]]), + torch.tensor([[50.0, 60.0], [70.0, 80.0]]), + ] + weights = [ + torch.tensor([[2.0, 4.0], [5.0, 8.0]]), + torch.tensor([[10.0, 12.0], [14.0, 16.0]]), + ] + is_inputs_tuple = True + use_weights = True + + result = format_result(total_attrib, weights, is_inputs_tuple, use_weights) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + assertTensorAlmostEqual(self, result[0], torch.tensor([[5.0, 5.0], [6.0, 5.0]])) + assertTensorAlmostEqual(self, result[1], torch.tensor([[5.0, 5.0], [5.0, 5.0]])) + + def test_format_result_integer_dtype_no_weights(self) -> None: + total_attrib = [torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32)] + weights = [] + is_inputs_tuple = False + use_weights = False + + result = format_result(total_attrib, weights, is_inputs_tuple, use_weights) + + self.assertIsInstance(result, Tensor) + assert isinstance(result, Tensor) # Type narrowing for pyre + self.assertEqual(result.dtype, torch.int32) + assertTensorAlmostEqual( + self, result, torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32) + ) + + if __name__ == "__main__": unittest.main()