Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 48 additions & 31 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python3

# pyre-strict
import typing
from enum import Enum
from functools import reduce
from inspect import signature
Expand Down Expand Up @@ -65,7 +64,7 @@ def safe_div(
denom: Union[Tensor, int, float],
default_denom: Union[Tensor, int, float] = 1.0,
) -> Tensor:
r"""
"""
A simple utility function to perform `numerator / denom`
if the statement is undefined => result will be `numerator / default_denorm`
"""
Expand All @@ -81,15 +80,15 @@ def safe_div(
return numerator / torch.where(denom != 0, denom, default_denom)


@typing.overload
@overload
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...


@typing.overload
@overload
def _is_tuple(inputs: Tensor) -> Literal[False]: ...


@typing.overload
@overload
def _is_tuple(
inputs: Union[Tensor, Tuple[Tensor, ...]],
) -> bool: ...
Expand Down Expand Up @@ -150,7 +149,7 @@ def _validate_input(


def _zeros(inputs: Tuple[Tensor, ...]) -> Tuple[int, ...]:
r"""
"""
Takes a tuple of tensors as input and returns a tuple that has the same
length as `inputs` with each element as the integer 0.
"""
Expand All @@ -160,6 +159,10 @@ def _zeros(inputs: Tuple[Tensor, ...]) -> Tuple[int, ...]:
def _format_baseline(
baselines: BaselineType, inputs: Tuple[Tensor, ...]
) -> Tuple[Union[Tensor, int, float], ...]:
"""
Converts baselines to tuple format, returning zeros if None,
or wrapping single values in a tuple.
"""
if baselines is None:
return _zeros(inputs)

Expand Down Expand Up @@ -197,11 +200,8 @@ def _format_feature_mask(
start_idx: int = 0,
) -> Tuple[Tensor, ...]:
"""
Format a feature mask into a tuple of tensors.
The `inputs` should be correctly formatted first
If `feature_mask` is None, assign each non-batch dimension with a consecutive
integer from `start_idx`.
If `feature_mask` is a tensor, wrap it in a tuple.
Converts feature mask to tuple format, auto-generating default mask
from start_idx if None.
"""
if feature_mask is None:
formatted_mask = []
Expand Down Expand Up @@ -240,6 +240,9 @@ def _format_tensor_into_tuples(
def _format_tensor_into_tuples(
inputs: Union[None, Tensor, Tuple[Tensor, ...]],
) -> Union[None, Tuple[Tensor, ...]]:
"""
Converts tensor inputs to tuple format, returning None unchanged if None.
"""
if inputs is None:
return None
if not isinstance(inputs, tuple):
Expand All @@ -252,6 +255,10 @@ def _format_tensor_into_tuples(


def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any:
"""
Returns inputs unchanged if already tuple/list
and unpack_inputs=True, otherwise wraps in tuple.
"""
return (
inputs
if (isinstance(inputs, tuple) or isinstance(inputs, list)) and unpack_inputs
Expand All @@ -262,6 +269,9 @@ def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any:
def _format_float_or_tensor_into_tuples(
inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]],
) -> Tuple[Union[float, Tensor], ...]:
"""
Converts float or tensor inputs to tuple format, wrapping single values in a tuple.
"""
if not isinstance(inputs, tuple):
assert isinstance(
inputs, (torch.Tensor, float)
Expand All @@ -274,23 +284,28 @@ def _format_float_or_tensor_into_tuples(

@overload
def _format_additional_forward_args(
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
additional_forward_args: Union[Tensor, Tuple],
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
) -> Tuple: ...
additional_forward_args: Union[Tensor, Tuple[Any, ...]],
) -> Tuple[Any, ...]: ...


@overload
def _format_additional_forward_args( # type: ignore
def _format_additional_forward_args(
additional_forward_args: None,
) -> None: ...


@overload
def _format_additional_forward_args(
additional_forward_args: Optional[object],
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
) -> Union[None, Tuple]: ...
) -> Optional[Tuple[Any, ...]]: ...


def _format_additional_forward_args(
additional_forward_args: Optional[object],
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
) -> Union[None, Tuple]:
) -> Optional[Tuple[Any, ...]]:
"""
Converts additional forward args to tuple format, returning None unchanged if None.
"""
if additional_forward_args is not None and not isinstance(
additional_forward_args, tuple
):
Expand Down Expand Up @@ -478,21 +493,21 @@ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict) -> None:
kwargs["feature_mask"] = feature_mask


@typing.overload
@overload
def _format_output(
is_inputs_tuple: Literal[True],
output: Tuple[Tensor, ...],
) -> Tuple[Tensor, ...]: ...


@typing.overload
@overload
def _format_output(
is_inputs_tuple: Literal[False],
output: Tuple[Tensor, ...],
) -> Tensor: ...


@typing.overload
@overload
def _format_output(
is_inputs_tuple: bool, output: Tuple[Tensor, ...]
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
Expand All @@ -501,7 +516,7 @@ def _format_output(
def _format_output(
is_inputs_tuple: bool, output: Tuple[Tensor, ...]
) -> Union[Tensor, Tuple[Tensor, ...]]:
r"""
"""
In case input is a tensor and the output is returned in form of a
tuple we take the first element of the output's tuple to match the
same shape signatues of the inputs
Expand All @@ -516,21 +531,21 @@ def _format_output(
return output if is_inputs_tuple else output[0]


@typing.overload
@overload
def _format_outputs(
is_multiple_inputs: Literal[False],
outputs: List[Tuple[Tensor, ...]],
) -> Union[Tensor, Tuple[Tensor, ...]]: ...


@typing.overload
@overload
def _format_outputs(
is_multiple_inputs: Literal[True],
outputs: List[Tuple[Tensor, ...]],
) -> List[Union[Tensor, Tuple[Tensor, ...]]]: ...


@typing.overload
@overload
def _format_outputs(
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ...
Expand All @@ -539,6 +554,10 @@ def _format_outputs(
def _format_outputs(
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
"""
Formats list of output tuples: returns list if is_multiple_inputs is True,
otherwise single formatted output.
"""
assert isinstance(outputs, list), "Outputs must be a list"
assert is_multiple_inputs or len(outputs) == 1, (
"outputs should contain multiple inputs or have a single output"
Expand All @@ -554,9 +573,7 @@ def _format_outputs(

# pyre-fixme[24] Callable requires 2 arguments
def _construct_future_forward(original_forward: Callable) -> Callable:
# pyre-fixme[3] return type not specified
def future_forward(*args: Any, **kwargs: Any):
# pyre-fixme[29]: `typing.Type[torch.futures.Future]` is not a function.
def future_forward(*args: Any, **kwargs: Any) -> torch.futures.Future[Tensor]:
fut: torch.futures.Future[Tensor] = torch.futures.Future()
fut.set_result(original_forward(*args, **kwargs))
return fut
Expand Down Expand Up @@ -829,7 +846,7 @@ def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor:


def _get_module_from_name(model: Module, layer_name: str) -> Any:
r"""
"""
Returns the module (layer) object, given its (string) name
in the model.

Expand Down
Loading
Loading