diff --git a/captum/attr/_core/shapley_value.py b/captum/attr/_core/shapley_value.py index 7a0f58f02..3288cc423 100644 --- a/captum/attr/_core/shapley_value.py +++ b/captum/attr/_core/shapley_value.py @@ -130,10 +130,6 @@ def attribute( show_progress: bool = False, ) -> TensorOrTupleOfTensorsGeneric: r""" - NOTE: The feature_mask argument differs from other perturbation based - methods, since feature indices can overlap across tensors. See the - description of the feature_mask argument below for more details. - Args: inputs (Tensor or tuple[Tensor, ...]): Input for which Shapley value @@ -225,8 +221,7 @@ def attribute( all tensors should be integers in the range 0 to num_features - 1, and indices corresponding to the same feature should have the same value. - Note that features are grouped across tensors - (unlike feature ablation and occlusion), so + Note that features are grouped across tensors, so if the same index is used in different tensors, those features are still grouped and added simultaneously. If the forward function returns a single scalar per batch, @@ -521,7 +516,7 @@ def attribute_future( prev_result_tuple: Future[ Tuple[Tensor, Tensor, Size, List[Tensor], bool] ] = initial_eval.then( - lambda inp=initial_eval: self._initialEvalToPrevResultsTuple( # type: ignore # noqa: E501 line too long + lambda inp=initial_eval: self._initial_eval_to_prev_results_tuple( # type: ignore # noqa: E501 line too long inp, num_examples, perturbations_per_eval, @@ -537,7 +532,7 @@ def attribute_future( total_features, n_samples ): prev_result_tuple = prev_result_tuple.then( - lambda inp=prev_result_tuple: self._setPrevResultsToInitialEval(inp) # type: ignore # noqa: E501 line too long + lambda inp=prev_result_tuple: self._set_prev_results_to_initial_eval(inp) # type: ignore # noqa: E501 line too long ) iter_count += 1 @@ -590,7 +585,7 @@ def attribute_future( ] = collect_all([prev_result_tuple, modified_eval]) prev_result_tuple = eval_futs.then( - lambda evals=eval_futs, masks=current_masks: self._evalFutToPrevResultsTuple( # type: ignore # noqa: E501 line too long + lambda evals=eval_futs, masks=current_masks: self._eval_fut_to_prev_results_tuple( # type: ignore # noqa: E501 line too long evals, num_examples, inputs_tuple, masks ) ) @@ -602,14 +597,14 @@ def attribute_future( # formatted attributions. formatted_attr: Future[Union[Tensor, tuple[Tensor, ...]]] = ( prev_result_tuple.then( - lambda inp=prev_result_tuple: self._prevResultTupleToFormattedAttr( # type: ignore # noqa: E501 line too long + lambda inp=prev_result_tuple: self._prev_result_tuple_to_formatted_attr( # type: ignore # noqa: E501 line too long inp, iter_count, is_inputs_tuple ) ) ) return cast(Future[TensorOrTupleOfTensorsGeneric], formatted_attr) - def _initialEvalToPrevResultsTuple( + def _initial_eval_to_prev_results_tuple( self, initial_eval: Future[Tensor], num_examples: int, @@ -657,7 +652,7 @@ def _initialEvalToPrevResultsTuple( ) from e return result - def _setPrevResultsToInitialEval( + def _set_prev_results_to_initial_eval( self, processed_initial_eval: Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]], ) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]: @@ -669,7 +664,7 @@ def _setPrevResultsToInitialEval( prev_results = initial_eval return (initial_eval, prev_results, output_shape, total_attrib, agg_output_mode) - def _evalFutToPrevResultsTuple( + def _eval_fut_to_prev_results_tuple( self, eval_futs: Future[ List[ @@ -755,7 +750,7 @@ def _evalFutToPrevResultsTuple( ) return result - def _prevResultTupleToFormattedAttr( + def _prev_result_tuple_to_formatted_attr( self, prev_result_tuple: Future[ Tuple[Tensor, Tensor, Tuple[int], List[Tensor], bool]