Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from captum._utils.sample_gradient import SampleGradientWrapper
from captum._utils.typing import (
ModuleOrModuleList,
SliceIntType,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
Expand Down Expand Up @@ -775,8 +776,11 @@ def compute_layer_gradients_and_eval(

def construct_neuron_grad_fn(
layer: Module,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
device_ids: Union[None, List[int]] = None,
attribute_to_neuron_input: bool = False,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
Expand Down
7 changes: 7 additions & 0 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@
TensorLikeList5D,
]

try:
# Subscripted slice syntax is not supported in previous Python versions,
# falling back to slice type.
SliceIntType = slice[int, int, int]
except TypeError:
# pyre-fixme[24]: Generic type `slice` expects 3 type parameters.
SliceIntType = slice # type: ignore

# Necessary for Python >=3.7 and <3.9!
if TYPE_CHECKING:
Expand Down
55 changes: 30 additions & 25 deletions captum/attr/_core/neuron/neuron_conductance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
_verify_select_neuron,
)
from captum._utils.gradient import compute_layer_gradients_and_eval
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum._utils.typing import (
BaselineType,
SliceIntType,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from captum.attr._utils.approximation_methods import approximation_parameters
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
from captum.attr._utils.batching import _batch_attribution
Expand All @@ -39,8 +44,7 @@ class NeuronConductance(NeuronAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
multiply_by_inputs: bool = True,
Expand Down Expand Up @@ -94,8 +98,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[int, ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Optional[object] = None,
Expand Down Expand Up @@ -285,28 +292,24 @@ def attribute(
" results.",
stacklevel=1,
)
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `TensorOrTupleOfTensorsGeneric`.
is_inputs_tuple = _is_tuple(inputs)

# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
# `Tuple[Tensor, ...]`.
inputs, baselines = _format_input_baseline(inputs, baselines)
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
# `TensorOrTupleOfTensorsGeneric`.
_validate_input(inputs, baselines, n_steps, method)
formatted_inputs, formatted_baselines = _format_input_baseline(
inputs, baselines
)
_validate_input(formatted_inputs, formatted_baselines, n_steps, method)

num_examples = inputs[0].shape[0]
num_examples = formatted_inputs[0].shape[0]

if internal_batch_size is not None:
num_examples = inputs[0].shape[0]
num_examples = formatted_inputs[0].shape[0]
attrs = _batch_attribution(
self,
num_examples,
internal_batch_size,
n_steps,
inputs=inputs,
baselines=baselines,
inputs=formatted_inputs,
baselines=formatted_baselines,
neuron_selector=neuron_selector,
target=target,
additional_forward_args=additional_forward_args,
Expand All @@ -315,11 +318,9 @@ def attribute(
)
else:
attrs = self._attribute(
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
# got `TensorOrTupleOfTensorsGeneric`.
inputs=inputs,
inputs=formatted_inputs,
neuron_selector=neuron_selector,
baselines=baselines,
baselines=formatted_baselines,
target=target,
additional_forward_args=additional_forward_args,
n_steps=n_steps,
Expand All @@ -334,8 +335,11 @@ def attribute(
def _attribute(
self,
inputs: Tuple[Tensor, ...],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[int, ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: Tuple[Union[Tensor, int, float], ...],
target: TargetType = None,
additional_forward_args: Optional[object] = None,
Expand Down Expand Up @@ -409,8 +413,9 @@ def _attribute(

# Aggregates across all steps for each tensor in the input tuple
total_grads = tuple(
# pyre-fixme[6]: For 4th argument expected `Tuple[int, ...]` but got `Size`.
_reshape_and_sum(scaled_grad, n_steps, num_examples, input_grad.shape[1:])
_reshape_and_sum(
scaled_grad, n_steps, num_examples, tuple(input_grad.shape[1:])
)
for (scaled_grad, input_grad) in zip(scaled_grads, input_grads)
)

Expand Down
20 changes: 15 additions & 5 deletions captum/attr/_core/neuron/neuron_deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from typing import Callable, cast, Optional, Tuple, Union

from captum._utils.gradient import construct_neuron_grad_fn
from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric
from captum._utils.typing import (
BaselineType,
SliceIntType,
TensorOrTupleOfTensorsGeneric,
)
from captum.attr._core.deep_lift import DeepLift, DeepLiftShap
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
from captum.log import log_usage
Expand Down Expand Up @@ -79,8 +83,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: BaselineType = None,
additional_forward_args: Optional[object] = None,
attribute_to_neuron_input: bool = False,
Expand Down Expand Up @@ -309,8 +316,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: Union[
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
],
Expand Down
19 changes: 12 additions & 7 deletions captum/attr/_core/neuron/neuron_feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import torch
from captum._utils.common import _verify_select_neuron
from captum._utils.gradient import _forward_layer_eval
from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric
from captum._utils.typing import (
BaselineType,
SliceIntType,
TensorOrTupleOfTensorsGeneric,
)
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._utils.attribution import NeuronAttribution, PerturbationAttribution
from captum.log import log_usage
Expand All @@ -31,8 +35,7 @@ class NeuronFeatureAblation(NeuronAttribution, PerturbationAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Union[int, float, Tensor]],
layer: Module,
device_ids: Union[None, List[int]] = None,
) -> None:
Expand Down Expand Up @@ -61,8 +64,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: BaselineType = None,
additional_forward_args: Optional[object] = None,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
Expand Down Expand Up @@ -250,8 +256,7 @@ def attribute(
>>> feature_mask=feature_mask)
"""

# pyre-fixme[3]: Return type must be annotated.
def neuron_forward_func(*args: Any):
def neuron_forward_func(*args: Any) -> Tensor:
with torch.no_grad():
layer_eval = _forward_layer_eval(
self.forward_func,
Expand Down
12 changes: 7 additions & 5 deletions captum/attr/_core/neuron/neuron_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, List, Optional, Tuple, Union

from captum._utils.gradient import construct_neuron_grad_fn
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from captum._utils.typing import SliceIntType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.integrated_gradients import IntegratedGradients
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
from captum.log import log_usage
Expand All @@ -27,8 +27,7 @@ class NeuronIntegratedGradients(NeuronAttribution, GradientAttribution):

def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
forward_func: Callable[..., Tensor],
layer: Module,
device_ids: Union[None, List[int]] = None,
multiply_by_inputs: bool = True,
Expand Down Expand Up @@ -76,8 +75,11 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
neuron_selector: Union[
int,
Tuple[Union[int, SliceIntType], ...],
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
],
baselines: Union[None, Tensor, Tuple[Tensor, ...]] = None,
additional_forward_args: Optional[object] = None,
n_steps: int = 50,
Expand Down
Loading