Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,12 @@ def _format_tensor_into_tuples(inputs: None) -> None: ...

@overload
def _format_tensor_into_tuples(
inputs: Union[Tensor, Tuple[Tensor, ...]]
inputs: Union[Tensor, Tuple[Tensor, ...]],
) -> Tuple[Tensor, ...]: ...


def _format_tensor_into_tuples(
inputs: Union[None, Tensor, Tuple[Tensor, ...]]
inputs: Union[None, Tensor, Tuple[Tensor, ...]],
) -> Union[None, Tuple[Tensor, ...]]:
if inputs is None:
return None
Expand All @@ -261,7 +261,7 @@ def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any:


def _format_float_or_tensor_into_tuples(
inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]
inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]],
) -> Tuple[Union[float, Tensor], ...]:
if not isinstance(inputs, tuple):
assert isinstance(
Expand All @@ -276,7 +276,7 @@ def _format_float_or_tensor_into_tuples(
@overload
def _format_additional_forward_args(
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
additional_forward_args: Union[Tensor, Tuple]
additional_forward_args: Union[Tensor, Tuple],
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
) -> Tuple: ...

Expand Down
6 changes: 2 additions & 4 deletions captum/attr/_core/dataloader_attr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

# pyre-strict

from collections import defaultdict
from copy import copy
from typing import Callable, cast, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -30,7 +31,6 @@ class InputRole:


# default reducer wehn reduce is None. Simply concat the outputs by the batch dimension
# pyre-fixme[2]: Parameter must be annotated.
def _concat_tensors(accum: Optional[Tensor], cur_output: Tensor, _) -> Tensor:
return cur_output if accum is None else torch.cat([accum, cur_output])

Expand Down Expand Up @@ -87,9 +87,7 @@ def _perturb_inputs(
else:
baseline = baselines[attr_inp_count]

# pyre-fixme[58]: `*` is not supported for operand types `object` and
# `Tensor`.
perturbed_inp = inp * pert_mask + baseline * (1 - pert_mask)
perturbed_inp = cast(Tensor, inp) * pert_mask + baseline * (1 - pert_mask)
perturbed_inputs.append(perturbed_inp)

attr_inp_count += 1
Expand Down
2 changes: 1 addition & 1 deletion captum/attr/_core/layer/layer_lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _get_output_relevance(

@staticmethod
def _convert_list_to_tuple(
relevances: Union[List[T], Tuple[T, ...]]
relevances: Union[List[T], Tuple[T, ...]],
) -> Tuple[T, ...]:
if isinstance(relevances, list):
return tuple(relevances)
Expand Down
11 changes: 9 additions & 2 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,12 +559,19 @@ def _forward_func(
outputs.past_key_values = DynamicCache.from_legacy_cache(
outputs.past_key_values
)
# nn.Module typing suggests non-base attributes are modules or
# tensors
_update_model_kwargs_for_generation = (
self.model._update_model_kwargs_for_generation
)
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
model_kwargs = self.model._update_model_kwargs_for_generation(
model_kwargs = _update_model_kwargs_for_generation( # type: ignore
outputs, model_kwargs
)
# nn.Module typing suggests non-base attributes are modules or tensors
prep_inputs_for_generation = self.model.prepare_inputs_for_generation
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
model_inputs = self.model.prepare_inputs_for_generation(
model_inputs = prep_inputs_for_generation( # type: ignore
model_inp, **model_kwargs
)
outputs = self.model.forward(**model_inputs)
Expand Down
2 changes: 1 addition & 1 deletion captum/attr/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def _find_output_mode_and_verify(


def _construct_default_feature_mask(
inputs: Tuple[Tensor, ...]
inputs: Tuple[Tensor, ...],
) -> Tuple[Tuple[Tensor, ...], int]:
feature_mask = []
current_num_features = 0
Expand Down
13 changes: 6 additions & 7 deletions captum/attr/_utils/stat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable, List, Optional, TYPE_CHECKING

from typing import Any, Callable, cast, List, Optional, TYPE_CHECKING

import torch
from torch import Tensor
Expand Down Expand Up @@ -117,20 +118,18 @@ def get(self) -> Optional[Tensor]:
return self.rolling_mean

def init(self) -> None:
# pyre-fixme[8]: Attribute has type `Optional[Count]`; used as `Optional[Stat]`.
self.n = self._get_stat(Count()) # type: ignore
self.n = cast(Count, self._get_stat(Count()))

def update(self, x: Tensor) -> None:
# pyre-fixme[16]: `Optional` has no attribute `get`.
n = self.n.get() # type: ignore
n = cast(Count, self.n).get()

if self.rolling_mean is None:
# Ensures rolling_mean is a float tensor
self.rolling_mean = x.clone() if x.is_floating_point() else x.double()
else:
delta = x - self.rolling_mean
# pyre-fixme[16]: `Optional` has no attribute `__iadd__`.
self.rolling_mean += delta / n
# pyre-ignore[16]: `Optional` has no attribute `__iadd__` (false positive)
self.rolling_mean += delta / cast(int, n)


class MSE(Stat):
Expand Down
4 changes: 2 additions & 2 deletions captum/influence/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def __len__(self) -> int:

def _format_inputs_dataset(
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
inputs_dataset: Union[Tuple[Any, ...], DataLoader]
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
) -> DataLoader:
# if `inputs_dataset` is not a `DataLoader`, turn it into one.
# `_DatasetFromList` turns a list into a `Dataset` where `__getitem__`
Expand Down Expand Up @@ -604,7 +604,7 @@ def _flatten_params(_params: Tuple[Tensor, ...]) -> Tensor:

# pyre-fixme[3]: Return type must be annotated.
def _unflatten_params_factory(
param_shapes: Union[List[Tuple[int, ...]], Tuple[Tensor, ...]]
param_shapes: Union[List[Tuple[int, ...]], Tuple[Tensor, ...]],
):
"""
returns a function which is the inverse of `_flatten_params`
Expand Down
2 changes: 1 addition & 1 deletion captum/insights/attr_vis/_utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def format_transforms(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
transforms: Optional[Union[Callable, List[Callable]]]
transforms: Optional[Union[Callable, List[Callable]]],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
) -> List[Callable]:
if transforms is None:
Expand Down
9 changes: 6 additions & 3 deletions captum/metrics/_core/infidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def infidelity_perturb_func_decorator(
"""

def sub_infidelity_perturb_func_decorator(
perturb_func: Callable[..., TensorOrTupleOfTensorsGeneric]
perturb_func: Callable[..., TensorOrTupleOfTensorsGeneric],
) -> Callable[
[TensorOrTupleOfTensorsGeneric, BaselineType],
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
Expand Down Expand Up @@ -611,6 +611,11 @@ def _next_infidelity_tensors(
targets_expanded,
additional_forward_args_expanded,
)
if isinstance(inputs_perturbed_fwd, torch.futures.Future):
raise NotImplementedError(
f"Outputs from forward_func of type {type(inputs_perturbed_fwd)} are "
"not yet supported."
)
inputs_fwd = _run_forward(forward_func, inputs, target, additional_forward_args)
# _run_forward may return future of Tensor,
# but we don't support it here now
Expand All @@ -619,8 +624,6 @@ def _next_infidelity_tensors(
inputs_fwd = torch.repeat_interleave(
inputs_fwd, current_n_perturb_samples, dim=0
)
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
# `Union[Future[Tensor], Tensor]`.
perturbed_fwd_diffs = inputs_fwd - inputs_perturbed_fwd
attributions_expanded = tuple(
torch.repeat_interleave(attribution, current_n_perturb_samples, dim=0)
Expand Down
6 changes: 3 additions & 3 deletions captum/testing/helpers/influence/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _isSorted(x, key=lambda x: x, descending=True) -> bool:

# pyre-fixme[2]: Parameter must be annotated.
def _wrap_model_in_dataparallel(net) -> Module:
alt_device_ids = [0] + [x for x in range(torch.cuda.device_count() - 1, 0, -1)]
alt_device_ids = [0] + list(range(torch.cuda.device_count() - 1, 0, -1))
net = net.cuda()
return torch.nn.DataParallel(net, device_ids=alt_device_ids)

Expand Down Expand Up @@ -505,7 +505,7 @@ def get_random_model_and_data(

# pyre-fixme[3]: Return type must be annotated.
def generate_symmetric_matrix_given_eigenvalues(
eigenvalues: Union[Tensor, List[float]]
eigenvalues: Union[Tensor, List[float]],
):
"""
following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L123 # noqa: E501
Expand All @@ -523,7 +523,7 @@ def generate_symmetric_matrix_given_eigenvalues(


def generate_assymetric_matrix_given_eigenvalues(
eigenvalues: Union[Tensor, List[float]]
eigenvalues: Union[Tensor, List[float]],
) -> Tensor:
"""
following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L105 # noqa: E501
Expand Down
5 changes: 3 additions & 2 deletions tests/attr/layer/test_layer_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def test_multiple_layers_multiple_inputs_shared_input(self) -> None:
self,
# last input for second layer is first input =>
# add the attributions
(attribs_inputs[0] + attribs_inputs[1][-1],) + attribs_inputs[1][0:-1],
(attribs_inputs[0] + attribs_inputs[1][-1],) # type: ignore
+ attribs_inputs[1][0:-1], # type: ignore
attribs_inputs_regular_ig,
delta=1e-5,
)
Expand Down Expand Up @@ -183,7 +184,7 @@ def test_multiple_layers_multiple_input_outputs(self) -> None:

assertTensorTuplesAlmostEqual(
self,
(attribs_inputs[0],) + attribs_inputs[1],
(attribs_inputs[0],) + attribs_inputs[1], # type: ignore
attribs_inputs_regular_ig,
delta=1e-7,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/concept/test_tcav.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def _compute_cavs_interpret(
attribute_to_layer_input: bool = False,
) -> None:
def wrap_in_list_if_not_already(
input: Union[str, float, List[float], List[str]]
input: Union[str, float, List[float], List[str]],
) -> Union[List[Union[float, str]], List[float], List[str]]:
return (
input
Expand Down
4 changes: 2 additions & 2 deletions tests/metrics/test_infidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _local_perturb_func_default(
# pyre-ignore[43]: The implementation of `_local_perturb_func` does not accept all
# possible arguments of overload defined on line `43`.
def _local_perturb_func(
inputs: Tuple[Tensor, ...]
inputs: Tuple[Tensor, ...],
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: ...


Expand Down Expand Up @@ -83,7 +83,7 @@ def _global_perturb_func1_default(
# pyre-fixme[43]: The implementation of `_global_perturb_func1` does not accept all
# possible arguments of overload defined on line `74`.
def _global_perturb_func1(
inputs: Tuple[Tensor, ...]
inputs: Tuple[Tensor, ...],
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: ...


Expand Down