Skip to content

Commit

Permalink
Multi-device support for attribution (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Jan 8, 2024
1 parent de20361 commit eba92ec
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
22 changes: 11 additions & 11 deletions inseq/attr/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ def __call__(
def logit_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor:
"""Compute the logit of the target_ids from the model's output logits."""
logits = args.attribution_model.output2logits(args.forward_output)
target_ids = args.target_ids.reshape(logits.shape[0], 1)
target_ids = args.target_ids.reshape(logits.shape[0], 1).to(logits.device)
return logits.gather(-1, target_ids).squeeze(-1)


def probability_fn(args: StepFunctionArgs, logprob: bool = False) -> SingleScorePerStepTensor:
"""Compute the probabilty of target_ids from the model's output logits."""
logits = args.attribution_model.output2logits(args.forward_output)
target_ids = args.target_ids.reshape(logits.shape[0], 1)
target_ids = args.target_ids.reshape(logits.shape[0], 1).to(logits.device)
logits = logits.softmax(dim=-1) if not logprob else logits.log_softmax(dim=-1)
# Extracts the ith score from the softmax output over the vocabulary (dim -1 of the logits)
# where i is the value of the corresponding index in target_ids.
Expand All @@ -101,7 +101,7 @@ def probability_fn(args: StepFunctionArgs, logprob: bool = False) -> SingleScore
def entropy_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor:
"""Compute the entropy of the model's output distribution."""
logits = args.attribution_model.output2logits(args.forward_output)
entropy = torch.zeros(logits.size(0))
entropy = torch.zeros(logits.size(0)).to(logits.device)
for i in range(logits.size(0)):
entropy[i] = torch.distributions.Categorical(logits=logits[i]).entropy()
return entropy
Expand All @@ -112,7 +112,7 @@ def crossentropy_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor:
See: https://github.com/ZurichNLP/nmtscore/blob/master/src/nmtscore/models/m2m100.py#L99.
"""
logits = args.attribution_model.output2logits(args.forward_output)
return F.cross_entropy(logits, args.target_ids, reduction="none").squeeze(-1)
return F.cross_entropy(logits, args.target_ids.to(logits.device), reduction="none").squeeze(-1)


def perplexity_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor:
Expand Down Expand Up @@ -191,7 +191,7 @@ def pcxmi_fn(
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
contrast_force_inputs=contrast_force_inputs,
)
).to(original_probs.device)
return -torch.log2(torch.div(original_probs, contrast_probs))


Expand Down Expand Up @@ -236,7 +236,7 @@ def kl_divergence_fn(
c_forward_output = args.attribution_model.get_forward_output(
contrast_inputs.batch, use_embeddings=args.attribution_model.is_encoder_decoder
)
contrast_logits: torch.Tensor = args.attribution_model.output2logits(c_forward_output)
contrast_logits: torch.Tensor = args.attribution_model.output2logits(c_forward_output).to(original_logits.device)
filtered_original_logits, filtered_contrast_logits = filter_logits(
original_logits=original_logits,
contrast_logits=contrast_logits,
Expand Down Expand Up @@ -278,7 +278,7 @@ def contrast_prob_diff_fn(
contrast_targets_alignments=contrast_targets_alignments,
logprob=logprob,
contrast_force_inputs=contrast_force_inputs,
)
).to(model_probs.device)
return model_probs - contrast_probs


Expand All @@ -300,7 +300,7 @@ def contrast_logits_diff_fn(
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
contrast_force_inputs=contrast_force_inputs,
)
).to(model_logits.device)
return model_logits - contrast_logits


Expand Down Expand Up @@ -329,7 +329,7 @@ def in_context_pvi_fn(
contrast_targets_alignments=contrast_targets_alignments,
logprob=True,
contrast_force_inputs=contrast_force_inputs,
)
).to(orig_logprob.device)
return -orig_logprob + contrast_logprob


Expand Down Expand Up @@ -361,7 +361,7 @@ def mc_dropout_prob_avg_fn(
aux_batch, use_embeddings=args.attribution_model.is_encoder_decoder
)
args.forward_output = aux_output
noisy_prob = probability_fn(args, logprob=logprob)
noisy_prob = probability_fn(args, logprob=logprob).to(orig_prob.device)
noisy_probs.append(noisy_prob)
# Z-score the original based on the mean and standard deviation of MC dropout predictions
return (orig_prob - torch.stack(noisy_probs).mean(0)).div(torch.stack(noisy_probs).std(0))
Expand All @@ -377,7 +377,7 @@ def top_p_size_fn(
top_p (:obj:`float`): The cumulative probability threshold to use for filtering the logits.
"""
logits: torch.Tensor = args.attribution_model.output2logits(args.forward_output)
indices_to_remove = top_p_logits_mask(logits, top_p, 1)
indices_to_remove = top_p_logits_mask(logits, top_p, 1).to(logits.device)
logits = logits.masked_select(~indices_to_remove)[None, ...]
return torch.tensor(logits.size(-1))[None, ...]

Expand Down
4 changes: 2 additions & 2 deletions inseq/utils/contrast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _get_contrast_inputs(
inputs=contrast_targets,
as_targets=is_enc_dec,
)
)
).to(args.decoder_input_ids.device)
curr_prefix_len = args.decoder_input_ids.size(1)
c_batch, c_tgt_ids = slice_batch_from_position(c_batch, curr_prefix_len, contrast_targets_alignments)

Expand All @@ -107,7 +107,7 @@ def _get_contrast_inputs(
"Contrastive source inputs can only be used with encoder-decoder models. "
"Use `contrast_targets` to set a contrastive target containing a prefix for decoder-only models."
)
c_enc_in = args.attribution_model.encode(contrast_sources)
c_enc_in = args.attribution_model.encode(contrast_sources).to(args.encoder_input_ids.device)
if (
args.encoder_input_ids.shape != c_enc_in.input_ids.shape
or torch.ne(args.encoder_input_ids, c_enc_in.input_ids).any()
Expand Down

0 comments on commit eba92ec

Please sign in to comment.