Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions captum/attr/_core/shapley_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,15 +397,19 @@ def attribute(
if show_progress:
attr_progress.update()
if agg_output_mode:
eval_diff = modified_eval - prev_results
eval_diff = (modified_eval - prev_results).to(
inputs_tuple[0].device
)
prev_results = modified_eval
else:
# when perturb_per_eval > 1, every num_examples stands for
# one perturb. Since the perturbs are from a consecutive
# perumuation, each diff of a perturb is its eval minus
# the eval of the previous perturb
all_eval = torch.cat((prev_results, modified_eval), dim=0)
eval_diff = all_eval[num_examples:] - all_eval[:-num_examples]
eval_diff = (
all_eval[num_examples:] - all_eval[:-num_examples]
).to(inputs_tuple[0].device)
prev_results = all_eval[-num_examples:]

for j in range(len(total_attrib)):
Expand Down Expand Up @@ -689,7 +693,7 @@ def _evalFutToPrevResultsTuple(
agg_output_mode,
) = prev_results_tuple
if agg_output_mode:
eval_diff = modified_eval - prev_results
eval_diff = (modified_eval - prev_results).to(inputs_tuple[0].device)
prev_results = modified_eval
else:
# when perturb_per_eval > 1, every num_examples stands for
Expand All @@ -698,7 +702,9 @@ def _evalFutToPrevResultsTuple(
# the eval of the previous perturb

all_eval = torch.cat((prev_results, modified_eval), dim=0)
eval_diff = all_eval[num_examples:] - all_eval[:-num_examples]
eval_diff = (all_eval[num_examples:] - all_eval[:-num_examples]).to(
inputs_tuple[0].device
)
prev_results = all_eval[-num_examples:]

for j in range(len(total_attrib)):
Expand Down Expand Up @@ -799,7 +805,10 @@ def _perturbation_generator(
)
current_tensors_list.append(current_tensors)
current_mask_list.append(
tuple(mask == feature_permutation[i] for mask in input_masks)
tuple(
(mask == feature_permutation[i]).to(inputs[0].device)
for mask in input_masks
)
)
if len(current_tensors_list) == perturbations_per_eval:
combined_inputs = tuple(
Expand Down