diff --git a/captum/attr/_core/shapley_value.py b/captum/attr/_core/shapley_value.py index ca7f6f7e9..77111328b 100644 --- a/captum/attr/_core/shapley_value.py +++ b/captum/attr/_core/shapley_value.py @@ -397,7 +397,9 @@ def attribute( if show_progress: attr_progress.update() if agg_output_mode: - eval_diff = modified_eval - prev_results + eval_diff = (modified_eval - prev_results).to( + inputs_tuple[0].device + ) prev_results = modified_eval else: # when perturb_per_eval > 1, every num_examples stands for @@ -405,7 +407,9 @@ def attribute( # perumuation, each diff of a perturb is its eval minus # the eval of the previous perturb all_eval = torch.cat((prev_results, modified_eval), dim=0) - eval_diff = all_eval[num_examples:] - all_eval[:-num_examples] + eval_diff = ( + all_eval[num_examples:] - all_eval[:-num_examples] + ).to(inputs_tuple[0].device) prev_results = all_eval[-num_examples:] for j in range(len(total_attrib)): @@ -689,7 +693,7 @@ def _evalFutToPrevResultsTuple( agg_output_mode, ) = prev_results_tuple if agg_output_mode: - eval_diff = modified_eval - prev_results + eval_diff = (modified_eval - prev_results).to(inputs_tuple[0].device) prev_results = modified_eval else: # when perturb_per_eval > 1, every num_examples stands for @@ -698,7 +702,9 @@ def _evalFutToPrevResultsTuple( # the eval of the previous perturb all_eval = torch.cat((prev_results, modified_eval), dim=0) - eval_diff = all_eval[num_examples:] - all_eval[:-num_examples] + eval_diff = (all_eval[num_examples:] - all_eval[:-num_examples]).to( + inputs_tuple[0].device + ) prev_results = all_eval[-num_examples:] for j in range(len(total_attrib)): @@ -799,7 +805,10 @@ def _perturbation_generator( ) current_tensors_list.append(current_tensors) current_mask_list.append( - tuple(mask == feature_permutation[i] for mask in input_masks) + tuple( + (mask == feature_permutation[i]).to(inputs[0].device) + for mask in input_masks + ) ) if len(current_tensors_list) == perturbations_per_eval: combined_inputs = tuple(