Skip to content

Commit

Permalink
Potential fix for failed test_attribute_slice_seq2seq on CUDA-enabl…
Browse files Browse the repository at this point in the history
…ed platforms (#252)

Thank you for the fixes!
  • Loading branch information
xuan25 committed Feb 20, 2024
1 parent b4156a6 commit ff4ac86
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 7 deletions.
4 changes: 2 additions & 2 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,11 +590,11 @@ def filtered_attribute_step(
batch=batch,
)
step_fn_extra_args = get_step_scores_args([score], step_scores_args)
step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args).to("cpu")
step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args)
# Reinsert finished sentences
if target_attention_mask is not None and is_filtered:
step_output.remap_from_filtered(target_attention_mask, orig_batch)
step_output = step_output.detach().to("cpu")
step_output = step_output.detach()
return step_output

def get_attribution_args(self, **kwargs) -> tuple[dict[str, Any], dict[str, Any]]:
Expand Down
6 changes: 3 additions & 3 deletions inseq/attr/feat/gradient_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def attribute_step(
attr, self.attribution_model.is_encoder_decoder
)
return GranularFeatureAttributionStepOutput(
source_attributions=source_attributions.to("cpu") if source_attributions is not None else None,
target_attributions=target_attributions.to("cpu") if target_attributions is not None else None,
step_scores={"deltas": deltas.to("cpu")} if deltas is not None else None,
source_attributions=source_attributions if source_attributions is not None else None,
target_attributions=target_attributions if target_attributions is not None else None,
step_scores={"deltas": deltas} if deltas is not None else None,
)


Expand Down
2 changes: 1 addition & 1 deletion inseq/data/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def _filter_scores(
if len(set(indices)) != len(indices):
raise IndexError("Duplicate indices are not allowed.")
if isinstance(indices, tuple):
scores = scores.index_select(dim, torch.arange(indices[0], indices[1]))
scores = scores.index_select(dim, torch.arange(indices[0], indices[1], device=scores.device))
else:
scores = scores.index_select(dim, torch.tensor(indices, device=scores.device))
return scores
Expand Down
4 changes: 3 additions & 1 deletion inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def from_step_attributions(
start_idx:end_idx, : len(targets[seq_id]), ... # noqa: E203
]
if target_attributions[seq_id].shape[0] != len(tokenized_target_sentences[seq_id]):
empty_final_row = torch.ones(1, *target_attributions[seq_id].shape[1:]) * float("nan")
empty_final_row = torch.ones(
1, *target_attributions[seq_id].shape[1:], device=target_attributions[seq_id].device
) * float("nan")
target_attributions[seq_id] = torch.cat([target_attributions[seq_id], empty_final_row], dim=0)
seq_attributions[seq_id].target_attributions = target_attributions[seq_id]
if attr.step_scores is not None:
Expand Down

0 comments on commit ff4ac86

Please sign in to comment.