diff --git a/README.md b/README.md
index 5e09d734..583d8a88 100644
--- a/README.md
+++ b/README.md
@@ -149,6 +149,8 @@ Use the `inseq.list_feature_attribution_methods` function to list all available
- `value_zeroing`: [Quantifying Context Mixing in Transformers](https://aclanthology.org/2023.eacl-main.245/) (Mohebbi et al. 2023)
+- `reagent`: [ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models](https://arxiv.org/abs/2402.00794) (Zhao et al., 2024)
+
#### Step functions
Step functions are used to extract custom scores from the model at each step of the attribution process with the `step_scores` argument in `model.attribute`. They can also be used as targets for attribution methods relying on model outputs (e.g. gradient-based methods) by passing them as the `attributed_fn` argument. The following step functions are currently supported:
@@ -303,7 +305,10 @@ If you use Inseq in your research we suggest to include a mention to the specifi
## Research using Inseq
-Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below. If you know more, please let us know or submit a pull request (*last updated: February 2024*).
+Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below.
+
+> [!TIP]
+> Last update: April 2024. Please open a pull request to add your publication to the list.
2023
@@ -324,6 +329,7 @@ Inseq has been used in various research projects. A list of known publications t
- LLMCheckup: Conversational Examination of Large Language Models via Interpretability Tools (Wang et al., 2024)
- ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models (Zhao et al., 2024)
+ - Revisiting subword tokenization: A case study on affixal negation in large language models (Truong et al., 2024)
diff --git a/docs/source/main_classes/feature_attribution.rst b/docs/source/main_classes/feature_attribution.rst
index 1f282626..d7c4f5fc 100644
--- a/docs/source/main_classes/feature_attribution.rst
+++ b/docs/source/main_classes/feature_attribution.rst
@@ -90,4 +90,25 @@ Perturbation-based Attribution Methods
:members:
.. autoclass:: inseq.attr.feat.ValueZeroingAttribution
- :members:
\ No newline at end of file
+ :members:
+
+.. autoclass:: inseq.attr.feat.ReagentAttribution
+ :members:
+
+ .. automethod:: __init__
+
+.. code:: python
+
+ import inseq
+
+ model = inseq.load_model(
+ "gpt2-medium",
+ "reagent",
+ keep_top_n=5,
+ stopping_condition_top_k=3,
+ replacing_ratio=0.3,
+ max_probe_steps=3000,
+ num_probes=8
+ )
+ out = model.attribute("Super Mario Land is a game that developed by")
+ out.show()
diff --git a/inseq/attr/feat/__init__.py b/inseq/attr/feat/__init__.py
index 2b25778a..7a81014f 100644
--- a/inseq/attr/feat/__init__.py
+++ b/inseq/attr/feat/__init__.py
@@ -18,6 +18,7 @@
LimeAttribution,
OcclusionAttribution,
PerturbationAttributionRegistry,
+ ReagentAttribution,
ValueZeroingAttribution,
)
@@ -43,4 +44,5 @@
"SequentialIntegratedGradientsAttribution",
"ValueZeroingAttribution",
"PerturbationAttributionRegistry",
+ "ReagentAttribution",
]
diff --git a/inseq/attr/feat/ops/__init__.py b/inseq/attr/feat/ops/__init__.py
index 7d86167a..a40b9dba 100644
--- a/inseq/attr/feat/ops/__init__.py
+++ b/inseq/attr/feat/ops/__init__.py
@@ -1,6 +1,7 @@
from .discretized_integrated_gradients import DiscretetizedIntegratedGradients
from .lime import Lime
from .monotonic_path_builder import MonotonicPathBuilder
+from .reagent import Reagent
from .sequential_integrated_gradients import SequentialIntegratedGradients
from .value_zeroing import ValueZeroing
@@ -9,5 +10,6 @@
"MonotonicPathBuilder",
"ValueZeroing",
"Lime",
+ "Reagent",
"SequentialIntegratedGradients",
]
diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py
new file mode 100644
index 00000000..080a2131
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent.py
@@ -0,0 +1,134 @@
+from typing import TYPE_CHECKING, Any, Union
+
+import torch
+from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
+from torch import Tensor
+from typing_extensions import override
+
+from ....utils.typing import InseqAttribution
+from .reagent_core import (
+ AggregateRationalizer,
+ DeltaProbImportanceScoreEvaluator,
+ POSTagTokenSampler,
+ TopKStoppingConditionEvaluator,
+ UniformTokenReplacer,
+)
+
+if TYPE_CHECKING:
+ from ....models import HuggingfaceModel
+
+
+class Reagent(InseqAttribution):
+ r"""Recursive attribution generator (ReAGent) method.
+
+ Measures importance as the drop in prediction probability produced by replacing a token with a plausible
+ alternative predicted by a LM.
+
+ Reference implementation:
+ `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models
+ `__
+
+ Args:
+ forward_func (callable): The forward function of the model or any modification of it
+ keep_top_n (int): If set to a value greater than 0, the top n tokens based on their importance score will be
+ kept during the prediction inference. If set to 0, the top n will be determined by ``keep_ratio``.
+ keep_ratio (float): If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep.
+ invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be
+ replaced instead of being kept.
+ stopping_condition_top_k (int): Threshold indicating that the stop condition achieved when the predicted target
+ exist in top k predictions
+ replacing_ratio (float): replacing ratio of tokens for probing
+ max_probe_steps (int): max_probe_steps
+ num_probes (int): number of probes in parallel
+
+ Example:
+ ```
+ import inseq
+
+ model = inseq.load_model("gpt2-medium", "reagent",
+ keep_top_n=5,
+ stopping_condition_top_k=3,
+ replacing_ratio=0.3,
+ max_probe_steps=3000,
+ num_probes=8
+ )
+ out = model.attribute("Super Mario Land is a game that developed by")
+ out.show()
+ ```
+ """
+
+ def __init__(
+ self,
+ attribution_model: "HuggingfaceModel",
+ keep_top_n: int = 5,
+ keep_ratio: float = None,
+ invert_keep: bool = False,
+ stopping_condition_top_k: int = 3,
+ replacing_ratio: float = 0.3,
+ max_probe_steps: int = 3000,
+ num_probes: int = 16,
+ ) -> None:
+ super().__init__(attribution_model)
+
+ model = attribution_model.model
+ tokenizer = attribution_model.tokenizer
+ model_name = attribution_model.model_name
+
+ sampler = POSTagTokenSampler(tokenizer=tokenizer, identifier=model_name, device=attribution_model.device)
+ stopping_condition_evaluator = TopKStoppingConditionEvaluator(
+ model=model,
+ sampler=sampler,
+ top_k=stopping_condition_top_k,
+ keep_top_n=keep_top_n,
+ keep_ratio=keep_ratio,
+ invert_keep=invert_keep,
+ )
+ importance_score_evaluator = DeltaProbImportanceScoreEvaluator(
+ model=model,
+ tokenizer=tokenizer,
+ token_replacer=UniformTokenReplacer(sampler=sampler, ratio=replacing_ratio),
+ stopping_condition_evaluator=stopping_condition_evaluator,
+ max_steps=max_probe_steps,
+ )
+
+ self.rationalizer = AggregateRationalizer(
+ importance_score_evaluator=importance_score_evaluator,
+ batch_size=num_probes,
+ overlap_threshold=0,
+ overlap_strict_pos=True,
+ keep_top_n=keep_top_n,
+ keep_ratio=keep_ratio,
+ )
+
+ @override
+ def attribute( # type: ignore
+ self,
+ inputs: TensorOrTupleOfTensorsGeneric,
+ _target: TargetType = None,
+ additional_forward_args: Any = None,
+ ) -> Union[
+ TensorOrTupleOfTensorsGeneric,
+ tuple[TensorOrTupleOfTensorsGeneric, Tensor],
+ ]:
+ """Implement attribute"""
+ # encoder-decoder
+ if self.forward_func.is_encoder_decoder:
+ # with target-side attribution
+ if len(inputs) > 1:
+ self.rationalizer(
+ additional_forward_args[0], additional_forward_args[2], additional_forward_args[1], True
+ )
+ mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0)
+ res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2])
+ return (
+ res[:, : additional_forward_args[0].shape[1], :],
+ res[:, additional_forward_args[0].shape[1] :, :],
+ )
+ # source-side only
+ else:
+ self.rationalizer(additional_forward_args[1], additional_forward_args[3], additional_forward_args[2])
+ # decoder-only
+ self.rationalizer(additional_forward_args[0], additional_forward_args[1])
+ mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0)
+ res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2])
+ return (res,)
diff --git a/inseq/attr/feat/ops/reagent_core/__init__.py b/inseq/attr/feat/ops/reagent_core/__init__.py
new file mode 100644
index 00000000..13917c00
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent_core/__init__.py
@@ -0,0 +1,13 @@
+from .importance_score_evaluator import DeltaProbImportanceScoreEvaluator
+from .rationalizer import AggregateRationalizer
+from .stopping_condition_evaluator import TopKStoppingConditionEvaluator
+from .token_replacer import UniformTokenReplacer
+from .token_sampler import POSTagTokenSampler
+
+__all__ = [
+ "DeltaProbImportanceScoreEvaluator",
+ "AggregateRationalizer",
+ "TopKStoppingConditionEvaluator",
+ "UniformTokenReplacer",
+ "POSTagTokenSampler",
+]
diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py
new file mode 100644
index 00000000..adc79b7b
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py
@@ -0,0 +1,196 @@
+from __future__ import annotations
+
+import logging
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import torch
+from jaxtyping import Float
+from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
+from typing_extensions import override
+
+from .....utils.typing import IdsTensor, MultipleScoresPerStepTensor, TargetIdsTensor
+from .stopping_condition_evaluator import StoppingConditionEvaluator
+from .token_replacer import TokenReplacer
+
+
+class BaseImportanceScoreEvaluator(ABC):
+ """Importance Score Evaluator"""
+
+ def __init__(self, model: AutoModelForCausalLM | AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer) -> None:
+ """Base Constructor
+
+ Args:
+ model: A Huggingface AutoModelForCausalLM or AutoModelForSeq2SeqLM model
+ tokenizer: A Huggingface AutoTokenizer
+ """
+ self.model = model
+ self.tokenizer = tokenizer
+ self.importance_score = None
+
+ @abstractmethod
+ def __call__(
+ self,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> MultipleScoresPerStepTensor:
+ """Evaluate importance score of input sequence
+
+ Args:
+ input_ids: input sequence [batch, sequence]
+ target_id: target token [batch]
+ decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ importance_score: evaluated importance score for each token in the input [batch, sequence]
+
+ """
+ raise NotImplementedError()
+
+
+class DeltaProbImportanceScoreEvaluator(BaseImportanceScoreEvaluator):
+ """Importance Score Evaluator"""
+
+ @override
+ def __init__(
+ self,
+ model: AutoModelForCausalLM | AutoModelForSeq2SeqLM,
+ tokenizer: AutoTokenizer,
+ token_replacer: TokenReplacer,
+ stopping_condition_evaluator: StoppingConditionEvaluator,
+ max_steps: float,
+ ) -> None:
+ """Constructor
+
+ Args:
+ model: A Huggingface AutoModelForCausalLM or AutoModelForSeq2SeqLM model
+ tokenizer: A Huggingface AutoTokenizer
+ token_replacer: A TokenReplacer
+ stopping_condition_evaluator: A StoppingConditionEvaluator
+ """
+ super().__init__(model, tokenizer)
+ self.token_replacer = token_replacer
+ self.stopping_condition_evaluator = stopping_condition_evaluator
+ self.max_steps = max_steps
+ self.importance_score = None
+ self.num_steps = 0
+
+ def update_importance_score(
+ self,
+ logit_importance_score: MultipleScoresPerStepTensor,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ prob_original_target: Float[torch.Tensor, "batch_size 1"],
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> MultipleScoresPerStepTensor:
+ """Update importance score by one step
+
+ Args:
+ logit_importance_score: Current importance score in logistic scale [batch, sequence]
+ input_ids: input tensor [batch, sequence]
+ target_id: target tensor [batch]
+ prob_original_target: predictive probability of the target on the original sequence [batch, 1]
+ decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ logit_importance_score: updated importance score in logistic scale [batch, sequence]
+ """
+ # Randomly replace a set of tokens R to form a new sequence \hat{y_{1...t}}
+ if not attribute_target:
+ input_ids_replaced, mask_replacing = self.token_replacer(input_ids)
+ else:
+ ids_replaced, mask_replacing = self.token_replacer(torch.cat((input_ids, decoder_input_ids), 1))
+ input_ids_replaced = ids_replaced[:, : input_ids.shape[1]]
+ decoder_input_ids_replaced = ids_replaced[:, input_ids.shape[1] :]
+
+ logging.debug(f"Replacing mask: { mask_replacing }")
+ logging.debug(
+ f"Replaced sequence: { [[ self.tokenizer.decode(seq[i]) for i in range(input_ids_replaced.shape[1]) ] for seq in input_ids_replaced ] }"
+ )
+
+ # Inference \hat{p^{(y)}} = p(y_{t+1}|\hat{y_{1...t}})
+ kwargs = {"input_ids": input_ids_replaced}
+ if decoder_input_ids is not None:
+ kwargs["decoder_input_ids"] = decoder_input_ids_replaced if attribute_target else decoder_input_ids
+ logits_replaced = self.model(**kwargs)["logits"]
+ prob_replaced_target = torch.softmax(logits_replaced[:, -1, :], -1)[:, target_id]
+
+ # Compute changes delta = p^{(y)} - \hat{p^{(y)}}
+ delta_prob_target = prob_original_target - prob_replaced_target
+ logging.debug(f"likelihood delta: { delta_prob_target }")
+
+ # Update importance scores based on delta (magnitude) and replacement (direction)
+ delta_score = mask_replacing * delta_prob_target + ~mask_replacing * -delta_prob_target
+ # TODO: better solution?
+ # Rescaling from [-1, 1] to [0, 1] before logit function
+ logit_delta_score = torch.logit(delta_score * 0.5 + 0.5)
+ logit_importance_score = logit_importance_score + logit_delta_score
+ logging.debug(f"Updated importance score: { torch.softmax(logit_importance_score, -1) }")
+ return logit_importance_score
+
+ @override
+ def __call__(
+ self,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> MultipleScoresPerStepTensor:
+ """Evaluate importance score of input sequence
+
+ Args:
+ input_ids: input sequence [batch, sequence]
+ target_id: target token [batch]
+ decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ importance_score: evaluated importance score for each token in the input [batch, sequence]
+ """
+ self.stop_mask = torch.zeros([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device)
+
+ # Inference p^{(y)} = p(y_{t+1}|y_{1...t})
+ if decoder_input_ids is None:
+ logits_original = self.model(input_ids)["logits"]
+ else:
+ logits_original = self.model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)["logits"]
+
+ prob_original_target = torch.softmax(logits_original[:, -1, :], -1)[:, target_id]
+
+ # Initialize importance score s for each token in the sequence y_{1...t}
+ if not attribute_target:
+ logit_importance_score = torch.rand(input_ids.shape, device=input_ids.device)
+ else:
+ logit_importance_score = torch.rand(
+ (input_ids.shape[0], input_ids.shape[1] + decoder_input_ids.shape[1]), device=input_ids.device
+ )
+ logging.debug(f"Initialize importance score -> { torch.softmax(logit_importance_score, -1) }")
+
+ # TODO: limit max steps
+ self.num_steps = 0
+ while self.num_steps < self.max_steps:
+ self.num_steps += 1
+ # Update importance score
+ logit_importance_score_update = self.update_importance_score(
+ logit_importance_score, input_ids, target_id, prob_original_target, decoder_input_ids, attribute_target
+ )
+ logit_importance_score = (
+ ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update
+ + torch.unsqueeze(self.stop_mask, 1) * logit_importance_score
+ )
+ self.importance_score = torch.softmax(logit_importance_score, -1)
+
+ # Evaluate stop condition
+ self.stop_mask = self.stop_mask | self.stopping_condition_evaluator(
+ input_ids, target_id, self.importance_score, decoder_input_ids, attribute_target
+ )
+ if torch.prod(self.stop_mask) > 0:
+ break
+
+ logging.info(f"Importance score evaluated in {self.num_steps} steps.")
+ return torch.softmax(logit_importance_score, -1)
diff --git a/inseq/attr/feat/ops/reagent_core/rationalizer.py b/inseq/attr/feat/ops/reagent_core/rationalizer.py
new file mode 100644
index 00000000..ab7c2be8
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent_core/rationalizer.py
@@ -0,0 +1,128 @@
+import math
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import torch
+from jaxtyping import Int64
+from typing_extensions import override
+
+from .....utils.typing import IdsTensor, TargetIdsTensor
+from .importance_score_evaluator import BaseImportanceScoreEvaluator
+
+
+class BaseRationalizer(ABC):
+ def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator) -> None:
+ super().__init__()
+ self.importance_score_evaluator = importance_score_evaluator
+ self.mean_importance_score = None
+
+ @abstractmethod
+ def __call__(
+ self,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> Int64[torch.Tensor, "batch_size other_dims"]:
+ """Compute rational of a sequence on a target
+
+ Args:
+ input_ids: The sequence [batch, sequence] (first dimension need to be 1)
+ target_id: The target [batch]
+ decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ pos_top_n: rational position in the sequence [batch, rational_size]
+
+ """
+ raise NotImplementedError()
+
+
+class AggregateRationalizer(BaseRationalizer):
+ """AggregateRationalizer"""
+
+ @override
+ def __init__(
+ self,
+ importance_score_evaluator: BaseImportanceScoreEvaluator,
+ batch_size: int,
+ overlap_threshold: int,
+ overlap_strict_pos: bool = True,
+ keep_top_n: int = 0,
+ keep_ratio: float = 0,
+ ) -> None:
+ """Constructor
+
+ Args:
+ importance_score_evaluator: A ImportanceScoreEvaluator
+ batch_size: Batch size for aggregate
+ overlap_threshold: Overlap threshold of rational tokens within a batch
+ overlap_strict_pos: Whether overlap strict to position ot not
+ keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be
+ kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by
+ ``keep_ratio``.
+ keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep.
+ """
+ super().__init__(importance_score_evaluator)
+ self.batch_size = batch_size
+ self.overlap_threshold = overlap_threshold
+ self.overlap_strict_pos = overlap_strict_pos
+ self.keep_top_n = keep_top_n
+ self.keep_ratio = keep_ratio
+ assert overlap_strict_pos, "overlap_strict_pos = False is not supported yet"
+
+ @override
+ @torch.no_grad()
+ def __call__(
+ self,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> Int64[torch.Tensor, "batch_size other_dims"]:
+ """Compute rational of a sequence on a target
+
+ Args:
+ input_ids: A tensor of ids of shape [batch, sequence_len]
+ target_id: A tensor of predicted targets of size [batch]
+ decoder_input_ids (optional): A tensor of ids representing the decoder input sequence for
+ ``AutoModelForSeq2SeqLM``, with shape [batch, sequence_len]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ pos_top_n: rational position in the sequence [batch, rational_size]
+
+ """
+ assert input_ids.shape[0] == 1, "the first dimension of input (batch_size) need to be 1"
+ batch_input_ids = input_ids.repeat(self.batch_size, 1)
+ batch_decoder_input_ids = (
+ decoder_input_ids.repeat(self.batch_size, 1) if decoder_input_ids is not None else None
+ )
+ batch_importance_score = self.importance_score_evaluator(
+ batch_input_ids, target_id, batch_decoder_input_ids, attribute_target
+ )
+ importance_score_masked = batch_importance_score * torch.unsqueeze(
+ self.importance_score_evaluator.stop_mask, -1
+ )
+ self.mean_importance_score = torch.sum(importance_score_masked, dim=0) / torch.sum(
+ self.importance_score_evaluator.stop_mask
+ )
+ pos_sorted = torch.argsort(batch_importance_score, dim=-1, descending=True)
+ top_n = int(math.ceil(self.keep_ratio * input_ids.shape[-1])) if not self.keep_top_n else self.keep_top_n
+ pos_top_n = pos_sorted[:, :top_n]
+ self.pos_top_n = pos_top_n
+ if self.overlap_strict_pos:
+ count_overlap = torch.bincount(pos_top_n.flatten(), minlength=input_ids.shape[1])
+ pos_top_n_overlap = torch.unsqueeze(
+ torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0
+ )
+ return pos_top_n_overlap
+ else:
+ raise NotImplementedError("overlap_strict_pos = False not been supported yet")
+ # TODO: Convert back to pos
+ # token_id_top_n = input_ids[0, pos_top_n]
+ # count_overlap = torch.bincount(token_id_top_n.flatten(), minlength=input_ids.shape[1])
+ # _token_id_top_n_overlap = torch.unsqueeze(
+ # torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0
+ # )
diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py
new file mode 100644
index 00000000..fd3bb67d
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py
@@ -0,0 +1,136 @@
+import logging
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import torch
+from transformers import AutoModelForCausalLM
+
+from .....utils.typing import IdsTensor, MultipleScoresPerStepTensor, TargetIdsTensor
+from .token_replacer import RankingTokenReplacer
+from .token_sampler import TokenSampler
+
+
+class StoppingConditionEvaluator(ABC):
+ """Base class for Stopping Condition Evaluators"""
+
+ @abstractmethod
+ def __call__(
+ self,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ importance_score: MultipleScoresPerStepTensor,
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> TargetIdsTensor:
+ """Evaluate stop condition according to the specified strategy.
+
+ Args:
+ input_ids: Input sequence [batch, sequence]
+ target_id: Target token [batch]
+ importance_score: Importance score of the input [batch, sequence]
+ decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ Boolean flag per sequence signaling whether the stop condition was reached [batch]
+
+ """
+ raise NotImplementedError()
+
+
+class TopKStoppingConditionEvaluator(StoppingConditionEvaluator):
+ """
+ Evaluator stopping when target exist among the top k predictions,
+ while top n tokens based on importance_score are not been replaced.
+ """
+
+ def __init__(
+ self,
+ model: AutoModelForCausalLM,
+ sampler: TokenSampler,
+ top_k: int,
+ keep_top_n: int = 0,
+ keep_ratio: float = 0,
+ invert_keep: bool = False,
+ ) -> None:
+ """Constructor for the TopKStoppingConditionEvaluator class.
+
+ Args:
+ model: A Huggingface ``AutoModelForCausalLM``.
+ sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object to sample replacement tokens.
+ top_k: Top K predictions in which the target must be included in order to achieve the stopping condition.
+ keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be
+ kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by
+ ``keep_ratio``.
+ keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep.
+ invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be
+ replaced instead of being kept.
+ """
+ self.model = model
+ self.top_k = top_k
+ self.replacer = RankingTokenReplacer(sampler, keep_top_n, keep_ratio, invert_keep)
+
+ def __call__(
+ self,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ importance_score: MultipleScoresPerStepTensor,
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> TargetIdsTensor:
+ """Evaluate stop condition
+
+ Args:
+ input_ids: Input sequence [batch, sequence]
+ target_id: Target token [batch]
+ importance_score: Importance score of the input [batch, sequence]
+ decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ Boolean flag per sequence signaling whether the stop condition was reached [batch]
+ """
+ # Replace tokens with low importance score and then inference \hat{y^{(e)}_{t+1}}
+ self.replacer.set_score(importance_score)
+ if not attribute_target:
+ input_ids_replaced, mask_replacing = self.replacer(input_ids)
+ else:
+ ids_replaced, mask_replacing = self.replacer(torch.cat((input_ids, decoder_input_ids), 1))
+ input_ids_replaced = ids_replaced[:, : input_ids.shape[1]]
+ decoder_input_ids_replaced = ids_replaced[:, input_ids.shape[1] :]
+
+ logging.debug(f"Replacing mask based on importance score -> { mask_replacing }")
+
+ # Whether the result \hat{y^{(e)}_{t+1}} consistent with y_{t+1}
+ assert not input_ids_replaced.requires_grad, "Error: auto-diff engine not disabled"
+ with torch.no_grad():
+ kwargs = {"input_ids": input_ids_replaced}
+ if decoder_input_ids is not None:
+ kwargs["decoder_input_ids"] = decoder_input_ids_replaced if attribute_target else decoder_input_ids
+ logits_replaced = self.model(**kwargs)["logits"]
+ ids_prediction_sorted = torch.argsort(logits_replaced[:, -1, :], descending=True)
+ ids_prediction_top_k = ids_prediction_sorted[:, : self.top_k]
+ match_mask = ids_prediction_top_k == target_id
+ match_hit = torch.sum(match_mask, dim=-1, dtype=torch.bool)
+ return match_hit
+
+
+class DummyStoppingConditionEvaluator(StoppingConditionEvaluator):
+ """
+ Stopping Condition Evaluator which stop when target exist in top k predictions,
+ while top n tokens based on importance_score are not been replaced.
+ """
+
+ def __call__(self, input_ids: IdsTensor, **kwargs) -> TargetIdsTensor:
+ """Evaluate stop condition
+
+ Args:
+ input_ids: Input sequence [batch, sequence]
+ target_id: Target token [batch]
+ importance_score: Importance score of the input [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ Boolean flag per sequence signaling whether the stop condition was reached [batch]
+ """
+ return torch.ones([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device)
diff --git a/inseq/attr/feat/ops/reagent_core/token_replacer.py b/inseq/attr/feat/ops/reagent_core/token_replacer.py
new file mode 100644
index 00000000..0d889144
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent_core/token_replacer.py
@@ -0,0 +1,111 @@
+import math
+from abc import ABC, abstractmethod
+
+import torch
+from typing_extensions import override
+
+from .....utils.typing import IdsTensor
+from .token_sampler import TokenSampler
+
+
+class TokenReplacer(ABC):
+ """
+ Base class for token replacers
+
+ """
+
+ def __init__(self, sampler: TokenSampler) -> None:
+ self.sampler = sampler
+
+ @abstractmethod
+ def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]:
+ """Replace tokens according to the specified strategy.
+
+ Args:
+ input: input sequence [batch, sequence]
+
+ Returns:
+ input_replaced: A replaced sequence [batch, sequence]
+ replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence]
+
+ """
+ raise NotImplementedError()
+
+
+class RankingTokenReplacer(TokenReplacer):
+ """Replace tokens in a sequence based on top-N ranking"""
+
+ @override
+ def __init__(
+ self, sampler: TokenSampler, keep_top_n: int = 0, keep_ratio: float = 0, invert_keep: bool = False
+ ) -> None:
+ """Constructor for the RankingTokenReplacer class.
+
+ Args:
+ sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object for sampling replacement tokens.
+ keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be
+ kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by
+ ``keep_ratio``.
+ keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep.
+ invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be
+ replaced instead of being kept.
+ """
+ super().__init__(sampler)
+ self.keep_top_n = keep_top_n
+ self.keep_ratio = keep_ratio
+ self.invert_keep = invert_keep
+
+ def set_score(self, value: torch.Tensor) -> None:
+ pos_sorted = torch.argsort(value, descending=True)
+ top_n = int(math.ceil(self.keep_ratio * value.shape[-1])) if not self.keep_top_n else self.keep_top_n
+ pos_top_n = pos_sorted[..., :top_n]
+ self.replacement_mask = torch.ones_like(value, device=value.device, dtype=torch.bool).scatter(
+ -1, pos_top_n, self.invert_keep
+ )
+
+ @override
+ def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]:
+ """Sample a sequence
+
+ Args:
+ input: Input sequence of ids of shape [batch, sequence]
+
+ Returns:
+ input_replaced: A replaced sequence [batch, sequence]
+ replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence]
+ """
+ token_sampled = self.sampler(input)
+ input_replaced = input * ~self.replacement_mask + token_sampled * self.replacement_mask
+ return input_replaced, self.replacement_mask
+
+
+class UniformTokenReplacer(TokenReplacer):
+ """Replace tokens in a sequence where selecting is base on uniform distribution"""
+
+ @override
+ def __init__(self, sampler: TokenSampler, ratio: float) -> None:
+ """Constructor
+
+ Args:
+ sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object for sampling replacement tokens.
+ ratio: Ratio of tokens to replace in the sequence.
+ """
+ super().__init__(sampler)
+ self.ratio = ratio
+
+ @override
+ def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]:
+ """Sample a sequence
+
+ Args:
+ input: Input sequence of ids of shape [batch, sequence]
+
+ Returns:
+ input_replaced: A replaced sequence [batch, sequence]
+ replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence]
+ """
+ sample_uniform = torch.rand(input.shape, device=input.device)
+ replacement_mask = sample_uniform < self.ratio
+ token_sampled = self.sampler(input)
+ input_replaced = input * ~replacement_mask + token_sampled * replacement_mask
+ return input_replaced, replacement_mask
diff --git a/inseq/attr/feat/ops/reagent_core/token_sampler.py b/inseq/attr/feat/ops/reagent_core/token_sampler.py
new file mode 100644
index 00000000..7ca41bf2
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent_core/token_sampler.py
@@ -0,0 +1,107 @@
+import logging
+from abc import ABC, abstractmethod
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Optional, Union
+
+import torch
+from transformers import AutoTokenizer, PreTrainedTokenizerBase
+from typing_extensions import override
+
+from .....utils import INSEQ_ARTIFACTS_CACHE, cache_results, is_nltk_available
+from .....utils.typing import IdsTensor
+
+logger = logging.getLogger(__name__)
+
+
+class TokenSampler(ABC):
+ """Base class for token samplers"""
+
+ @abstractmethod
+ def __call__(self, input: IdsTensor, **kwargs) -> IdsTensor:
+ """Sample tokens according to the specified strategy.
+
+ Args:
+ input: input tensor [batch, sequence]
+
+ Returns:
+ token_uniform: A sampled tensor where its shape is the same with the input
+ """
+ raise NotImplementedError()
+
+
+class POSTagTokenSampler(TokenSampler):
+ """Sample tokens from Uniform distribution on a set of words with the same POS tag."""
+
+ def __init__(
+ self,
+ tokenizer: Union[str, PreTrainedTokenizerBase],
+ identifier: str = "pos_tag_sampler",
+ save_cache: bool = True,
+ overwrite_cache: bool = False,
+ cache_dir: Path = INSEQ_ARTIFACTS_CACHE / "pos_tag_sampler_cache",
+ device: Optional[str] = None,
+ tokenizer_kwargs: Optional[dict[str, Any]] = {},
+ ) -> None:
+ if isinstance(tokenizer, PreTrainedTokenizerBase):
+ self.tokenizer = tokenizer
+ else:
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)
+ cache_filename = cache_dir / f"{identifier.split('/')[-1]}.pkl"
+ self.pos2ids = self.build_pos_mapping_from_vocab(
+ cache_dir,
+ cache_filename,
+ save_cache,
+ overwrite_cache,
+ tokenizer=self.tokenizer,
+ )
+ num_postags = len(self.pos2ids)
+ self.id2pos = torch.zeros([self.tokenizer.vocab_size], dtype=torch.long, device=device)
+ for pos_idx, ids in enumerate(self.pos2ids.values()):
+ self.id2pos[ids] = pos_idx
+ self.num_ids_per_pos = torch.tensor(
+ [len(ids) for ids in self.pos2ids.values()], dtype=torch.long, device=device
+ )
+ self.offsets = torch.sum(
+ torch.tril(torch.ones([num_postags, num_postags], device=device), diagonal=-1) * self.num_ids_per_pos,
+ dim=-1,
+ )
+ self.compact_idx = torch.cat(
+ tuple(torch.tensor(v, dtype=torch.long, device=device) for v in self.pos2ids.values())
+ )
+
+ @staticmethod
+ @cache_results
+ def build_pos_mapping_from_vocab(
+ tokenizer: PreTrainedTokenizerBase,
+ log_every: int = 5000,
+ ) -> dict[str, list[int]]:
+ """Build mapping from POS tags to list of token ids from tokenizer's vocabulary."""
+ if not is_nltk_available():
+ raise ImportError("nltk is required to build POS tag mapping. Please install nltk.")
+ import nltk
+
+ nltk.download("averaged_perceptron_tagger")
+ pos2ids = defaultdict(list)
+ for i in range(tokenizer.vocab_size):
+ word = tokenizer.decode([i])
+ _, tag = nltk.pos_tag([word.strip()])[0]
+ pos2ids[tag].append(i)
+ if i % log_every == 0:
+ logger.info(f"Loading vocab from tokenizer - {i / tokenizer.vocab_size * 100:.2f}%")
+ return pos2ids
+
+ @override
+ def __call__(self, input_ids: IdsTensor) -> IdsTensor:
+ """Sample a tensor
+
+ Args:
+ input: input tensor [batch, sequence]
+
+ Returns:
+ token_uniform: A sampled tensor where its shape is the same with the input
+ """
+ input_ids_pos = self.id2pos[input_ids]
+ sample_uniform = torch.rand(input_ids.shape, device=input_ids.device)
+ compact_group_idx = (sample_uniform * self.num_ids_per_pos[input_ids_pos] + self.offsets[input_ids_pos]).long()
+ return self.compact_idx[compact_group_idx]
diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py
index c3eb0211..498093af 100644
--- a/inseq/attr/feat/perturbation_attribution.py
+++ b/inseq/attr/feat/perturbation_attribution.py
@@ -1,5 +1,5 @@
import logging
-from typing import Any
+from typing import TYPE_CHECKING, Any
from captum.attr import Occlusion
@@ -11,7 +11,10 @@
from ...utils import Registry
from .attribution_utils import get_source_target_attributions
from .gradient_attribution import FeatureAttribution
-from .ops import Lime, ValueZeroing
+from .ops import Lime, Reagent, ValueZeroing
+
+if TYPE_CHECKING:
+ from ...models import HuggingfaceModel
logger = logging.getLogger(__name__)
@@ -120,6 +123,70 @@ def attribute_step(
)
+class ReagentAttribution(PerturbationAttributionRegistry):
+ """Recursive attribution generator (ReAGent) method.
+
+ Measures importance as the drop in prediction probability produced by replacing a token with a plausible
+ alternative predicted by a LM.
+
+ Reference implementation:
+ `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models `__
+ """
+
+ method_name = "reagent"
+
+ def __init__(
+ self,
+ attribution_model: "HuggingfaceModel",
+ keep_top_n: int = 5,
+ keep_ratio: float = None,
+ invert_keep: bool = False,
+ stopping_condition_top_k: int = 3,
+ replacing_ratio: float = 0.3,
+ max_probe_steps: int = 3000,
+ num_probes: int = 16,
+ ):
+ """ReAGent method constructor.
+
+ Args:
+ keep_top_n (:obj:`int`, `optional`): If set to a value greater than 0, the top n tokens based on their importance score will be
+ kept during the prediction inference. If set to 0, the top n will be determined by ``keep_ratio``. Default: ``5``.
+ keep_ratio (:obj:`float`, `optional`): If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep.
+ invert_keep (:obj:`bool`, `optional`): If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be
+ replaced instead of being kept. Default: ``False``.
+ stopping_condition_top_k (:obj:`int`, `optional`): Threshold indicating that the stop condition achieved when the predicted target
+ exist in top k predictions. Default: ``3``.
+ replacing_ratio (:obj:`float`, `optional`): replacing ratio of tokens for probing. Default: ``0.3``.
+ max_probe_steps (:obj:`int`, `optional`): Max number of steps before stopping the probing. Default: ``3000``.
+ num_probes (:obj:`int`, `optional`): Number of probes performed in parallel. Default: ``16``.
+ """
+ super().__init__(attribution_model)
+ # Custom target attribution is currently not supported
+ self.use_predicted_target = False
+ self.method = Reagent(
+ attribution_model=self.attribution_model,
+ keep_top_n=keep_top_n,
+ keep_ratio=keep_ratio,
+ invert_keep=invert_keep,
+ stopping_condition_top_k=stopping_condition_top_k,
+ replacing_ratio=replacing_ratio,
+ max_probe_steps=max_probe_steps,
+ num_probes=num_probes,
+ )
+
+ def attribute_step(
+ self,
+ attribute_fn_main_args: dict[str, Any],
+ attribution_args: dict[str, Any] = {},
+ ) -> GranularFeatureAttributionStepOutput:
+ out = super().attribute_step(attribute_fn_main_args, attribution_args)
+ return GranularFeatureAttributionStepOutput(
+ source_attributions=out.source_attributions,
+ target_attributions=out.target_attributions,
+ sequence_scores=out.sequence_scores,
+ )
+
+
class ValueZeroingAttribution(PerturbationAttributionRegistry):
"""Value Zeroing method for feature attribution.
diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py
index 69d9d1ad..f632ba32 100644
--- a/inseq/utils/__init__.py
+++ b/inseq/utils/__init__.py
@@ -14,6 +14,7 @@
is_datasets_available,
is_ipywidgets_available,
is_joblib_available,
+ is_nltk_available,
is_scikitlearn_available,
is_sentencepiece_available,
is_transformers_available,
@@ -99,6 +100,7 @@
"is_datasets_available",
"is_captum_available",
"is_joblib_available",
+ "is_nltk_available",
"check_device",
"get_default_device",
"ndarray_to_bin_str",
diff --git a/inseq/utils/import_utils.py b/inseq/utils/import_utils.py
index cbd03420..2a1ccc2d 100644
--- a/inseq/utils/import_utils.py
+++ b/inseq/utils/import_utils.py
@@ -7,6 +7,7 @@
_datasets_available = find_spec("datasets") is not None
_captum_available = find_spec("captum") is not None
_joblib_available = find_spec("joblib") is not None
+_nltk_available = find_spec("nltk") is not None
def is_ipywidgets_available():
@@ -35,3 +36,7 @@ def is_captum_available():
def is_joblib_available():
return _joblib_available
+
+
+def is_nltk_available():
+ return _nltk_available
diff --git a/pyproject.toml b/pyproject.toml
index 3babbe9a..379b8060 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -74,7 +74,7 @@ docs = [
]
lint = [
"bandit>=1.7.4",
- "safety>=2.2.0",
+ "safety>=3.1.0",
"pydoclint>=0.4.0",
"pre-commit>=2.19.0",
"pytest>=7.2.0",
@@ -93,6 +93,9 @@ notebook = [
"ipykernel>=6.29.2",
"ipywidgets>=8.1.2"
]
+nltk = [
+ "nltk>=3.8.1",
+]
[project.urls]
homepage = "https://github.com/inseq-team/inseq"
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 92a9ca95..2fda0ec1 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,4 +1,4 @@
-# This file was autogenerated by uv v0.1.2 via the following command:
+# This file was autogenerated by uv via the following command:
# uv pip compile --all-extras pyproject.toml -o requirements-dev.txt
aiohttp==3.9.3
# via
@@ -32,6 +32,7 @@ charset-normalizer==3.3.2
# via requests
click==8.1.7
# via
+ # nltk
# pydoclint
# safety
# typer
@@ -43,7 +44,7 @@ contourpy==1.2.0
# via matplotlib
coverage==7.4.1
# via pytest-cov
-cryptography==42.0.2
+cryptography==42.0.5
# via authlib
cycler==0.12.1
# via matplotlib
@@ -123,7 +124,9 @@ jinja2==3.1.3
# sphinx
# torch
joblib==1.3.2
- # via scikit-learn
+ # via
+ # nltk
+ # scikit-learn
jupyter-client==8.6.0
# via ipykernel
jupyter-core==5.7.1
@@ -160,6 +163,7 @@ nest-asyncio==1.6.0
# via ipykernel
networkx==3.2.1
# via torch
+nltk==3.8.1
nodeenv==1.8.0
# via pre-commit
numpy==1.26.4
@@ -258,7 +262,9 @@ pyzmq==25.1.2
# ipykernel
# jupyter-client
regex==2023.12.25
- # via transformers
+ # via
+ # nltk
+ # transformers
requests==2.31.0
# via
# datasets
@@ -280,7 +286,7 @@ ruamel-yaml-clib==0.2.8
ruff==0.2.1
safetensors==0.4.2
# via transformers
-safety==3.0.1
+safety==3.1.0
safety-schemas==0.0.2
# via safety
scikit-learn==1.4.0
@@ -351,6 +357,7 @@ tqdm==4.66.2
# captum
# datasets
# huggingface-hub
+ # nltk
# transformers
traitlets==5.14.1
# via
diff --git a/requirements.txt b/requirements.txt
index a0a99e61..93809632 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-# This file was autogenerated by uv v0.1.2 via the following command:
+# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml -o requirements.txt
captum==0.7.0
certifi==2024.2.2