Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions captum/attr/_core/shapley_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ def attribute(
show_progress: bool = False,
) -> TensorOrTupleOfTensorsGeneric:
r"""
NOTE: The feature_mask argument differs from other perturbation based
methods, since feature indices can overlap across tensors. See the
description of the feature_mask argument below for more details.

Args:

inputs (Tensor or tuple[Tensor, ...]): Input for which Shapley value
Expand Down Expand Up @@ -225,8 +221,7 @@ def attribute(
all tensors should be integers in the range 0 to
num_features - 1, and indices corresponding to the same
feature should have the same value.
Note that features are grouped across tensors
(unlike feature ablation and occlusion), so
Note that features are grouped across tensors, so
if the same index is used in different tensors, those
features are still grouped and added simultaneously.
If the forward function returns a single scalar per batch,
Expand Down Expand Up @@ -521,7 +516,7 @@ def attribute_future(
prev_result_tuple: Future[
Tuple[Tensor, Tensor, Size, List[Tensor], bool]
] = initial_eval.then(
lambda inp=initial_eval: self._initialEvalToPrevResultsTuple( # type: ignore # noqa: E501 line too long
lambda inp=initial_eval: self._initial_eval_to_prev_results_tuple( # type: ignore # noqa: E501 line too long
inp,
num_examples,
perturbations_per_eval,
Expand All @@ -537,7 +532,7 @@ def attribute_future(
total_features, n_samples
):
prev_result_tuple = prev_result_tuple.then(
lambda inp=prev_result_tuple: self._setPrevResultsToInitialEval(inp) # type: ignore # noqa: E501 line too long
lambda inp=prev_result_tuple: self._set_prev_results_to_initial_eval(inp) # type: ignore # noqa: E501 line too long
)

iter_count += 1
Expand Down Expand Up @@ -590,7 +585,7 @@ def attribute_future(
] = collect_all([prev_result_tuple, modified_eval])

prev_result_tuple = eval_futs.then(
lambda evals=eval_futs, masks=current_masks: self._evalFutToPrevResultsTuple( # type: ignore # noqa: E501 line too long
lambda evals=eval_futs, masks=current_masks: self._eval_fut_to_prev_results_tuple( # type: ignore # noqa: E501 line too long
evals, num_examples, inputs_tuple, masks
)
)
Expand All @@ -602,14 +597,14 @@ def attribute_future(
# formatted attributions.
formatted_attr: Future[Union[Tensor, tuple[Tensor, ...]]] = (
prev_result_tuple.then(
lambda inp=prev_result_tuple: self._prevResultTupleToFormattedAttr( # type: ignore # noqa: E501 line too long
lambda inp=prev_result_tuple: self._prev_result_tuple_to_formatted_attr( # type: ignore # noqa: E501 line too long
inp, iter_count, is_inputs_tuple
)
)
)
return cast(Future[TensorOrTupleOfTensorsGeneric], formatted_attr)

def _initialEvalToPrevResultsTuple(
def _initial_eval_to_prev_results_tuple(
self,
initial_eval: Future[Tensor],
num_examples: int,
Expand Down Expand Up @@ -657,7 +652,7 @@ def _initialEvalToPrevResultsTuple(
) from e
return result

def _setPrevResultsToInitialEval(
def _set_prev_results_to_initial_eval(
self,
processed_initial_eval: Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]],
) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]:
Expand All @@ -669,7 +664,7 @@ def _setPrevResultsToInitialEval(
prev_results = initial_eval
return (initial_eval, prev_results, output_shape, total_attrib, agg_output_mode)

def _evalFutToPrevResultsTuple(
def _eval_fut_to_prev_results_tuple(
self,
eval_futs: Future[
List[
Expand Down Expand Up @@ -755,7 +750,7 @@ def _evalFutToPrevResultsTuple(
)
return result

def _prevResultTupleToFormattedAttr(
def _prev_result_tuple_to_formatted_attr(
self,
prev_result_tuple: Future[
Tuple[Tensor, Tensor, Tuple[int], List[Tensor], bool]
Expand Down
Loading