diff --git a/README.md b/README.md index 5e09d734..583d8a88 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,8 @@ Use the `inseq.list_feature_attribution_methods` function to list all available - `value_zeroing`: [Quantifying Context Mixing in Transformers](https://aclanthology.org/2023.eacl-main.245/) (Mohebbi et al. 2023) +- `reagent`: [ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models](https://arxiv.org/abs/2402.00794) (Zhao et al., 2024) + #### Step functions Step functions are used to extract custom scores from the model at each step of the attribution process with the `step_scores` argument in `model.attribute`. They can also be used as targets for attribution methods relying on model outputs (e.g. gradient-based methods) by passing them as the `attributed_fn` argument. The following step functions are currently supported: @@ -303,7 +305,10 @@ If you use Inseq in your research we suggest to include a mention to the specifi ## Research using Inseq -Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below. If you know more, please let us know or submit a pull request (*last updated: February 2024*). +Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below. + +> [!TIP] +> Last update: April 2024. Please open a pull request to add your publication to the list.
2023 @@ -324,6 +329,7 @@ Inseq has been used in various research projects. A list of known publications t
  1. LLMCheckup: Conversational Examination of Large Language Models via Interpretability Tools (Wang et al., 2024)
  2. ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models (Zhao et al., 2024)
  3. +
  4. Revisiting subword tokenization: A case study on affixal negation in large language models (Truong et al., 2024)
diff --git a/docs/source/main_classes/feature_attribution.rst b/docs/source/main_classes/feature_attribution.rst index 1f282626..d7c4f5fc 100644 --- a/docs/source/main_classes/feature_attribution.rst +++ b/docs/source/main_classes/feature_attribution.rst @@ -90,4 +90,25 @@ Perturbation-based Attribution Methods :members: .. autoclass:: inseq.attr.feat.ValueZeroingAttribution - :members: \ No newline at end of file + :members: + +.. autoclass:: inseq.attr.feat.ReagentAttribution + :members: + + .. automethod:: __init__ + +.. code:: python + + import inseq + + model = inseq.load_model( + "gpt2-medium", + "reagent", + keep_top_n=5, + stopping_condition_top_k=3, + replacing_ratio=0.3, + max_probe_steps=3000, + num_probes=8 + ) + out = model.attribute("Super Mario Land is a game that developed by") + out.show() diff --git a/inseq/attr/feat/__init__.py b/inseq/attr/feat/__init__.py index 2b25778a..7a81014f 100644 --- a/inseq/attr/feat/__init__.py +++ b/inseq/attr/feat/__init__.py @@ -18,6 +18,7 @@ LimeAttribution, OcclusionAttribution, PerturbationAttributionRegistry, + ReagentAttribution, ValueZeroingAttribution, ) @@ -43,4 +44,5 @@ "SequentialIntegratedGradientsAttribution", "ValueZeroingAttribution", "PerturbationAttributionRegistry", + "ReagentAttribution", ] diff --git a/inseq/attr/feat/ops/__init__.py b/inseq/attr/feat/ops/__init__.py index 7d86167a..a40b9dba 100644 --- a/inseq/attr/feat/ops/__init__.py +++ b/inseq/attr/feat/ops/__init__.py @@ -1,6 +1,7 @@ from .discretized_integrated_gradients import DiscretetizedIntegratedGradients from .lime import Lime from .monotonic_path_builder import MonotonicPathBuilder +from .reagent import Reagent from .sequential_integrated_gradients import SequentialIntegratedGradients from .value_zeroing import ValueZeroing @@ -9,5 +10,6 @@ "MonotonicPathBuilder", "ValueZeroing", "Lime", + "Reagent", "SequentialIntegratedGradients", ] diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py new file mode 100644 index 00000000..080a2131 --- /dev/null +++ b/inseq/attr/feat/ops/reagent.py @@ -0,0 +1,134 @@ +from typing import TYPE_CHECKING, Any, Union + +import torch +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric +from torch import Tensor +from typing_extensions import override + +from ....utils.typing import InseqAttribution +from .reagent_core import ( + AggregateRationalizer, + DeltaProbImportanceScoreEvaluator, + POSTagTokenSampler, + TopKStoppingConditionEvaluator, + UniformTokenReplacer, +) + +if TYPE_CHECKING: + from ....models import HuggingfaceModel + + +class Reagent(InseqAttribution): + r"""Recursive attribution generator (ReAGent) method. + + Measures importance as the drop in prediction probability produced by replacing a token with a plausible + alternative predicted by a LM. + + Reference implementation: + `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models + `__ + + Args: + forward_func (callable): The forward function of the model or any modification of it + keep_top_n (int): If set to a value greater than 0, the top n tokens based on their importance score will be + kept during the prediction inference. If set to 0, the top n will be determined by ``keep_ratio``. + keep_ratio (float): If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be + replaced instead of being kept. + stopping_condition_top_k (int): Threshold indicating that the stop condition achieved when the predicted target + exist in top k predictions + replacing_ratio (float): replacing ratio of tokens for probing + max_probe_steps (int): max_probe_steps + num_probes (int): number of probes in parallel + + Example: + ``` + import inseq + + model = inseq.load_model("gpt2-medium", "reagent", + keep_top_n=5, + stopping_condition_top_k=3, + replacing_ratio=0.3, + max_probe_steps=3000, + num_probes=8 + ) + out = model.attribute("Super Mario Land is a game that developed by") + out.show() + ``` + """ + + def __init__( + self, + attribution_model: "HuggingfaceModel", + keep_top_n: int = 5, + keep_ratio: float = None, + invert_keep: bool = False, + stopping_condition_top_k: int = 3, + replacing_ratio: float = 0.3, + max_probe_steps: int = 3000, + num_probes: int = 16, + ) -> None: + super().__init__(attribution_model) + + model = attribution_model.model + tokenizer = attribution_model.tokenizer + model_name = attribution_model.model_name + + sampler = POSTagTokenSampler(tokenizer=tokenizer, identifier=model_name, device=attribution_model.device) + stopping_condition_evaluator = TopKStoppingConditionEvaluator( + model=model, + sampler=sampler, + top_k=stopping_condition_top_k, + keep_top_n=keep_top_n, + keep_ratio=keep_ratio, + invert_keep=invert_keep, + ) + importance_score_evaluator = DeltaProbImportanceScoreEvaluator( + model=model, + tokenizer=tokenizer, + token_replacer=UniformTokenReplacer(sampler=sampler, ratio=replacing_ratio), + stopping_condition_evaluator=stopping_condition_evaluator, + max_steps=max_probe_steps, + ) + + self.rationalizer = AggregateRationalizer( + importance_score_evaluator=importance_score_evaluator, + batch_size=num_probes, + overlap_threshold=0, + overlap_strict_pos=True, + keep_top_n=keep_top_n, + keep_ratio=keep_ratio, + ) + + @override + def attribute( # type: ignore + self, + inputs: TensorOrTupleOfTensorsGeneric, + _target: TargetType = None, + additional_forward_args: Any = None, + ) -> Union[ + TensorOrTupleOfTensorsGeneric, + tuple[TensorOrTupleOfTensorsGeneric, Tensor], + ]: + """Implement attribute""" + # encoder-decoder + if self.forward_func.is_encoder_decoder: + # with target-side attribution + if len(inputs) > 1: + self.rationalizer( + additional_forward_args[0], additional_forward_args[2], additional_forward_args[1], True + ) + mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0) + res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2]) + return ( + res[:, : additional_forward_args[0].shape[1], :], + res[:, additional_forward_args[0].shape[1] :, :], + ) + # source-side only + else: + self.rationalizer(additional_forward_args[1], additional_forward_args[3], additional_forward_args[2]) + # decoder-only + self.rationalizer(additional_forward_args[0], additional_forward_args[1]) + mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0) + res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2]) + return (res,) diff --git a/inseq/attr/feat/ops/reagent_core/__init__.py b/inseq/attr/feat/ops/reagent_core/__init__.py new file mode 100644 index 00000000..13917c00 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/__init__.py @@ -0,0 +1,13 @@ +from .importance_score_evaluator import DeltaProbImportanceScoreEvaluator +from .rationalizer import AggregateRationalizer +from .stopping_condition_evaluator import TopKStoppingConditionEvaluator +from .token_replacer import UniformTokenReplacer +from .token_sampler import POSTagTokenSampler + +__all__ = [ + "DeltaProbImportanceScoreEvaluator", + "AggregateRationalizer", + "TopKStoppingConditionEvaluator", + "UniformTokenReplacer", + "POSTagTokenSampler", +] diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py new file mode 100644 index 00000000..adc79b7b --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from jaxtyping import Float +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer +from typing_extensions import override + +from .....utils.typing import IdsTensor, MultipleScoresPerStepTensor, TargetIdsTensor +from .stopping_condition_evaluator import StoppingConditionEvaluator +from .token_replacer import TokenReplacer + + +class BaseImportanceScoreEvaluator(ABC): + """Importance Score Evaluator""" + + def __init__(self, model: AutoModelForCausalLM | AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer) -> None: + """Base Constructor + + Args: + model: A Huggingface AutoModelForCausalLM or AutoModelForSeq2SeqLM model + tokenizer: A Huggingface AutoTokenizer + """ + self.model = model + self.tokenizer = tokenizer + self.importance_score = None + + @abstractmethod + def __call__( + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> MultipleScoresPerStepTensor: + """Evaluate importance score of input sequence + + Args: + input_ids: input sequence [batch, sequence] + target_id: target token [batch] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + importance_score: evaluated importance score for each token in the input [batch, sequence] + + """ + raise NotImplementedError() + + +class DeltaProbImportanceScoreEvaluator(BaseImportanceScoreEvaluator): + """Importance Score Evaluator""" + + @override + def __init__( + self, + model: AutoModelForCausalLM | AutoModelForSeq2SeqLM, + tokenizer: AutoTokenizer, + token_replacer: TokenReplacer, + stopping_condition_evaluator: StoppingConditionEvaluator, + max_steps: float, + ) -> None: + """Constructor + + Args: + model: A Huggingface AutoModelForCausalLM or AutoModelForSeq2SeqLM model + tokenizer: A Huggingface AutoTokenizer + token_replacer: A TokenReplacer + stopping_condition_evaluator: A StoppingConditionEvaluator + """ + super().__init__(model, tokenizer) + self.token_replacer = token_replacer + self.stopping_condition_evaluator = stopping_condition_evaluator + self.max_steps = max_steps + self.importance_score = None + self.num_steps = 0 + + def update_importance_score( + self, + logit_importance_score: MultipleScoresPerStepTensor, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + prob_original_target: Float[torch.Tensor, "batch_size 1"], + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> MultipleScoresPerStepTensor: + """Update importance score by one step + + Args: + logit_importance_score: Current importance score in logistic scale [batch, sequence] + input_ids: input tensor [batch, sequence] + target_id: target tensor [batch] + prob_original_target: predictive probability of the target on the original sequence [batch, 1] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + logit_importance_score: updated importance score in logistic scale [batch, sequence] + """ + # Randomly replace a set of tokens R to form a new sequence \hat{y_{1...t}} + if not attribute_target: + input_ids_replaced, mask_replacing = self.token_replacer(input_ids) + else: + ids_replaced, mask_replacing = self.token_replacer(torch.cat((input_ids, decoder_input_ids), 1)) + input_ids_replaced = ids_replaced[:, : input_ids.shape[1]] + decoder_input_ids_replaced = ids_replaced[:, input_ids.shape[1] :] + + logging.debug(f"Replacing mask: { mask_replacing }") + logging.debug( + f"Replaced sequence: { [[ self.tokenizer.decode(seq[i]) for i in range(input_ids_replaced.shape[1]) ] for seq in input_ids_replaced ] }" + ) + + # Inference \hat{p^{(y)}} = p(y_{t+1}|\hat{y_{1...t}}) + kwargs = {"input_ids": input_ids_replaced} + if decoder_input_ids is not None: + kwargs["decoder_input_ids"] = decoder_input_ids_replaced if attribute_target else decoder_input_ids + logits_replaced = self.model(**kwargs)["logits"] + prob_replaced_target = torch.softmax(logits_replaced[:, -1, :], -1)[:, target_id] + + # Compute changes delta = p^{(y)} - \hat{p^{(y)}} + delta_prob_target = prob_original_target - prob_replaced_target + logging.debug(f"likelihood delta: { delta_prob_target }") + + # Update importance scores based on delta (magnitude) and replacement (direction) + delta_score = mask_replacing * delta_prob_target + ~mask_replacing * -delta_prob_target + # TODO: better solution? + # Rescaling from [-1, 1] to [0, 1] before logit function + logit_delta_score = torch.logit(delta_score * 0.5 + 0.5) + logit_importance_score = logit_importance_score + logit_delta_score + logging.debug(f"Updated importance score: { torch.softmax(logit_importance_score, -1) }") + return logit_importance_score + + @override + def __call__( + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> MultipleScoresPerStepTensor: + """Evaluate importance score of input sequence + + Args: + input_ids: input sequence [batch, sequence] + target_id: target token [batch] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + importance_score: evaluated importance score for each token in the input [batch, sequence] + """ + self.stop_mask = torch.zeros([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) + + # Inference p^{(y)} = p(y_{t+1}|y_{1...t}) + if decoder_input_ids is None: + logits_original = self.model(input_ids)["logits"] + else: + logits_original = self.model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)["logits"] + + prob_original_target = torch.softmax(logits_original[:, -1, :], -1)[:, target_id] + + # Initialize importance score s for each token in the sequence y_{1...t} + if not attribute_target: + logit_importance_score = torch.rand(input_ids.shape, device=input_ids.device) + else: + logit_importance_score = torch.rand( + (input_ids.shape[0], input_ids.shape[1] + decoder_input_ids.shape[1]), device=input_ids.device + ) + logging.debug(f"Initialize importance score -> { torch.softmax(logit_importance_score, -1) }") + + # TODO: limit max steps + self.num_steps = 0 + while self.num_steps < self.max_steps: + self.num_steps += 1 + # Update importance score + logit_importance_score_update = self.update_importance_score( + logit_importance_score, input_ids, target_id, prob_original_target, decoder_input_ids, attribute_target + ) + logit_importance_score = ( + ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update + + torch.unsqueeze(self.stop_mask, 1) * logit_importance_score + ) + self.importance_score = torch.softmax(logit_importance_score, -1) + + # Evaluate stop condition + self.stop_mask = self.stop_mask | self.stopping_condition_evaluator( + input_ids, target_id, self.importance_score, decoder_input_ids, attribute_target + ) + if torch.prod(self.stop_mask) > 0: + break + + logging.info(f"Importance score evaluated in {self.num_steps} steps.") + return torch.softmax(logit_importance_score, -1) diff --git a/inseq/attr/feat/ops/reagent_core/rationalizer.py b/inseq/attr/feat/ops/reagent_core/rationalizer.py new file mode 100644 index 00000000..ab7c2be8 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/rationalizer.py @@ -0,0 +1,128 @@ +import math +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from jaxtyping import Int64 +from typing_extensions import override + +from .....utils.typing import IdsTensor, TargetIdsTensor +from .importance_score_evaluator import BaseImportanceScoreEvaluator + + +class BaseRationalizer(ABC): + def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator) -> None: + super().__init__() + self.importance_score_evaluator = importance_score_evaluator + self.mean_importance_score = None + + @abstractmethod + def __call__( + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> Int64[torch.Tensor, "batch_size other_dims"]: + """Compute rational of a sequence on a target + + Args: + input_ids: The sequence [batch, sequence] (first dimension need to be 1) + target_id: The target [batch] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + pos_top_n: rational position in the sequence [batch, rational_size] + + """ + raise NotImplementedError() + + +class AggregateRationalizer(BaseRationalizer): + """AggregateRationalizer""" + + @override + def __init__( + self, + importance_score_evaluator: BaseImportanceScoreEvaluator, + batch_size: int, + overlap_threshold: int, + overlap_strict_pos: bool = True, + keep_top_n: int = 0, + keep_ratio: float = 0, + ) -> None: + """Constructor + + Args: + importance_score_evaluator: A ImportanceScoreEvaluator + batch_size: Batch size for aggregate + overlap_threshold: Overlap threshold of rational tokens within a batch + overlap_strict_pos: Whether overlap strict to position ot not + keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be + kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by + ``keep_ratio``. + keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + """ + super().__init__(importance_score_evaluator) + self.batch_size = batch_size + self.overlap_threshold = overlap_threshold + self.overlap_strict_pos = overlap_strict_pos + self.keep_top_n = keep_top_n + self.keep_ratio = keep_ratio + assert overlap_strict_pos, "overlap_strict_pos = False is not supported yet" + + @override + @torch.no_grad() + def __call__( + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> Int64[torch.Tensor, "batch_size other_dims"]: + """Compute rational of a sequence on a target + + Args: + input_ids: A tensor of ids of shape [batch, sequence_len] + target_id: A tensor of predicted targets of size [batch] + decoder_input_ids (optional): A tensor of ids representing the decoder input sequence for + ``AutoModelForSeq2SeqLM``, with shape [batch, sequence_len] + attribute_target: whether attribute target for encoder-decoder models + + Return: + pos_top_n: rational position in the sequence [batch, rational_size] + + """ + assert input_ids.shape[0] == 1, "the first dimension of input (batch_size) need to be 1" + batch_input_ids = input_ids.repeat(self.batch_size, 1) + batch_decoder_input_ids = ( + decoder_input_ids.repeat(self.batch_size, 1) if decoder_input_ids is not None else None + ) + batch_importance_score = self.importance_score_evaluator( + batch_input_ids, target_id, batch_decoder_input_ids, attribute_target + ) + importance_score_masked = batch_importance_score * torch.unsqueeze( + self.importance_score_evaluator.stop_mask, -1 + ) + self.mean_importance_score = torch.sum(importance_score_masked, dim=0) / torch.sum( + self.importance_score_evaluator.stop_mask + ) + pos_sorted = torch.argsort(batch_importance_score, dim=-1, descending=True) + top_n = int(math.ceil(self.keep_ratio * input_ids.shape[-1])) if not self.keep_top_n else self.keep_top_n + pos_top_n = pos_sorted[:, :top_n] + self.pos_top_n = pos_top_n + if self.overlap_strict_pos: + count_overlap = torch.bincount(pos_top_n.flatten(), minlength=input_ids.shape[1]) + pos_top_n_overlap = torch.unsqueeze( + torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0 + ) + return pos_top_n_overlap + else: + raise NotImplementedError("overlap_strict_pos = False not been supported yet") + # TODO: Convert back to pos + # token_id_top_n = input_ids[0, pos_top_n] + # count_overlap = torch.bincount(token_id_top_n.flatten(), minlength=input_ids.shape[1]) + # _token_id_top_n_overlap = torch.unsqueeze( + # torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0 + # ) diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py new file mode 100644 index 00000000..fd3bb67d --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py @@ -0,0 +1,136 @@ +import logging +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from transformers import AutoModelForCausalLM + +from .....utils.typing import IdsTensor, MultipleScoresPerStepTensor, TargetIdsTensor +from .token_replacer import RankingTokenReplacer +from .token_sampler import TokenSampler + + +class StoppingConditionEvaluator(ABC): + """Base class for Stopping Condition Evaluators""" + + @abstractmethod + def __call__( + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + importance_score: MultipleScoresPerStepTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> TargetIdsTensor: + """Evaluate stop condition according to the specified strategy. + + Args: + input_ids: Input sequence [batch, sequence] + target_id: Target token [batch] + importance_score: Importance score of the input [batch, sequence] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + Boolean flag per sequence signaling whether the stop condition was reached [batch] + + """ + raise NotImplementedError() + + +class TopKStoppingConditionEvaluator(StoppingConditionEvaluator): + """ + Evaluator stopping when target exist among the top k predictions, + while top n tokens based on importance_score are not been replaced. + """ + + def __init__( + self, + model: AutoModelForCausalLM, + sampler: TokenSampler, + top_k: int, + keep_top_n: int = 0, + keep_ratio: float = 0, + invert_keep: bool = False, + ) -> None: + """Constructor for the TopKStoppingConditionEvaluator class. + + Args: + model: A Huggingface ``AutoModelForCausalLM``. + sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object to sample replacement tokens. + top_k: Top K predictions in which the target must be included in order to achieve the stopping condition. + keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be + kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by + ``keep_ratio``. + keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be + replaced instead of being kept. + """ + self.model = model + self.top_k = top_k + self.replacer = RankingTokenReplacer(sampler, keep_top_n, keep_ratio, invert_keep) + + def __call__( + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + importance_score: MultipleScoresPerStepTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> TargetIdsTensor: + """Evaluate stop condition + + Args: + input_ids: Input sequence [batch, sequence] + target_id: Target token [batch] + importance_score: Importance score of the input [batch, sequence] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + Boolean flag per sequence signaling whether the stop condition was reached [batch] + """ + # Replace tokens with low importance score and then inference \hat{y^{(e)}_{t+1}} + self.replacer.set_score(importance_score) + if not attribute_target: + input_ids_replaced, mask_replacing = self.replacer(input_ids) + else: + ids_replaced, mask_replacing = self.replacer(torch.cat((input_ids, decoder_input_ids), 1)) + input_ids_replaced = ids_replaced[:, : input_ids.shape[1]] + decoder_input_ids_replaced = ids_replaced[:, input_ids.shape[1] :] + + logging.debug(f"Replacing mask based on importance score -> { mask_replacing }") + + # Whether the result \hat{y^{(e)}_{t+1}} consistent with y_{t+1} + assert not input_ids_replaced.requires_grad, "Error: auto-diff engine not disabled" + with torch.no_grad(): + kwargs = {"input_ids": input_ids_replaced} + if decoder_input_ids is not None: + kwargs["decoder_input_ids"] = decoder_input_ids_replaced if attribute_target else decoder_input_ids + logits_replaced = self.model(**kwargs)["logits"] + ids_prediction_sorted = torch.argsort(logits_replaced[:, -1, :], descending=True) + ids_prediction_top_k = ids_prediction_sorted[:, : self.top_k] + match_mask = ids_prediction_top_k == target_id + match_hit = torch.sum(match_mask, dim=-1, dtype=torch.bool) + return match_hit + + +class DummyStoppingConditionEvaluator(StoppingConditionEvaluator): + """ + Stopping Condition Evaluator which stop when target exist in top k predictions, + while top n tokens based on importance_score are not been replaced. + """ + + def __call__(self, input_ids: IdsTensor, **kwargs) -> TargetIdsTensor: + """Evaluate stop condition + + Args: + input_ids: Input sequence [batch, sequence] + target_id: Target token [batch] + importance_score: Importance score of the input [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + Boolean flag per sequence signaling whether the stop condition was reached [batch] + """ + return torch.ones([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) diff --git a/inseq/attr/feat/ops/reagent_core/token_replacer.py b/inseq/attr/feat/ops/reagent_core/token_replacer.py new file mode 100644 index 00000000..0d889144 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_replacer.py @@ -0,0 +1,111 @@ +import math +from abc import ABC, abstractmethod + +import torch +from typing_extensions import override + +from .....utils.typing import IdsTensor +from .token_sampler import TokenSampler + + +class TokenReplacer(ABC): + """ + Base class for token replacers + + """ + + def __init__(self, sampler: TokenSampler) -> None: + self.sampler = sampler + + @abstractmethod + def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]: + """Replace tokens according to the specified strategy. + + Args: + input: input sequence [batch, sequence] + + Returns: + input_replaced: A replaced sequence [batch, sequence] + replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence] + + """ + raise NotImplementedError() + + +class RankingTokenReplacer(TokenReplacer): + """Replace tokens in a sequence based on top-N ranking""" + + @override + def __init__( + self, sampler: TokenSampler, keep_top_n: int = 0, keep_ratio: float = 0, invert_keep: bool = False + ) -> None: + """Constructor for the RankingTokenReplacer class. + + Args: + sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object for sampling replacement tokens. + keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be + kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by + ``keep_ratio``. + keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be + replaced instead of being kept. + """ + super().__init__(sampler) + self.keep_top_n = keep_top_n + self.keep_ratio = keep_ratio + self.invert_keep = invert_keep + + def set_score(self, value: torch.Tensor) -> None: + pos_sorted = torch.argsort(value, descending=True) + top_n = int(math.ceil(self.keep_ratio * value.shape[-1])) if not self.keep_top_n else self.keep_top_n + pos_top_n = pos_sorted[..., :top_n] + self.replacement_mask = torch.ones_like(value, device=value.device, dtype=torch.bool).scatter( + -1, pos_top_n, self.invert_keep + ) + + @override + def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]: + """Sample a sequence + + Args: + input: Input sequence of ids of shape [batch, sequence] + + Returns: + input_replaced: A replaced sequence [batch, sequence] + replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence] + """ + token_sampled = self.sampler(input) + input_replaced = input * ~self.replacement_mask + token_sampled * self.replacement_mask + return input_replaced, self.replacement_mask + + +class UniformTokenReplacer(TokenReplacer): + """Replace tokens in a sequence where selecting is base on uniform distribution""" + + @override + def __init__(self, sampler: TokenSampler, ratio: float) -> None: + """Constructor + + Args: + sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object for sampling replacement tokens. + ratio: Ratio of tokens to replace in the sequence. + """ + super().__init__(sampler) + self.ratio = ratio + + @override + def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]: + """Sample a sequence + + Args: + input: Input sequence of ids of shape [batch, sequence] + + Returns: + input_replaced: A replaced sequence [batch, sequence] + replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence] + """ + sample_uniform = torch.rand(input.shape, device=input.device) + replacement_mask = sample_uniform < self.ratio + token_sampled = self.sampler(input) + input_replaced = input * ~replacement_mask + token_sampled * replacement_mask + return input_replaced, replacement_mask diff --git a/inseq/attr/feat/ops/reagent_core/token_sampler.py b/inseq/attr/feat/ops/reagent_core/token_sampler.py new file mode 100644 index 00000000..7ca41bf2 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_sampler.py @@ -0,0 +1,107 @@ +import logging +from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import Any, Optional, Union + +import torch +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from typing_extensions import override + +from .....utils import INSEQ_ARTIFACTS_CACHE, cache_results, is_nltk_available +from .....utils.typing import IdsTensor + +logger = logging.getLogger(__name__) + + +class TokenSampler(ABC): + """Base class for token samplers""" + + @abstractmethod + def __call__(self, input: IdsTensor, **kwargs) -> IdsTensor: + """Sample tokens according to the specified strategy. + + Args: + input: input tensor [batch, sequence] + + Returns: + token_uniform: A sampled tensor where its shape is the same with the input + """ + raise NotImplementedError() + + +class POSTagTokenSampler(TokenSampler): + """Sample tokens from Uniform distribution on a set of words with the same POS tag.""" + + def __init__( + self, + tokenizer: Union[str, PreTrainedTokenizerBase], + identifier: str = "pos_tag_sampler", + save_cache: bool = True, + overwrite_cache: bool = False, + cache_dir: Path = INSEQ_ARTIFACTS_CACHE / "pos_tag_sampler_cache", + device: Optional[str] = None, + tokenizer_kwargs: Optional[dict[str, Any]] = {}, + ) -> None: + if isinstance(tokenizer, PreTrainedTokenizerBase): + self.tokenizer = tokenizer + else: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs) + cache_filename = cache_dir / f"{identifier.split('/')[-1]}.pkl" + self.pos2ids = self.build_pos_mapping_from_vocab( + cache_dir, + cache_filename, + save_cache, + overwrite_cache, + tokenizer=self.tokenizer, + ) + num_postags = len(self.pos2ids) + self.id2pos = torch.zeros([self.tokenizer.vocab_size], dtype=torch.long, device=device) + for pos_idx, ids in enumerate(self.pos2ids.values()): + self.id2pos[ids] = pos_idx + self.num_ids_per_pos = torch.tensor( + [len(ids) for ids in self.pos2ids.values()], dtype=torch.long, device=device + ) + self.offsets = torch.sum( + torch.tril(torch.ones([num_postags, num_postags], device=device), diagonal=-1) * self.num_ids_per_pos, + dim=-1, + ) + self.compact_idx = torch.cat( + tuple(torch.tensor(v, dtype=torch.long, device=device) for v in self.pos2ids.values()) + ) + + @staticmethod + @cache_results + def build_pos_mapping_from_vocab( + tokenizer: PreTrainedTokenizerBase, + log_every: int = 5000, + ) -> dict[str, list[int]]: + """Build mapping from POS tags to list of token ids from tokenizer's vocabulary.""" + if not is_nltk_available(): + raise ImportError("nltk is required to build POS tag mapping. Please install nltk.") + import nltk + + nltk.download("averaged_perceptron_tagger") + pos2ids = defaultdict(list) + for i in range(tokenizer.vocab_size): + word = tokenizer.decode([i]) + _, tag = nltk.pos_tag([word.strip()])[0] + pos2ids[tag].append(i) + if i % log_every == 0: + logger.info(f"Loading vocab from tokenizer - {i / tokenizer.vocab_size * 100:.2f}%") + return pos2ids + + @override + def __call__(self, input_ids: IdsTensor) -> IdsTensor: + """Sample a tensor + + Args: + input: input tensor [batch, sequence] + + Returns: + token_uniform: A sampled tensor where its shape is the same with the input + """ + input_ids_pos = self.id2pos[input_ids] + sample_uniform = torch.rand(input_ids.shape, device=input_ids.device) + compact_group_idx = (sample_uniform * self.num_ids_per_pos[input_ids_pos] + self.offsets[input_ids_pos]).long() + return self.compact_idx[compact_group_idx] diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py index c3eb0211..498093af 100644 --- a/inseq/attr/feat/perturbation_attribution.py +++ b/inseq/attr/feat/perturbation_attribution.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import TYPE_CHECKING, Any from captum.attr import Occlusion @@ -11,7 +11,10 @@ from ...utils import Registry from .attribution_utils import get_source_target_attributions from .gradient_attribution import FeatureAttribution -from .ops import Lime, ValueZeroing +from .ops import Lime, Reagent, ValueZeroing + +if TYPE_CHECKING: + from ...models import HuggingfaceModel logger = logging.getLogger(__name__) @@ -120,6 +123,70 @@ def attribute_step( ) +class ReagentAttribution(PerturbationAttributionRegistry): + """Recursive attribution generator (ReAGent) method. + + Measures importance as the drop in prediction probability produced by replacing a token with a plausible + alternative predicted by a LM. + + Reference implementation: + `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models `__ + """ + + method_name = "reagent" + + def __init__( + self, + attribution_model: "HuggingfaceModel", + keep_top_n: int = 5, + keep_ratio: float = None, + invert_keep: bool = False, + stopping_condition_top_k: int = 3, + replacing_ratio: float = 0.3, + max_probe_steps: int = 3000, + num_probes: int = 16, + ): + """ReAGent method constructor. + + Args: + keep_top_n (:obj:`int`, `optional`): If set to a value greater than 0, the top n tokens based on their importance score will be + kept during the prediction inference. If set to 0, the top n will be determined by ``keep_ratio``. Default: ``5``. + keep_ratio (:obj:`float`, `optional`): If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + invert_keep (:obj:`bool`, `optional`): If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be + replaced instead of being kept. Default: ``False``. + stopping_condition_top_k (:obj:`int`, `optional`): Threshold indicating that the stop condition achieved when the predicted target + exist in top k predictions. Default: ``3``. + replacing_ratio (:obj:`float`, `optional`): replacing ratio of tokens for probing. Default: ``0.3``. + max_probe_steps (:obj:`int`, `optional`): Max number of steps before stopping the probing. Default: ``3000``. + num_probes (:obj:`int`, `optional`): Number of probes performed in parallel. Default: ``16``. + """ + super().__init__(attribution_model) + # Custom target attribution is currently not supported + self.use_predicted_target = False + self.method = Reagent( + attribution_model=self.attribution_model, + keep_top_n=keep_top_n, + keep_ratio=keep_ratio, + invert_keep=invert_keep, + stopping_condition_top_k=stopping_condition_top_k, + replacing_ratio=replacing_ratio, + max_probe_steps=max_probe_steps, + num_probes=num_probes, + ) + + def attribute_step( + self, + attribute_fn_main_args: dict[str, Any], + attribution_args: dict[str, Any] = {}, + ) -> GranularFeatureAttributionStepOutput: + out = super().attribute_step(attribute_fn_main_args, attribution_args) + return GranularFeatureAttributionStepOutput( + source_attributions=out.source_attributions, + target_attributions=out.target_attributions, + sequence_scores=out.sequence_scores, + ) + + class ValueZeroingAttribution(PerturbationAttributionRegistry): """Value Zeroing method for feature attribution. diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py index 69d9d1ad..f632ba32 100644 --- a/inseq/utils/__init__.py +++ b/inseq/utils/__init__.py @@ -14,6 +14,7 @@ is_datasets_available, is_ipywidgets_available, is_joblib_available, + is_nltk_available, is_scikitlearn_available, is_sentencepiece_available, is_transformers_available, @@ -99,6 +100,7 @@ "is_datasets_available", "is_captum_available", "is_joblib_available", + "is_nltk_available", "check_device", "get_default_device", "ndarray_to_bin_str", diff --git a/inseq/utils/import_utils.py b/inseq/utils/import_utils.py index cbd03420..2a1ccc2d 100644 --- a/inseq/utils/import_utils.py +++ b/inseq/utils/import_utils.py @@ -7,6 +7,7 @@ _datasets_available = find_spec("datasets") is not None _captum_available = find_spec("captum") is not None _joblib_available = find_spec("joblib") is not None +_nltk_available = find_spec("nltk") is not None def is_ipywidgets_available(): @@ -35,3 +36,7 @@ def is_captum_available(): def is_joblib_available(): return _joblib_available + + +def is_nltk_available(): + return _nltk_available diff --git a/pyproject.toml b/pyproject.toml index 3babbe9a..379b8060 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ docs = [ ] lint = [ "bandit>=1.7.4", - "safety>=2.2.0", + "safety>=3.1.0", "pydoclint>=0.4.0", "pre-commit>=2.19.0", "pytest>=7.2.0", @@ -93,6 +93,9 @@ notebook = [ "ipykernel>=6.29.2", "ipywidgets>=8.1.2" ] +nltk = [ + "nltk>=3.8.1", +] [project.urls] homepage = "https://github.com/inseq-team/inseq" diff --git a/requirements-dev.txt b/requirements-dev.txt index 92a9ca95..2fda0ec1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ -# This file was autogenerated by uv v0.1.2 via the following command: +# This file was autogenerated by uv via the following command: # uv pip compile --all-extras pyproject.toml -o requirements-dev.txt aiohttp==3.9.3 # via @@ -32,6 +32,7 @@ charset-normalizer==3.3.2 # via requests click==8.1.7 # via + # nltk # pydoclint # safety # typer @@ -43,7 +44,7 @@ contourpy==1.2.0 # via matplotlib coverage==7.4.1 # via pytest-cov -cryptography==42.0.2 +cryptography==42.0.5 # via authlib cycler==0.12.1 # via matplotlib @@ -123,7 +124,9 @@ jinja2==3.1.3 # sphinx # torch joblib==1.3.2 - # via scikit-learn + # via + # nltk + # scikit-learn jupyter-client==8.6.0 # via ipykernel jupyter-core==5.7.1 @@ -160,6 +163,7 @@ nest-asyncio==1.6.0 # via ipykernel networkx==3.2.1 # via torch +nltk==3.8.1 nodeenv==1.8.0 # via pre-commit numpy==1.26.4 @@ -258,7 +262,9 @@ pyzmq==25.1.2 # ipykernel # jupyter-client regex==2023.12.25 - # via transformers + # via + # nltk + # transformers requests==2.31.0 # via # datasets @@ -280,7 +286,7 @@ ruamel-yaml-clib==0.2.8 ruff==0.2.1 safetensors==0.4.2 # via transformers -safety==3.0.1 +safety==3.1.0 safety-schemas==0.0.2 # via safety scikit-learn==1.4.0 @@ -351,6 +357,7 @@ tqdm==4.66.2 # captum # datasets # huggingface-hub + # nltk # transformers traitlets==5.14.1 # via diff --git a/requirements.txt b/requirements.txt index a0a99e61..93809632 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -# This file was autogenerated by uv v0.1.2 via the following command: +# This file was autogenerated by uv via the following command: # uv pip compile pyproject.toml -o requirements.txt captum==0.7.0 certifi==2024.2.2