diff --git a/README.md b/README.md index 35650658..1efeefe2 100644 --- a/README.md +++ b/README.md @@ -159,14 +159,15 @@ Use the `inseq.list_feature_attribution_methods` function to list all available Step functions are used to extract custom scores from the model at each step of the attribution process with the `step_scores` argument in `model.attribute`. They can also be used as targets for attribution methods relying on model outputs (e.g. gradient-based methods) by passing them as the `attributed_fn` argument. The following step functions are currently supported: - `logits`: Logits of the target token. -- `probability`: Probability of the target token. +- `probability`: Probability of the target token. Can also be used for log-probability by passing `logprob=True`. - `entropy`: Entropy of the predictive distribution. - `crossentropy`: Cross-entropy loss between target token and predicted distribution. - `perplexity`: Perplexity of the target token. -- `contrast_prob`: Probability of the target token when different contrastive inputs are provided to the model. Equivalent to `probability` when no contrastive inputs are provided. +- `contrast_logits`/`contrast_prob`: Logits/probabilities of the target token when different contrastive inputs are provided to the model. Equivalent to `logits`/`probability` when no contrastive inputs are provided. +- `contrast_logits_diff`/`contrast_prob_diff`: Difference in logits/probability between original and foil target tokens pair, can be used for contrastive evaluation as in [contrastive attribution](https://aclanthology.org/2022.emnlp-main.14/) (Yin and Neubig, 2022). - `pcxmi`: Point-wise Contextual Cross-Mutual Information (P-CXMI) for the target token given original and contrastive contexts [(Yin et al. 2021)](https://arxiv.org/abs/2109.07446). - `kl_divergence`: KL divergence of the predictive distribution given original and contrastive contexts. Can be restricted to most likely target token options using the `top_k` and `top_p` parameters. -- `contrast_prob_diff`: Difference in probability between original and foil target tokens pair, can be used for contrastive evaluation as in [Contrastive Attribution](https://aclanthology.org/2022.emnlp-main.14/) (Yin and Neubig, 2022). +- `in_context_pvi`: In-context Pointwise V-usable Information (PVI) to measure the amount of contextual information used in model predictions [(Lu et al. 2023)](https://arxiv.org/abs/2310.12300). - `mc_dropout_prob_avg`: Average probability of the target token across multiple samples using [MC Dropout](https://arxiv.org/abs/1506.02142) (Gal and Ghahramani, 2016). - `top_p_size`: The number of tokens with cumulative probability greater than `top_p` in the predictive distribution of the model. diff --git a/docs/source/main_classes/step_functions.rst b/docs/source/main_classes/step_functions.rst index 54a98b2f..8f11407f 100644 --- a/docs/source/main_classes/step_functions.rst +++ b/docs/source/main_classes/step_functions.rst @@ -37,13 +37,19 @@ The following functions can be used out-of-the-box as attribution targets or ste .. autofunction:: perplexity_fn +.. autofunction:: contrast_logits_fn + .. autofunction:: contrast_prob_fn +.. autofunction:: contrast_logits_diff_fn + +.. autofunction:: contrast_prob_diff_fn + .. autofunction:: pcxmi_fn .. autofunction:: kl_divergence_fn -.. autofunction:: contrast_prob_diff_fn +.. autofunction:: in_context_pvi_fn .. autofunction:: mc_dropout_prob_avg_fn diff --git a/inseq/attr/step_functions.py b/inseq/attr/step_functions.py index 4bbd752b..39fff88e 100644 --- a/inseq/attr/step_functions.py +++ b/inseq/attr/step_functions.py @@ -3,11 +3,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Protocol, Tuple import torch +import torch.nn.functional as F from transformers.modeling_outputs import ModelOutput from ..data import DecoderOnlyBatch, FeatureAttributionInput, get_batch_from_inputs, slice_batch_from_position from ..data.aggregation_functions import DEFAULT_ATTRIBUTION_AGGREGATE_DICT -from ..utils import extract_signature_args, logits_kl_divergence, top_p_logits_mask +from ..utils import extract_signature_args, filter_logits, top_p_logits_mask from ..utils.typing import EmbeddingsTensor, IdsTensor, SingleScorePerStepTensor, TargetIdsTensor if TYPE_CHECKING: @@ -72,6 +73,39 @@ def __call__( ... +CONTRAST_FN_ARGS_DOCSTRING = """Args: + contrast_sources (:obj:`str` or :obj:`list(str)`): Source text(s) used as contrastive inputs to compute + the contrastive step function for encoder-decoder models. If not specified, the source text is assumed to + match the original source text. Defaults to :obj:`None`. + contrast_target_prefixes (:obj:`str` or :obj:`list(str)`): Target prefix(es) used as contrastive inputs to + compute the contrastive step function. If not specified, no target prefix beyond previously generated + tokens is assumed. Defaults to :obj:`None`. + contrast_targets (:obj:`str` or :obj:`list(str)`): Contrastive target text(s) to be compared to the original + target text. If not specified, the original target text is used as contrastive target (will result in same + output unless ``contrast_sources`` or ``contrast_target_prefixes`` are specified). Defaults to :obj:`None`. + contrast_targets_alignments (:obj:`list(tuple(int, int))`, `optional`): A list of tuples of indices, where the + first element is the index of the original target token and the second element is the index of the + contrastive target token, used only if :obj:`contrast_targets` is specified. If an explicit alignment is + not specified, the alignment of the original and contrastive target texts is assumed to be 1:1 for all + available tokens. Defaults to :obj:`None`. +""" + + +def contrast_fn_docstring(): + def docstring_decorator(fn: StepFunction): + """Returns the docstring for the contrastive step functions.""" + if fn.__doc__ is not None: + if "Args:\n" in fn.__doc__: + fn.__doc__ = fn.__doc__.replace("Args:\n", CONTRAST_FN_ARGS_DOCSTRING) + else: + fn.__doc__ = fn.__doc__ + "\n " + CONTRAST_FN_ARGS_DOCSTRING + else: + fn.__doc__ = CONTRAST_FN_ARGS_DOCSTRING + return fn + + return docstring_decorator + + def logit_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor: """Compute the logit of the target_ids from the model's output logits.""" logits = args.attribution_model.output2logits(args.forward_output) @@ -79,11 +113,11 @@ def logit_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor: return logits.gather(-1, target_ids).squeeze(-1) -def probability_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor: +def probability_fn(args: StepFunctionArgs, logprob: bool = False) -> SingleScorePerStepTensor: """Compute the probabilty of target_ids from the model's output logits.""" logits = args.attribution_model.output2logits(args.forward_output) target_ids = args.target_ids.reshape(logits.shape[0], 1) - logits = logits.softmax(dim=-1) + logits = logits.softmax(dim=-1) if not logprob else logits.log_softmax(dim=-1) # Extracts the ith score from the softmax output over the vocabulary (dim -1 of the logits) # where i is the value of the corresponding index in target_ids. return logits.gather(-1, target_ids).squeeze(-1) @@ -102,7 +136,8 @@ def crossentropy_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor: """Compute the cross entropy between the target_ids and the logits. See: https://github.com/ZurichNLP/nmtscore/blob/master/src/nmtscore/models/m2m100.py#L99. """ - return -torch.log2(probability_fn(args)) + logits = args.attribution_model.output2logits(args.forward_output) + return F.cross_entropy(logits, args.target_ids, reduction="none").squeeze(-1) def perplexity_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor: @@ -114,6 +149,7 @@ def perplexity_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor: return 2 ** crossentropy_fn(args) +@contrast_fn_docstring() def _get_contrast_output( args: StepFunctionArgs, contrast_target_prefixes: Optional[FeatureAttributionInput] = None, @@ -126,20 +162,6 @@ def _get_contrast_output( """Utility function to return the output of the model for given contrastive inputs. Args: - contrast_sources (:obj:`str` or :obj:`list(str)`): Source text(s) used as contrastive inputs to compute - target probabilities for encoder-decoder models. If not specified, the source text is assumed to match the - original source text. Defaults to :obj:`None`. - contrast_target_prefixes (:obj:`str` or :obj:`list(str)`): Target prefix(es) used as contrastive inputs to - compute target probabilities. If not specified, no target prefix beyond previously generated tokens is - assumed. Defaults to :obj:`None`. - contrast_targets (:obj:`str` or :obj:`list(str)`): Contrastive target text(s) to be compared to the original - target text. If not specified, the original target text is used as contrastive target (will result in same - output unless ``contrast_sources`` or ``contrast_target_prefixes`` are specified). Defaults to :obj:`None`. - contrast_targets_alignments (:obj:`list(tuple(int, int))`, `optional`): A list of tuples of indices, where the - first element is the index of the original target token and the second element is the index of the - contrastive target token, used only if :obj:`contrast_targets` is specified. If an explicit alignment is - not specified, the alignment of the original and contrastive target texts is assumed to be 1:1 for all - available tokens. Defaults to :obj:`None`. return_contrastive_target_ids (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to return the contrastive target ids as well as the model output. Defaults to :obj:`False`. **forward_kwargs: Additional keyword arguments to be passed to the model's forward pass. @@ -199,33 +221,46 @@ def _get_contrast_output( return c_out +@contrast_fn_docstring() +def contrast_logits_fn( + args: StepFunctionArgs, + contrast_sources: Optional[FeatureAttributionInput] = None, + contrast_target_prefixes: Optional[FeatureAttributionInput] = None, + contrast_targets: Optional[FeatureAttributionInput] = None, + contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, +): + """Returns the logit of a generation target given contrastive context or target prediction alternative. + If only ``contrast_targets`` are specified, the logit of the contrastive prediction is computed given same + context. The logit for the same token given contrastive source/target preceding context can also be computed + using ``contrast_sources`` and ``contrast_target_prefixes`` without specifying ``contrast_targets``. + """ + c_output, c_tgt_ids = _get_contrast_output( + args, + contrast_sources=contrast_sources, + contrast_target_prefixes=contrast_target_prefixes, + contrast_targets=contrast_targets, + contrast_targets_alignments=contrast_targets_alignments, + return_contrastive_target_ids=True, + ) + if c_tgt_ids: + args.target_ids = c_tgt_ids + args.forward_output = c_output + return logit_fn(args) + + +@contrast_fn_docstring() def contrast_prob_fn( args: StepFunctionArgs, contrast_sources: Optional[FeatureAttributionInput] = None, contrast_target_prefixes: Optional[FeatureAttributionInput] = None, contrast_targets: Optional[FeatureAttributionInput] = None, contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, + logprob: bool = False, ): """Returns the probability of a generation target given contrastive context or target prediction alternative. If only ``contrast_targets`` are specified, the probability of the contrastive prediction is computed given same context. The probability for the same token given contrastive source/target preceding context can also be computed using ``contrast_sources`` and ``contrast_target_prefixes`` without specifying ``contrast_targets``. - - Args: - contrast_sources (:obj:`str` or :obj:`list(str)`): Source text(s) used as contrastive inputs to compute - target probabilities for encoder-decoder models. If not specified, the source text is assumed to match the - original source text. Defaults to :obj:`None`. - contrast_target_prefixes (:obj:`str` or :obj:`list(str)`): Target prefix(es) used as contrastive inputs to - compute target probabilities. If not specified, no target prefix beyond previously generated tokens is - assumed. Defaults to :obj:`None`. - contrast_targets (:obj:`str` or :obj:`list(str)`): Contrastive target text(s) to be compared to the original - target text. If not specified, the original target text is used as contrastive target (will result in same - output unless ``contrast_sources`` or ``contrast_target_prefixes`` are specified). Defaults to :obj:`None`. - contrast_targets_alignments (:obj:`list(tuple(int, int))`, `optional`): A list of tuples of indices, where the - first element is the index of the original target token and the second element is the index of the - contrastive target token, used only if :obj:`contrast_targets` is specified. If an explicit alignment is - not specified, the alignment of the original and contrastive target texts is assumed to be 1:1 for all - available tokens. Defaults to :obj:`None`. """ c_output, c_tgt_ids = _get_contrast_output( args, @@ -238,9 +273,10 @@ def contrast_prob_fn( if c_tgt_ids: args.target_ids = c_tgt_ids args.forward_output = c_output - return probability_fn(args) + return probability_fn(args, logprob=logprob) +@contrast_fn_docstring() def pcxmi_fn( args: StepFunctionArgs, contrast_sources: Optional[FeatureAttributionInput] = None, @@ -253,22 +289,6 @@ def pcxmi_fn( input options. The P-CXMI is defined as the negative log-ratio between the conditional probability of the target given the original input and the conditional probability of the target given the contrastive input, as defined by `Yin et al. (2021) `__. - - Args: - contrast_sources (:obj:`str` or :obj:`list(str)`): Source text(s) used as contrastive inputs to compute - the P-CXMI for encoder-decoder models. If not specified, the source text is assumed to match the original - source text. Defaults to :obj:`None`. - contrast_target_prefixes (:obj:`str` or :obj:`list(str)`): Target prefix(es) used as contrastive inputs to - compute the P-CXMI. If not specified, no target prefix beyond previously generated tokens is assumed. - Defaults to :obj:`None`. - contrast_targets (:obj:`str` or :obj:`list(str)`): Contrastive target text(s) to be compared to the original - target text. If not specified, the original target text is used as contrastive target (will result in same - output unless ``contrast_sources`` or ``contrast_target_prefixes`` are specified). Defaults to :obj:`None`. - contrast_targets_alignments (:obj:`list(tuple(int, int))`, `optional`): A list of tuples of indices, where the - first element is the index of the original target token and the second element is the index of the - contrastive target token, used only if :obj:`contrast_targets` is specified. If an explicit alignment is - not specified, the alignment of the original and contrastive target texts is assumed to be 1:1 for all - available tokens. Defaults to :obj:`None`. """ original_probs = probability_fn(args) contrast_probs = contrast_prob_fn( @@ -281,6 +301,7 @@ def pcxmi_fn( return -torch.log2(torch.div(original_probs, contrast_probs)) +@contrast_fn_docstring() def kl_divergence_fn( args: StepFunctionArgs, contrast_sources: Optional[FeatureAttributionInput] = None, @@ -296,20 +317,6 @@ def kl_divergence_fn( (Q) inputs. Args: - contrast_sources (:obj:`str` or :obj:`list(str)`): Source text(s) used as contrastive inputs to compute - the KL divergence for encoder-decoder models. If not specified, the source text is assumed to match the - original source text. Defaults to :obj:`None`. - contrast_target_prefixes (:obj:`str` or :obj:`list(str)`): Target prefix(es) used as contrastive inputs to - compute the KL divergence. If not specified, no target prefix beyond previously generated tokens is - assumed. Defaults to :obj:`None`. - contrast_targets (:obj:`str` or :obj:`list(str)`): Contrastive target text(s) to be compared to the original - target text. If not specified, the original target text is used as contrastive target (will result in same - output unless ``contrast_sources`` or ``contrast_target_prefixes`` are specified). Defaults to :obj:`None`. - contrast_targets_alignments (:obj:`list(tuple(int, int))`, `optional`): A list of tuples of indices, where the - first element is the index of the original target token and the second element is the index of the - contrastive target token, used only if :obj:`contrast_targets` is specified. If an explicit alignment is - not specified, the alignment of the original and contrastive target texts is assumed to be 1:1 for all - available tokens. Defaults to :obj:`None`. top_k (:obj:`int`): If set to a value > 0, only the top :obj:`top_k` tokens will be considered for computing the KL divergence. Defaults to :obj:`0` (no top-k selection). top_p (:obj:`float`): If set to a value > 0 and < 1, only the tokens with cumulative probability above @@ -329,21 +336,31 @@ def kl_divergence_fn( return_contrastive_target_ids=False, ) contrast_logits: torch.Tensor = args.attribution_model.output2logits(contrast_output) - return logits_kl_divergence( + filtered_original_logits, filtered_contrast_logits = filter_logits( original_logits=original_logits, contrast_logits=contrast_logits, top_p=top_p, top_k=top_k, min_tokens_to_keep=min_tokens_to_keep, ) + filtered_original_logprobs = F.log_softmax(filtered_original_logits, dim=-1) + filtered_contrast_logprobs = F.log_softmax(filtered_contrast_logits, dim=-1) + kl_divergence = torch.zeros(filtered_original_logprobs.size(0)) + for i in range(filtered_original_logits.size(0)): + kl_divergence[i] = F.kl_div( + filtered_contrast_logprobs[i], filtered_original_logprobs[i], reduction="sum", log_target=True + ) + return kl_divergence +@contrast_fn_docstring() def contrast_prob_diff_fn( args: StepFunctionArgs, contrast_sources: Optional[FeatureAttributionInput] = None, contrast_target_prefixes: Optional[FeatureAttributionInput] = None, contrast_targets: Optional[FeatureAttributionInput] = None, contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, + logprob: bool = False, ): """Returns the difference between next step probability for a candidate generation target vs. a contrastive alternative. Can be used as attribution target to answer the question: "Which features were salient in the @@ -351,38 +368,76 @@ def contrast_prob_diff_fn( of `Yin and Neubig (2022) `__. Can also be used to compute the difference in probability for the same token given contrastive source/target preceding context using ``contrast_sources`` and ``contrast_target_prefixes`` without specifying ``contrast_targets``. - - Args: - contrast_sources (:obj:`str` or :obj:`list(str)`): Source text(s) used as contrastive inputs to compute - target probabilities for encoder-decoder models. If not specified, the source text is assumed to match the - original source text. Defaults to :obj:`None`. - contrast_target_prefixes (:obj:`str` or :obj:`list(str)`): Target prefix(es) used as contrastive inputs to - compute target probabilities. If not specified, no target prefix beyond previously generated tokens is - assumed. Defaults to :obj:`None`. - contrast_targets (:obj:`str` or :obj:`list(str)`): Contrastive target text(s) to be compared to the original - target text. If not specified, the original target text is used as contrastive target (will result in same - output unless ``contrast_sources`` or ``contrast_target_prefixes`` are specified). Defaults to :obj:`None`. - contrast_targets_alignments (:obj:`list(tuple(int, int))`, `optional`): A list of tuples of indices, where the - first element is the index of the original target token and the second element is the index of the - contrastive target token, used only if :obj:`contrast_targets` is specified. If an explicit alignment is - not specified, the alignment of the original and contrastive target texts is assumed to be 1:1 for all - available tokens. Defaults to :obj:`None`. """ - model_probs = probability_fn(args) + model_probs = probability_fn(args, logprob=logprob) contrast_probs = contrast_prob_fn( args=args, contrast_sources=contrast_sources, contrast_target_prefixes=contrast_target_prefixes, contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, + logprob=logprob, ) # Return the prob difference as target for attribution return model_probs - contrast_probs +@contrast_fn_docstring() +def contrast_logits_diff_fn( + args: StepFunctionArgs, + contrast_sources: Optional[FeatureAttributionInput] = None, + contrast_target_prefixes: Optional[FeatureAttributionInput] = None, + contrast_targets: Optional[FeatureAttributionInput] = None, + contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, +): + """Equivalent to ``contrast_prob_diff_fn`` but for logits. The original target function used in + `Yin and Neubig (2022) `__ + """ + model_logits = logit_fn(args) + contrast_logits = contrast_logits_fn( + args=args, + contrast_sources=contrast_sources, + contrast_target_prefixes=contrast_target_prefixes, + contrast_targets=contrast_targets, + contrast_targets_alignments=contrast_targets_alignments, + ) + # Return the logit difference as target for attribution + return model_logits - contrast_logits + + +@contrast_fn_docstring() +def in_context_pvi_fn( + args: StepFunctionArgs, + contrast_sources: Optional[FeatureAttributionInput] = None, + contrast_target_prefixes: Optional[FeatureAttributionInput] = None, + contrast_targets: Optional[FeatureAttributionInput] = None, + contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, +): + """Returns the in-context pointwise V-usable information as defined by `Lu et al. (2023) + `__. In-context PVI is a variant of P-CXMI that captures the amount of usable + information in a given contextual example, i.e. how much context information contributes to model's prediction. + In-context PVI was used by `Lu et al. (2023) `__ to estimate example difficulty + for a given model, and by `Prasad et al. (2023) `__ to measure the + informativeness of intermediate reasoning steps in chain-of-thought prompting. + + Reference implementation: https://github.com/boblus/in-context-pvi/blob/main/in_context_pvi.ipynb + """ + orig_logprob = probability_fn(args, logprob=True) + contrast_logprob = contrast_prob_fn( + args=args, + contrast_sources=contrast_sources, + contrast_target_prefixes=contrast_target_prefixes, + contrast_targets=contrast_targets, + contrast_targets_alignments=contrast_targets_alignments, + logprob=True, + ) + return -orig_logprob + contrast_logprob + + def mc_dropout_prob_avg_fn( args: StepFunctionArgs, n_mcd_steps: int = 5, + logprob: bool = False, ): """Returns the average of probability scores using a pool of noisy prediction computed with MC Dropout. Can be used as an attribution target to compute more robust attribution scores. @@ -395,7 +450,7 @@ def mc_dropout_prob_avg_fn( n_mcd_steps (:obj:`int`): The number of prediction steps that should be used to normalize the original output. """ # Original probability from the model without noise - orig_prob = probability_fn(args) + orig_prob = probability_fn(args, logprob=logprob) # Compute noisy predictions using the noisy model # Important: must be in train mode to ensure noise for MCD @@ -407,7 +462,7 @@ def mc_dropout_prob_avg_fn( aux_batch, use_embeddings=args.attribution_model.is_encoder_decoder ) args.forward_output = aux_output - noisy_prob = probability_fn(args) + noisy_prob = probability_fn(args, logprob=logprob) noisy_probs.append(noisy_prob) # Z-score the original based on the mean and standard deviation of MC dropout predictions return (orig_prob - torch.stack(noisy_probs).mean(0)).div(torch.stack(noisy_probs).std(0)) @@ -434,10 +489,13 @@ def top_p_size_fn( "entropy": entropy_fn, "crossentropy": crossentropy_fn, "perplexity": perplexity_fn, + "contrast_logits": contrast_logits_fn, "contrast_prob": contrast_prob_fn, + "contrast_logits_diff": contrast_logits_diff_fn, + "contrast_prob_diff": contrast_prob_diff_fn, "pcxmi": pcxmi_fn, "kl_divergence": kl_divergence_fn, - "contrast_prob_diff": contrast_prob_diff_fn, + "in_context_pvi": in_context_pvi_fn, "mc_dropout_prob_avg": mc_dropout_prob_avg_fn, "top_p_size": top_p_size_fn, } diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py index 73e76fc8..5f59ea04 100644 --- a/inseq/utils/__init__.py +++ b/inseq/utils/__init__.py @@ -48,10 +48,10 @@ aggregate_contiguous, check_device, euclidean_distance, + filter_logits, get_default_device, get_front_padding, get_sequences_from_batched_steps, - logits_kl_divergence, normalize, remap_from_filtered, top_p_logits_mask, @@ -116,5 +116,5 @@ "get_adjusted_alignments", "get_aligned_idx", "top_p_logits_mask", - "logits_kl_divergence", + "filter_logits", ] diff --git a/inseq/utils/torch_utils.py b/inseq/utils/torch_utils.py index cc99797f..5bf4ceb9 100644 --- a/inseq/utils/torch_utils.py +++ b/inseq/utils/torch_utils.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Sequence, Tuple, Union import torch import torch.nn.functional as F @@ -75,37 +75,65 @@ def top_p_logits_mask(logits: torch.Tensor, top_p: float, min_tokens_to_keep: in return indices_to_remove -def logits_kl_divergence( +def top_k_logits_mask(logits: torch.Tensor, top_k: int, min_tokens_to_keep: int) -> torch.Tensor: + """Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py + """ + top_k = max(top_k, min_tokens_to_keep) + return logits < logits.topk(top_k).values[..., -1, None] + + +def get_logits_from_filter_strategy( + filter_strategy: Union[Literal["original"], Literal["contrast"], Literal["merged"]], + original_logits: torch.Tensor, + contrast_logits: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if filter_strategy == "original": + return original_logits + elif filter_strategy == "contrast": + return contrast_logits + elif filter_strategy == "merged": + return original_logits + contrast_logits + + +def filter_logits( original_logits: torch.Tensor, - contrast_logits: torch.Tensor, + contrast_logits: Optional[torch.Tensor] = None, top_p: float = 1.0, top_k: int = 0, min_tokens_to_keep: int = 1, -) -> torch.Tensor: - """Compute the KL divergence between two sets of logits, with optional top-k and top-p filtering.""" - top_k = min(top_k, contrast_logits.size(-1)) + filter_strategy: Union[Literal["original"], Literal["contrast"], Literal["merged"], None] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Applies top-k and top-p filtering to logits, and optionally to an additional set of contrastive logits.""" + if top_k > original_logits.size(-1) or top_k < 0: + raise ValueError(f"`top_k` has to be a positive integer < {original_logits.size(-1)}, but is {top_k}") + if filter_strategy and filter_strategy != "original" and contrast_logits is None: + raise ValueError(f"`filter_strategy` {filter_strategy} can only be used if `contrast_logits` is provided") + if not filter_strategy: + if contrast_logits is None: + filter_strategy = "original" + else: + filter_strategy = "merged" if top_p < 1.0: - original_indices_to_remove = top_p_logits_mask(original_logits, top_p, min_tokens_to_keep) - contrastive_indices_to_remove = top_p_logits_mask(contrast_logits, top_p, min_tokens_to_keep) - joined_indices_to_remove = original_indices_to_remove & contrastive_indices_to_remove - original_logits = original_logits.masked_select(~joined_indices_to_remove)[None, ...] - contrast_logits = contrast_logits.masked_select(~joined_indices_to_remove)[None, ...] + indices_to_remove = top_p_logits_mask( + get_logits_from_filter_strategy(filter_strategy, original_logits, contrast_logits), + top_p, + min_tokens_to_keep, + ) + original_logits = original_logits.masked_fill(indices_to_remove, float("-inf")) + if contrast_logits is not None: + contrast_logits = contrast_logits.masked_fill(indices_to_remove, float("-inf")) if top_k > 0: - filtered_contrast_logits = torch.zeros(contrast_logits.size(0), top_k) - filtered_original_logits = torch.zeros(original_logits.size(0), top_k) - indices_to_remove = contrast_logits < contrast_logits.topk(top_k).values[..., -1, None] - for i in range(contrast_logits.size(0)): - filtered_contrast_logits[i] = contrast_logits[i].masked_select(~indices_to_remove[i]) - filtered_original_logits[i] = original_logits[i].masked_select(~indices_to_remove[i]) - else: - filtered_contrast_logits = contrast_logits - filtered_original_logits = original_logits - original_logprobs = F.log_softmax(filtered_original_logits, dim=-1) - contrast_logprobs = F.log_softmax(filtered_contrast_logits, dim=-1) - kl_divergence = torch.zeros(original_logprobs.size(0)) - for i in range(original_logprobs.size(0)): - kl_divergence[i] = F.kl_div(contrast_logprobs[i], original_logprobs[i], reduction="sum", log_target=True) - return kl_divergence + indices_to_remove = top_k_logits_mask( + get_logits_from_filter_strategy(filter_strategy, original_logits, contrast_logits), + top_k, + min_tokens_to_keep, + ) + original_logits = original_logits.masked_fill(indices_to_remove, float("-inf")) + if contrast_logits is not None: + contrast_logits = contrast_logits.masked_fill(indices_to_remove, float("-inf")) + if contrast_logits is not None: + return original_logits, contrast_logits + return original_logits def euclidean_distance(vec_a: torch.Tensor, vec_b: torch.Tensor) -> torch.Tensor: diff --git a/tests/utils/test_torch_utils.py b/tests/utils/test_torch_utils.py index e33a889c..115e05af 100644 --- a/tests/utils/test_torch_utils.py +++ b/tests/utils/test_torch_utils.py @@ -2,6 +2,7 @@ import torch from inseq.utils.misc import pretty_tensor +from inseq.utils.torch_utils import filter_logits @pytest.mark.parametrize( @@ -46,3 +47,31 @@ def test_probits2prob(): probs = torch.gather(probits, -1, target_ids.T) assert probs.shape == (1, 1) assert torch.eq(probs, torch.tensor([23456.0])).all() + + +def test_filter_logits(): + original_logits = torch.tensor( + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [2.0, 3.0, 4.0, 5.0, 6.0], + [3.0, 4.0, 5.0, 6.0, 7.0], + [4.0, 5.0, 6.0, 7.0, 8.0], + ] + ) + filtered_logits = filter_logits(original_logits, top_k=2) + topk2 = original_logits.clone().index_fill(1, torch.tensor([0, 1, 2]), float("-inf")) + assert torch.eq(filtered_logits, topk2).all() + filtered_logits = filter_logits(original_logits, top_p=0.9) + topp90 = original_logits.clone().index_fill(1, torch.tensor([0, 1]), float("-inf")) + assert torch.eq(filtered_logits, topp90).all() + contrast_logits = torch.tensor( + [ + [15.0, 13.0, 11.0, 9.0, 7.0], + [13.0, 11.0, 9.0, 7.0, 5.0], + [11.0, 9.0, 7.0, 5.0, 3.0], + [9.0, 7.0, 5.0, 3.0, 1.0], + ] + ) + filtered_logits, contrast_logits = filter_logits(original_logits, contrast_logits=contrast_logits, top_k=2) + top2merged = original_logits.clone().index_fill(1, torch.tensor([2, 3, 4]), float("-inf")) + assert torch.eq(filtered_logits, top2merged).all()