diff --git a/CHANGELOG.md b/CHANGELOG.md
index faeca60e..b4a36ee2 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,21 +4,15 @@
## š Features
-- Support for multi-GPU attribution ([#238](https://github.com/inseq-team/inseq/pull/238))
-- Added `inseq attribute-context` CLI command to support the [PECoRe framework] for detecting and attributing context reliance in generative LMs ([#237](https://github.com/inseq-team/inseq/pull/237))
-
-## š§ Fixes & Refactoring
-
-- Fix `ContiguousSpanAggregator` and `SubwordAggregator` edge case of single-step generation ([#247](https://github.com/inseq-team/inseq/pull/247))
-- Move tensors to CPU right away in the forward pass to avoid OOM when cloning ([#245](https://github.com/inseq-team/inseq/pull/245))
-- Fix `remap_from_filtered` behavior on sequence_scores tensors. ([#245](https://github.com/inseq-team/inseq/pull/245))
-- Use torch-native padding when converting lists of `FeatureAttributionStepOutput` to `FeatureAttributionSequenceOutput` in `get_sequences_from_batched_steps`. ([#245](https://github.com/inseq-team/inseq/pull/245))
-- Bump `ruff` version ([#245](https://github.com/inseq-team/inseq/pull/245))
-- Drop `poetry` in favor of [`uv`](https://github.com/astral-sh/uv) to accelerate package installation and simplify config in `pyproject.toml`. ([#249](https://github.com/inseq-team/inseq/pull/249))
-- Drop `darglint` in favor of `pydoclint`. ([#249](https://github.com/inseq-team/inseq/pull/249))
-- Replace Arxiv with ACL Anthology badge in `README`. ([#249](https://github.com/inseq-team/inseq/pull/249))
-- Add first version of `CHANGELOG.md` ([#249](https://github.com/inseq-team/inseq/pull/249))
-- Added multithread support for running tests using `pytest-xdist`
+- Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM` to model config.
+
+## š§ Fixes and Refactoring
+
+- Fix the issue in the attention implementation from [#268](https://github.com/inseq-team/inseq/issues/268) where non-terminal position in the tensor were set to nan if they were 0s ([#269](https://github.com/inseq-team/inseq/pull/269)).
+
+- Fix the pad token in cases where it is not specified by default in the loaded model (e.g. for Qwen models) ([#269](https://github.com/inseq-team/inseq/pull/269)).
+
+- Fix bug reported in [#266](https://github.com/inseq-team/inseq/issues/266) making `value_zeroing` unusable for SDPA attention. This enables using the method on models using SDPA attention as default (e.g. `GemmaForCausalLM`) without passing `model_kwargs={'attn_implementation': 'eager'}` ([#267](https://github.com/inseq-team/inseq/pull/267)).
## š Documentation and Tutorials
@@ -26,4 +20,4 @@
## š„ Breaking Changes
-*No changes*
+*No changes*
\ No newline at end of file
diff --git a/Makefile b/Makefile
index f0eb967b..a03222c5 100644
--- a/Makefile
+++ b/Makefile
@@ -60,7 +60,7 @@ install-dev:
.PHONY: install-ci
install-ci:
- make uv-activate && uv pip install -e .[lint]
+ make uv-activate && uv pip install -r requirements-dev.txt
.PHONY: update-deps
update-deps:
diff --git a/README.md b/README.md
index 79aa0563..583d8a88 100644
--- a/README.md
+++ b/README.md
@@ -147,6 +147,10 @@ Use the `inseq.list_feature_attribution_methods` function to list all available
- `lime`: ["Why Should I Trust You?": Explaining the Predictions of Any Classifier](https://arxiv.org/abs/1602.04938) (Ribeiro et al., 2016)
+- `value_zeroing`: [Quantifying Context Mixing in Transformers](https://aclanthology.org/2023.eacl-main.245/) (Mohebbi et al. 2023)
+
+- `reagent`: [ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models](https://arxiv.org/abs/2402.00794) (Zhao et al., 2024)
+
#### Step functions
Step functions are used to extract custom scores from the model at each step of the attribution process with the `step_scores` argument in `model.attribute`. They can also be used as targets for attribution methods relying on model outputs (e.g. gradient-based methods) by passing them as the `attributed_fn` argument. The following step functions are currently supported:
@@ -301,7 +305,10 @@ If you use Inseq in your research we suggest to include a mention to the specifi
## Research using Inseq
-Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below. If you know more, please let us know or submit a pull request (*last updated: February 2024*).
+Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below.
+
+> [!TIP]
+> Last update: April 2024. Please open a pull request to add your publication to the list.
2023
@@ -322,6 +329,7 @@ Inseq has been used in various research projects. A list of known publications t
- LLMCheckup: Conversational Examination of Large Language Models via Interpretability Tools (Wang et al., 2024)
- ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models (Zhao et al., 2024)
+ - Revisiting subword tokenization: A case study on affixal negation in large language models (Truong et al., 2024)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 12d27a66..d2805c1e 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -21,13 +21,13 @@
# -- Project information -----------------------------------------------------
project = "inseq"
-copyright = "2023, The Inseq Team, Licensed under the Apache License, Version 2.0"
+copyright = "2024 , The Inseq Team, Licensed under the Apache License, Version 2.0"
author = "The Inseq Team"
# The short X.Y version
-version = "0.6"
+version = "0.7"
# The full version, including alpha/beta/rc tags
-release = "0.6.0.dev0"
+release = "0.7.0.dev0"
# Prefix link to point to master, comment this during version release and uncomment below line
diff --git a/docs/source/main_classes/cli.rst b/docs/source/main_classes/cli.rst
index 1793360c..8230e020 100644
--- a/docs/source/main_classes/cli.rst
+++ b/docs/source/main_classes/cli.rst
@@ -23,7 +23,7 @@ Three commands are supported:
- ``inseq attribute-dataset``: Extends ``attribute`` to full dataset using Hugging Face ``datasets.load_dataset`` API.
-- ``inseq attribute-context``: Detects and attribute context dependence for generation tasks using the approach of `Sarti et al. (2023) `__.
+- ``inseq attribute-context``: Detects and attribute context dependence for generation tasks using the approach of `Sarti et al. (2023) `__.
``attribute``
-----------------------------------------------------------------------------------------------------------------------
@@ -47,6 +47,6 @@ The ``attribute-dataset`` command extends the ``attribute`` command to full data
-----------------------------------------------------------------------------------------------------------------------
The ``attribute-context`` command detects and attributes context dependence for generation tasks using the approach of
-`Sarti et al. (2023) `__. The command takes the following arguments:
+`Sarti et al. (2023) `__. The command takes the following arguments:
-.. autoclass:: inseq.commands.attribute_context.attribute_context_args.AttributeContextArgs
\ No newline at end of file
+.. autoclass:: inseq.commands.attribute_context.attribute_context_args.AttributeContextArgs
diff --git a/docs/source/main_classes/feature_attribution.rst b/docs/source/main_classes/feature_attribution.rst
index 174b405c..d7c4f5fc 100644
--- a/docs/source/main_classes/feature_attribution.rst
+++ b/docs/source/main_classes/feature_attribution.rst
@@ -17,7 +17,7 @@ Attribution Methods
.. autoclass:: inseq.attr.FeatureAttribution
:members:
-Gradient Attribution Methods
+Gradient-based Attribution Methods
-----------------------------------------------------------------------------------------------------------------------
.. autoclass:: inseq.attr.feat.GradientAttributionRegistry
@@ -67,7 +67,7 @@ Layer Attribution Methods
:members:
-Attention Attribution Methods
+Internals-based Attribution Methods
-----------------------------------------------------------------------------------------------------------------------
.. autoclass:: inseq.attr.feat.InternalsAttributionRegistry
@@ -76,3 +76,39 @@ Attention Attribution Methods
.. autoclass:: inseq.attr.feat.AttentionWeightsAttribution
:members:
+
+Perturbation-based Attribution Methods
+-----------------------------------------------------------------------------------------------------------------------
+
+.. autoclass:: inseq.attr.feat.PerturbationAttributionRegistry
+ :members:
+
+.. autoclass:: inseq.attr.feat.OcclusionAttribution
+ :members:
+
+.. autoclass:: inseq.attr.feat.LimeAttribution
+ :members:
+
+.. autoclass:: inseq.attr.feat.ValueZeroingAttribution
+ :members:
+
+.. autoclass:: inseq.attr.feat.ReagentAttribution
+ :members:
+
+ .. automethod:: __init__
+
+.. code:: python
+
+ import inseq
+
+ model = inseq.load_model(
+ "gpt2-medium",
+ "reagent",
+ keep_top_n=5,
+ stopping_condition_top_k=3,
+ replacing_ratio=0.3,
+ max_probe_steps=3000,
+ num_probes=8
+ )
+ out = model.attribute("Super Mario Land is a game that developed by")
+ out.show()
diff --git a/inseq/attr/feat/__init__.py b/inseq/attr/feat/__init__.py
index cc07f530..7a81014f 100644
--- a/inseq/attr/feat/__init__.py
+++ b/inseq/attr/feat/__init__.py
@@ -17,6 +17,9 @@
from .perturbation_attribution import (
LimeAttribution,
OcclusionAttribution,
+ PerturbationAttributionRegistry,
+ ReagentAttribution,
+ ValueZeroingAttribution,
)
__all__ = [
@@ -39,4 +42,7 @@
"OcclusionAttribution",
"LimeAttribution",
"SequentialIntegratedGradientsAttribution",
+ "ValueZeroingAttribution",
+ "PerturbationAttributionRegistry",
+ "ReagentAttribution",
]
diff --git a/inseq/attr/feat/attribution_utils.py b/inseq/attr/feat/attribution_utils.py
index 8da4f899..a9679845 100644
--- a/inseq/attr/feat/attribution_utils.py
+++ b/inseq/attr/feat/attribution_utils.py
@@ -144,11 +144,15 @@ def extract_args(
def get_source_target_attributions(
attr: Union[StepAttributionTensor, tuple[StepAttributionTensor, StepAttributionTensor]],
is_encoder_decoder: bool,
+ has_sequence_scores: bool = False,
) -> tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]:
if isinstance(attr, tuple):
if is_encoder_decoder:
- return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None)
+ if has_sequence_scores:
+ return (attr[0], attr[1], attr[2])
+ else:
+ return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None)
else:
- return (None, attr[0])
+ return (None, None, attr[0]) if has_sequence_scores else (None, attr[0])
else:
return (attr, None) if is_encoder_decoder else (None, attr)
diff --git a/inseq/attr/feat/feature_attribution.py b/inseq/attr/feat/feature_attribution.py
index 250cd700..3d832bc3 100644
--- a/inseq/attr/feat/feature_attribution.py
+++ b/inseq/attr/feat/feature_attribution.py
@@ -114,6 +114,7 @@ def __init__(self, attribution_model: "AttributionModel", hook_to_model: bool =
self.use_hidden_states: bool = False
self.use_predicted_target: bool = True
self.use_model_config: bool = False
+ self.is_final_step_method: bool = False
if hook_to_model:
self.hook(**kwargs)
@@ -272,6 +273,35 @@ def _run_compatibility_checks(self, attributed_fn) -> None:
" method."
)
+ @staticmethod
+ def _build_multistep_output_from_single_step(
+ single_step_output: FeatureAttributionStepOutput,
+ attr_pos_start: int,
+ attr_pos_end: int,
+ ) -> list[FeatureAttributionStepOutput]:
+ if single_step_output.step_scores:
+ raise ValueError("step_scores are not supported for final step attribution methods.")
+ num_seq = len(single_step_output.prefix)
+ steps = []
+ for pos_idx in range(attr_pos_start, attr_pos_end):
+ step_output = single_step_output.clone_empty()
+ step_output.source = single_step_output.source
+ step_output.prefix = [single_step_output.prefix[seq_idx][:pos_idx] for seq_idx in range(num_seq)]
+ step_output.target = (
+ single_step_output.target
+ if pos_idx == attr_pos_end - 1
+ else [[single_step_output.prefix[seq_idx][pos_idx]] for seq_idx in range(num_seq)]
+ )
+ if single_step_output.source_attributions is not None:
+ step_output.source_attributions = single_step_output.source_attributions[:, :, pos_idx - 1]
+ if single_step_output.target_attributions is not None:
+ step_output.target_attributions = single_step_output.target_attributions[:, :pos_idx, pos_idx - 1]
+ single_step_output.step_scores = {}
+ if single_step_output.sequence_scores is not None:
+ step_output.sequence_scores = single_step_output.sequence_scores
+ steps.append(step_output)
+ return steps
+
def format_contrastive_targets(
self,
target_sequences: TextSequences,
@@ -416,9 +446,9 @@ def attribute(
target_lengths=targets_lengths,
method_name=self.method_name,
show=show_progress,
- pretty=pretty_progress,
+ pretty=False if self.is_final_step_method else pretty_progress,
attr_pos_start=attr_pos_start,
- attr_pos_end=attr_pos_end,
+ attr_pos_end=1 if self.is_final_step_method else attr_pos_end,
)
whitespace_indexes = find_char_indexes(sequences.targets, " ")
attribution_outputs = []
@@ -427,6 +457,8 @@ def attribute(
# Attribution loop for generation
for step in range(attr_pos_start, iter_pos_end):
+ if self.is_final_step_method and step != iter_pos_end - 1:
+ continue
tgt_ids, tgt_mask = batch.get_step_target(step, with_attention=True)
step_output = self.filtered_attribute_step(
batch[:step],
@@ -450,7 +482,7 @@ def attribute(
contrast_targets_alignments=contrast_targets_alignments,
)
attribution_outputs.append(step_output)
- if pretty_progress:
+ if pretty_progress and not self.is_final_step_method:
tgt_tokens = batch.target_tokens
skipped_prefixes = tok2string(self.attribution_model, tgt_tokens, end=attr_pos_start)
attributed_sentences = tok2string(self.attribution_model, tgt_tokens, attr_pos_start, step + 1)
@@ -464,19 +496,24 @@ def attribute(
skipped_suffixes,
whitespace_indexes,
show=show_progress,
- pretty=pretty_progress,
+ pretty=True,
)
else:
- update_progress_bar(pbar, show=show_progress, pretty=pretty_progress)
+ update_progress_bar(pbar, show=show_progress, pretty=False)
end = datetime.now()
- close_progress_bar(pbar, show=show_progress, pretty=pretty_progress)
+ close_progress_bar(pbar, show=show_progress, pretty=False if self.is_final_step_method else pretty_progress)
batch.detach().to("cpu")
+ if self.is_final_step_method:
+ attribution_outputs = self._build_multistep_output_from_single_step(
+ attribution_outputs[0],
+ attr_pos_start=attr_pos_start,
+ attr_pos_end=iter_pos_end,
+ )
out = FeatureAttributionOutput(
sequence_attributions=FeatureAttributionSequenceOutput.from_step_attributions(
attributions=attribution_outputs,
tokenized_target_sentences=target_tokens_with_ids,
- pad_id=self.attribution_model.pad_token,
- has_bos_token=self.attribution_model.is_encoder_decoder,
+ pad_token=self.attribution_model.pad_token,
attr_pos_end=attr_pos_end,
),
step_attributions=attribution_outputs if output_step_attributions else None,
@@ -593,7 +630,7 @@ def filtered_attribute_step(
step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args).to("cpu")
# Reinsert finished sentences
if target_attention_mask is not None and is_filtered:
- step_output.remap_from_filtered(target_attention_mask, orig_batch)
+ step_output.remap_from_filtered(target_attention_mask, orig_batch, self.is_final_step_method)
step_output = step_output.detach().to("cpu")
return step_output
diff --git a/inseq/attr/feat/internals_attribution.py b/inseq/attr/feat/internals_attribution.py
index 9c6e8923..f3f0479d 100644
--- a/inseq/attr/feat/internals_attribution.py
+++ b/inseq/attr/feat/internals_attribution.py
@@ -17,11 +17,10 @@
from typing import Any, Optional
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
-from captum.attr._utils.attribution import Attribution
from ...data import MultiDimensionalFeatureAttributionStepOutput
from ...utils import Registry
-from ...utils.typing import MultiLayerMultiUnitScoreTensor
+from ...utils.typing import InseqAttribution, MultiLayerMultiUnitScoreTensor
from .feature_attribution import FeatureAttribution
logger = logging.getLogger(__name__)
@@ -38,7 +37,7 @@ class AttentionWeightsAttribution(InternalsAttributionRegistry):
method_name = "attention"
- class AttentionWeights(Attribution):
+ class AttentionWeights(InseqAttribution):
@staticmethod
def has_convergence_delta() -> bool:
return False
@@ -74,9 +73,9 @@ def attribute(
:class:`~inseq.data.MultiDimensionalFeatureAttributionStepOutput`: A step output containing attention
weights for each layer and head, with shape :obj:`(batch_size, seq_len, n_layers, n_heads)`.
"""
- # We adopt the format [batch_size, sequence_length, num_layers, num_heads]
+ # We adopt the format [batch_size, sequence_length, sequence_length, num_layers, num_heads]
# for consistency with other multi-unit methods (e.g. gradient attribution)
- decoder_self_attentions = decoder_self_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2)
+ decoder_self_attentions = decoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2)
if self.forward_func.is_encoder_decoder:
sequence_scores = {}
if len(inputs) > 1:
@@ -85,10 +84,11 @@ def attribute(
target_attributions = None
sequence_scores["decoder_self_attentions"] = decoder_self_attentions
sequence_scores["encoder_self_attentions"] = (
- encoder_self_attentions.to("cpu").clone().permute(0, 3, 4, 1, 2)
+ encoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2)
)
+ cross_attentions = cross_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2)
return MultiDimensionalFeatureAttributionStepOutput(
- source_attributions=cross_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2),
+ source_attributions=cross_attentions,
target_attributions=target_attributions,
sequence_scores=sequence_scores,
_num_dimensions=2, # num_layers, num_heads
@@ -106,6 +106,8 @@ def __init__(self, attribution_model, **kwargs):
self.use_attention_weights = True
# Does not rely on predicted output (i.e. decoding strategy agnostic)
self.use_predicted_target = False
+ # Needs only the final generation step to extract scores
+ self.is_final_step_method = True
self.method = self.AttentionWeights(attribution_model)
def attribute_step(
diff --git a/inseq/attr/feat/ops/__init__.py b/inseq/attr/feat/ops/__init__.py
index 0e533b4a..7a4d726f 100644
--- a/inseq/attr/feat/ops/__init__.py
+++ b/inseq/attr/feat/ops/__init__.py
@@ -1,13 +1,17 @@
from .discretized_integrated_gradients import DiscretetizedIntegratedGradients
from .lime import Lime
from .monotonic_path_builder import MonotonicPathBuilder
+from .reagent import Reagent
from .rollout import rollout_fn
from .sequential_integrated_gradients import SequentialIntegratedGradients
+from .value_zeroing import ValueZeroing
__all__ = [
"DiscretetizedIntegratedGradients",
"MonotonicPathBuilder",
+ "ValueZeroing",
"Lime",
+ "Reagent",
"SequentialIntegratedGradients",
"rollout_fn",
]
diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py
new file mode 100644
index 00000000..080a2131
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent.py
@@ -0,0 +1,134 @@
+from typing import TYPE_CHECKING, Any, Union
+
+import torch
+from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
+from torch import Tensor
+from typing_extensions import override
+
+from ....utils.typing import InseqAttribution
+from .reagent_core import (
+ AggregateRationalizer,
+ DeltaProbImportanceScoreEvaluator,
+ POSTagTokenSampler,
+ TopKStoppingConditionEvaluator,
+ UniformTokenReplacer,
+)
+
+if TYPE_CHECKING:
+ from ....models import HuggingfaceModel
+
+
+class Reagent(InseqAttribution):
+ r"""Recursive attribution generator (ReAGent) method.
+
+ Measures importance as the drop in prediction probability produced by replacing a token with a plausible
+ alternative predicted by a LM.
+
+ Reference implementation:
+ `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models
+ `__
+
+ Args:
+ forward_func (callable): The forward function of the model or any modification of it
+ keep_top_n (int): If set to a value greater than 0, the top n tokens based on their importance score will be
+ kept during the prediction inference. If set to 0, the top n will be determined by ``keep_ratio``.
+ keep_ratio (float): If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep.
+ invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be
+ replaced instead of being kept.
+ stopping_condition_top_k (int): Threshold indicating that the stop condition achieved when the predicted target
+ exist in top k predictions
+ replacing_ratio (float): replacing ratio of tokens for probing
+ max_probe_steps (int): max_probe_steps
+ num_probes (int): number of probes in parallel
+
+ Example:
+ ```
+ import inseq
+
+ model = inseq.load_model("gpt2-medium", "reagent",
+ keep_top_n=5,
+ stopping_condition_top_k=3,
+ replacing_ratio=0.3,
+ max_probe_steps=3000,
+ num_probes=8
+ )
+ out = model.attribute("Super Mario Land is a game that developed by")
+ out.show()
+ ```
+ """
+
+ def __init__(
+ self,
+ attribution_model: "HuggingfaceModel",
+ keep_top_n: int = 5,
+ keep_ratio: float = None,
+ invert_keep: bool = False,
+ stopping_condition_top_k: int = 3,
+ replacing_ratio: float = 0.3,
+ max_probe_steps: int = 3000,
+ num_probes: int = 16,
+ ) -> None:
+ super().__init__(attribution_model)
+
+ model = attribution_model.model
+ tokenizer = attribution_model.tokenizer
+ model_name = attribution_model.model_name
+
+ sampler = POSTagTokenSampler(tokenizer=tokenizer, identifier=model_name, device=attribution_model.device)
+ stopping_condition_evaluator = TopKStoppingConditionEvaluator(
+ model=model,
+ sampler=sampler,
+ top_k=stopping_condition_top_k,
+ keep_top_n=keep_top_n,
+ keep_ratio=keep_ratio,
+ invert_keep=invert_keep,
+ )
+ importance_score_evaluator = DeltaProbImportanceScoreEvaluator(
+ model=model,
+ tokenizer=tokenizer,
+ token_replacer=UniformTokenReplacer(sampler=sampler, ratio=replacing_ratio),
+ stopping_condition_evaluator=stopping_condition_evaluator,
+ max_steps=max_probe_steps,
+ )
+
+ self.rationalizer = AggregateRationalizer(
+ importance_score_evaluator=importance_score_evaluator,
+ batch_size=num_probes,
+ overlap_threshold=0,
+ overlap_strict_pos=True,
+ keep_top_n=keep_top_n,
+ keep_ratio=keep_ratio,
+ )
+
+ @override
+ def attribute( # type: ignore
+ self,
+ inputs: TensorOrTupleOfTensorsGeneric,
+ _target: TargetType = None,
+ additional_forward_args: Any = None,
+ ) -> Union[
+ TensorOrTupleOfTensorsGeneric,
+ tuple[TensorOrTupleOfTensorsGeneric, Tensor],
+ ]:
+ """Implement attribute"""
+ # encoder-decoder
+ if self.forward_func.is_encoder_decoder:
+ # with target-side attribution
+ if len(inputs) > 1:
+ self.rationalizer(
+ additional_forward_args[0], additional_forward_args[2], additional_forward_args[1], True
+ )
+ mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0)
+ res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2])
+ return (
+ res[:, : additional_forward_args[0].shape[1], :],
+ res[:, additional_forward_args[0].shape[1] :, :],
+ )
+ # source-side only
+ else:
+ self.rationalizer(additional_forward_args[1], additional_forward_args[3], additional_forward_args[2])
+ # decoder-only
+ self.rationalizer(additional_forward_args[0], additional_forward_args[1])
+ mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0)
+ res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2])
+ return (res,)
diff --git a/inseq/attr/feat/ops/reagent_core/__init__.py b/inseq/attr/feat/ops/reagent_core/__init__.py
new file mode 100644
index 00000000..13917c00
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent_core/__init__.py
@@ -0,0 +1,13 @@
+from .importance_score_evaluator import DeltaProbImportanceScoreEvaluator
+from .rationalizer import AggregateRationalizer
+from .stopping_condition_evaluator import TopKStoppingConditionEvaluator
+from .token_replacer import UniformTokenReplacer
+from .token_sampler import POSTagTokenSampler
+
+__all__ = [
+ "DeltaProbImportanceScoreEvaluator",
+ "AggregateRationalizer",
+ "TopKStoppingConditionEvaluator",
+ "UniformTokenReplacer",
+ "POSTagTokenSampler",
+]
diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py
new file mode 100644
index 00000000..adc79b7b
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py
@@ -0,0 +1,196 @@
+from __future__ import annotations
+
+import logging
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import torch
+from jaxtyping import Float
+from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
+from typing_extensions import override
+
+from .....utils.typing import IdsTensor, MultipleScoresPerStepTensor, TargetIdsTensor
+from .stopping_condition_evaluator import StoppingConditionEvaluator
+from .token_replacer import TokenReplacer
+
+
+class BaseImportanceScoreEvaluator(ABC):
+ """Importance Score Evaluator"""
+
+ def __init__(self, model: AutoModelForCausalLM | AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer) -> None:
+ """Base Constructor
+
+ Args:
+ model: A Huggingface AutoModelForCausalLM or AutoModelForSeq2SeqLM model
+ tokenizer: A Huggingface AutoTokenizer
+ """
+ self.model = model
+ self.tokenizer = tokenizer
+ self.importance_score = None
+
+ @abstractmethod
+ def __call__(
+ self,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> MultipleScoresPerStepTensor:
+ """Evaluate importance score of input sequence
+
+ Args:
+ input_ids: input sequence [batch, sequence]
+ target_id: target token [batch]
+ decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ importance_score: evaluated importance score for each token in the input [batch, sequence]
+
+ """
+ raise NotImplementedError()
+
+
+class DeltaProbImportanceScoreEvaluator(BaseImportanceScoreEvaluator):
+ """Importance Score Evaluator"""
+
+ @override
+ def __init__(
+ self,
+ model: AutoModelForCausalLM | AutoModelForSeq2SeqLM,
+ tokenizer: AutoTokenizer,
+ token_replacer: TokenReplacer,
+ stopping_condition_evaluator: StoppingConditionEvaluator,
+ max_steps: float,
+ ) -> None:
+ """Constructor
+
+ Args:
+ model: A Huggingface AutoModelForCausalLM or AutoModelForSeq2SeqLM model
+ tokenizer: A Huggingface AutoTokenizer
+ token_replacer: A TokenReplacer
+ stopping_condition_evaluator: A StoppingConditionEvaluator
+ """
+ super().__init__(model, tokenizer)
+ self.token_replacer = token_replacer
+ self.stopping_condition_evaluator = stopping_condition_evaluator
+ self.max_steps = max_steps
+ self.importance_score = None
+ self.num_steps = 0
+
+ def update_importance_score(
+ self,
+ logit_importance_score: MultipleScoresPerStepTensor,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ prob_original_target: Float[torch.Tensor, "batch_size 1"],
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> MultipleScoresPerStepTensor:
+ """Update importance score by one step
+
+ Args:
+ logit_importance_score: Current importance score in logistic scale [batch, sequence]
+ input_ids: input tensor [batch, sequence]
+ target_id: target tensor [batch]
+ prob_original_target: predictive probability of the target on the original sequence [batch, 1]
+ decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ logit_importance_score: updated importance score in logistic scale [batch, sequence]
+ """
+ # Randomly replace a set of tokens R to form a new sequence \hat{y_{1...t}}
+ if not attribute_target:
+ input_ids_replaced, mask_replacing = self.token_replacer(input_ids)
+ else:
+ ids_replaced, mask_replacing = self.token_replacer(torch.cat((input_ids, decoder_input_ids), 1))
+ input_ids_replaced = ids_replaced[:, : input_ids.shape[1]]
+ decoder_input_ids_replaced = ids_replaced[:, input_ids.shape[1] :]
+
+ logging.debug(f"Replacing mask: { mask_replacing }")
+ logging.debug(
+ f"Replaced sequence: { [[ self.tokenizer.decode(seq[i]) for i in range(input_ids_replaced.shape[1]) ] for seq in input_ids_replaced ] }"
+ )
+
+ # Inference \hat{p^{(y)}} = p(y_{t+1}|\hat{y_{1...t}})
+ kwargs = {"input_ids": input_ids_replaced}
+ if decoder_input_ids is not None:
+ kwargs["decoder_input_ids"] = decoder_input_ids_replaced if attribute_target else decoder_input_ids
+ logits_replaced = self.model(**kwargs)["logits"]
+ prob_replaced_target = torch.softmax(logits_replaced[:, -1, :], -1)[:, target_id]
+
+ # Compute changes delta = p^{(y)} - \hat{p^{(y)}}
+ delta_prob_target = prob_original_target - prob_replaced_target
+ logging.debug(f"likelihood delta: { delta_prob_target }")
+
+ # Update importance scores based on delta (magnitude) and replacement (direction)
+ delta_score = mask_replacing * delta_prob_target + ~mask_replacing * -delta_prob_target
+ # TODO: better solution?
+ # Rescaling from [-1, 1] to [0, 1] before logit function
+ logit_delta_score = torch.logit(delta_score * 0.5 + 0.5)
+ logit_importance_score = logit_importance_score + logit_delta_score
+ logging.debug(f"Updated importance score: { torch.softmax(logit_importance_score, -1) }")
+ return logit_importance_score
+
+ @override
+ def __call__(
+ self,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> MultipleScoresPerStepTensor:
+ """Evaluate importance score of input sequence
+
+ Args:
+ input_ids: input sequence [batch, sequence]
+ target_id: target token [batch]
+ decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ importance_score: evaluated importance score for each token in the input [batch, sequence]
+ """
+ self.stop_mask = torch.zeros([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device)
+
+ # Inference p^{(y)} = p(y_{t+1}|y_{1...t})
+ if decoder_input_ids is None:
+ logits_original = self.model(input_ids)["logits"]
+ else:
+ logits_original = self.model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)["logits"]
+
+ prob_original_target = torch.softmax(logits_original[:, -1, :], -1)[:, target_id]
+
+ # Initialize importance score s for each token in the sequence y_{1...t}
+ if not attribute_target:
+ logit_importance_score = torch.rand(input_ids.shape, device=input_ids.device)
+ else:
+ logit_importance_score = torch.rand(
+ (input_ids.shape[0], input_ids.shape[1] + decoder_input_ids.shape[1]), device=input_ids.device
+ )
+ logging.debug(f"Initialize importance score -> { torch.softmax(logit_importance_score, -1) }")
+
+ # TODO: limit max steps
+ self.num_steps = 0
+ while self.num_steps < self.max_steps:
+ self.num_steps += 1
+ # Update importance score
+ logit_importance_score_update = self.update_importance_score(
+ logit_importance_score, input_ids, target_id, prob_original_target, decoder_input_ids, attribute_target
+ )
+ logit_importance_score = (
+ ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update
+ + torch.unsqueeze(self.stop_mask, 1) * logit_importance_score
+ )
+ self.importance_score = torch.softmax(logit_importance_score, -1)
+
+ # Evaluate stop condition
+ self.stop_mask = self.stop_mask | self.stopping_condition_evaluator(
+ input_ids, target_id, self.importance_score, decoder_input_ids, attribute_target
+ )
+ if torch.prod(self.stop_mask) > 0:
+ break
+
+ logging.info(f"Importance score evaluated in {self.num_steps} steps.")
+ return torch.softmax(logit_importance_score, -1)
diff --git a/inseq/attr/feat/ops/reagent_core/rationalizer.py b/inseq/attr/feat/ops/reagent_core/rationalizer.py
new file mode 100644
index 00000000..ab7c2be8
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent_core/rationalizer.py
@@ -0,0 +1,128 @@
+import math
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import torch
+from jaxtyping import Int64
+from typing_extensions import override
+
+from .....utils.typing import IdsTensor, TargetIdsTensor
+from .importance_score_evaluator import BaseImportanceScoreEvaluator
+
+
+class BaseRationalizer(ABC):
+ def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator) -> None:
+ super().__init__()
+ self.importance_score_evaluator = importance_score_evaluator
+ self.mean_importance_score = None
+
+ @abstractmethod
+ def __call__(
+ self,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> Int64[torch.Tensor, "batch_size other_dims"]:
+ """Compute rational of a sequence on a target
+
+ Args:
+ input_ids: The sequence [batch, sequence] (first dimension need to be 1)
+ target_id: The target [batch]
+ decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ pos_top_n: rational position in the sequence [batch, rational_size]
+
+ """
+ raise NotImplementedError()
+
+
+class AggregateRationalizer(BaseRationalizer):
+ """AggregateRationalizer"""
+
+ @override
+ def __init__(
+ self,
+ importance_score_evaluator: BaseImportanceScoreEvaluator,
+ batch_size: int,
+ overlap_threshold: int,
+ overlap_strict_pos: bool = True,
+ keep_top_n: int = 0,
+ keep_ratio: float = 0,
+ ) -> None:
+ """Constructor
+
+ Args:
+ importance_score_evaluator: A ImportanceScoreEvaluator
+ batch_size: Batch size for aggregate
+ overlap_threshold: Overlap threshold of rational tokens within a batch
+ overlap_strict_pos: Whether overlap strict to position ot not
+ keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be
+ kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by
+ ``keep_ratio``.
+ keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep.
+ """
+ super().__init__(importance_score_evaluator)
+ self.batch_size = batch_size
+ self.overlap_threshold = overlap_threshold
+ self.overlap_strict_pos = overlap_strict_pos
+ self.keep_top_n = keep_top_n
+ self.keep_ratio = keep_ratio
+ assert overlap_strict_pos, "overlap_strict_pos = False is not supported yet"
+
+ @override
+ @torch.no_grad()
+ def __call__(
+ self,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> Int64[torch.Tensor, "batch_size other_dims"]:
+ """Compute rational of a sequence on a target
+
+ Args:
+ input_ids: A tensor of ids of shape [batch, sequence_len]
+ target_id: A tensor of predicted targets of size [batch]
+ decoder_input_ids (optional): A tensor of ids representing the decoder input sequence for
+ ``AutoModelForSeq2SeqLM``, with shape [batch, sequence_len]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ pos_top_n: rational position in the sequence [batch, rational_size]
+
+ """
+ assert input_ids.shape[0] == 1, "the first dimension of input (batch_size) need to be 1"
+ batch_input_ids = input_ids.repeat(self.batch_size, 1)
+ batch_decoder_input_ids = (
+ decoder_input_ids.repeat(self.batch_size, 1) if decoder_input_ids is not None else None
+ )
+ batch_importance_score = self.importance_score_evaluator(
+ batch_input_ids, target_id, batch_decoder_input_ids, attribute_target
+ )
+ importance_score_masked = batch_importance_score * torch.unsqueeze(
+ self.importance_score_evaluator.stop_mask, -1
+ )
+ self.mean_importance_score = torch.sum(importance_score_masked, dim=0) / torch.sum(
+ self.importance_score_evaluator.stop_mask
+ )
+ pos_sorted = torch.argsort(batch_importance_score, dim=-1, descending=True)
+ top_n = int(math.ceil(self.keep_ratio * input_ids.shape[-1])) if not self.keep_top_n else self.keep_top_n
+ pos_top_n = pos_sorted[:, :top_n]
+ self.pos_top_n = pos_top_n
+ if self.overlap_strict_pos:
+ count_overlap = torch.bincount(pos_top_n.flatten(), minlength=input_ids.shape[1])
+ pos_top_n_overlap = torch.unsqueeze(
+ torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0
+ )
+ return pos_top_n_overlap
+ else:
+ raise NotImplementedError("overlap_strict_pos = False not been supported yet")
+ # TODO: Convert back to pos
+ # token_id_top_n = input_ids[0, pos_top_n]
+ # count_overlap = torch.bincount(token_id_top_n.flatten(), minlength=input_ids.shape[1])
+ # _token_id_top_n_overlap = torch.unsqueeze(
+ # torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0
+ # )
diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py
new file mode 100644
index 00000000..fd3bb67d
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py
@@ -0,0 +1,136 @@
+import logging
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import torch
+from transformers import AutoModelForCausalLM
+
+from .....utils.typing import IdsTensor, MultipleScoresPerStepTensor, TargetIdsTensor
+from .token_replacer import RankingTokenReplacer
+from .token_sampler import TokenSampler
+
+
+class StoppingConditionEvaluator(ABC):
+ """Base class for Stopping Condition Evaluators"""
+
+ @abstractmethod
+ def __call__(
+ self,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ importance_score: MultipleScoresPerStepTensor,
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> TargetIdsTensor:
+ """Evaluate stop condition according to the specified strategy.
+
+ Args:
+ input_ids: Input sequence [batch, sequence]
+ target_id: Target token [batch]
+ importance_score: Importance score of the input [batch, sequence]
+ decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ Boolean flag per sequence signaling whether the stop condition was reached [batch]
+
+ """
+ raise NotImplementedError()
+
+
+class TopKStoppingConditionEvaluator(StoppingConditionEvaluator):
+ """
+ Evaluator stopping when target exist among the top k predictions,
+ while top n tokens based on importance_score are not been replaced.
+ """
+
+ def __init__(
+ self,
+ model: AutoModelForCausalLM,
+ sampler: TokenSampler,
+ top_k: int,
+ keep_top_n: int = 0,
+ keep_ratio: float = 0,
+ invert_keep: bool = False,
+ ) -> None:
+ """Constructor for the TopKStoppingConditionEvaluator class.
+
+ Args:
+ model: A Huggingface ``AutoModelForCausalLM``.
+ sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object to sample replacement tokens.
+ top_k: Top K predictions in which the target must be included in order to achieve the stopping condition.
+ keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be
+ kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by
+ ``keep_ratio``.
+ keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep.
+ invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be
+ replaced instead of being kept.
+ """
+ self.model = model
+ self.top_k = top_k
+ self.replacer = RankingTokenReplacer(sampler, keep_top_n, keep_ratio, invert_keep)
+
+ def __call__(
+ self,
+ input_ids: IdsTensor,
+ target_id: TargetIdsTensor,
+ importance_score: MultipleScoresPerStepTensor,
+ decoder_input_ids: Optional[IdsTensor] = None,
+ attribute_target: bool = False,
+ ) -> TargetIdsTensor:
+ """Evaluate stop condition
+
+ Args:
+ input_ids: Input sequence [batch, sequence]
+ target_id: Target token [batch]
+ importance_score: Importance score of the input [batch, sequence]
+ decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ Boolean flag per sequence signaling whether the stop condition was reached [batch]
+ """
+ # Replace tokens with low importance score and then inference \hat{y^{(e)}_{t+1}}
+ self.replacer.set_score(importance_score)
+ if not attribute_target:
+ input_ids_replaced, mask_replacing = self.replacer(input_ids)
+ else:
+ ids_replaced, mask_replacing = self.replacer(torch.cat((input_ids, decoder_input_ids), 1))
+ input_ids_replaced = ids_replaced[:, : input_ids.shape[1]]
+ decoder_input_ids_replaced = ids_replaced[:, input_ids.shape[1] :]
+
+ logging.debug(f"Replacing mask based on importance score -> { mask_replacing }")
+
+ # Whether the result \hat{y^{(e)}_{t+1}} consistent with y_{t+1}
+ assert not input_ids_replaced.requires_grad, "Error: auto-diff engine not disabled"
+ with torch.no_grad():
+ kwargs = {"input_ids": input_ids_replaced}
+ if decoder_input_ids is not None:
+ kwargs["decoder_input_ids"] = decoder_input_ids_replaced if attribute_target else decoder_input_ids
+ logits_replaced = self.model(**kwargs)["logits"]
+ ids_prediction_sorted = torch.argsort(logits_replaced[:, -1, :], descending=True)
+ ids_prediction_top_k = ids_prediction_sorted[:, : self.top_k]
+ match_mask = ids_prediction_top_k == target_id
+ match_hit = torch.sum(match_mask, dim=-1, dtype=torch.bool)
+ return match_hit
+
+
+class DummyStoppingConditionEvaluator(StoppingConditionEvaluator):
+ """
+ Stopping Condition Evaluator which stop when target exist in top k predictions,
+ while top n tokens based on importance_score are not been replaced.
+ """
+
+ def __call__(self, input_ids: IdsTensor, **kwargs) -> TargetIdsTensor:
+ """Evaluate stop condition
+
+ Args:
+ input_ids: Input sequence [batch, sequence]
+ target_id: Target token [batch]
+ importance_score: Importance score of the input [batch, sequence]
+ attribute_target: whether attribute target for encoder-decoder models
+
+ Return:
+ Boolean flag per sequence signaling whether the stop condition was reached [batch]
+ """
+ return torch.ones([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device)
diff --git a/inseq/attr/feat/ops/reagent_core/token_replacer.py b/inseq/attr/feat/ops/reagent_core/token_replacer.py
new file mode 100644
index 00000000..0d889144
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent_core/token_replacer.py
@@ -0,0 +1,111 @@
+import math
+from abc import ABC, abstractmethod
+
+import torch
+from typing_extensions import override
+
+from .....utils.typing import IdsTensor
+from .token_sampler import TokenSampler
+
+
+class TokenReplacer(ABC):
+ """
+ Base class for token replacers
+
+ """
+
+ def __init__(self, sampler: TokenSampler) -> None:
+ self.sampler = sampler
+
+ @abstractmethod
+ def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]:
+ """Replace tokens according to the specified strategy.
+
+ Args:
+ input: input sequence [batch, sequence]
+
+ Returns:
+ input_replaced: A replaced sequence [batch, sequence]
+ replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence]
+
+ """
+ raise NotImplementedError()
+
+
+class RankingTokenReplacer(TokenReplacer):
+ """Replace tokens in a sequence based on top-N ranking"""
+
+ @override
+ def __init__(
+ self, sampler: TokenSampler, keep_top_n: int = 0, keep_ratio: float = 0, invert_keep: bool = False
+ ) -> None:
+ """Constructor for the RankingTokenReplacer class.
+
+ Args:
+ sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object for sampling replacement tokens.
+ keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be
+ kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by
+ ``keep_ratio``.
+ keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep.
+ invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be
+ replaced instead of being kept.
+ """
+ super().__init__(sampler)
+ self.keep_top_n = keep_top_n
+ self.keep_ratio = keep_ratio
+ self.invert_keep = invert_keep
+
+ def set_score(self, value: torch.Tensor) -> None:
+ pos_sorted = torch.argsort(value, descending=True)
+ top_n = int(math.ceil(self.keep_ratio * value.shape[-1])) if not self.keep_top_n else self.keep_top_n
+ pos_top_n = pos_sorted[..., :top_n]
+ self.replacement_mask = torch.ones_like(value, device=value.device, dtype=torch.bool).scatter(
+ -1, pos_top_n, self.invert_keep
+ )
+
+ @override
+ def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]:
+ """Sample a sequence
+
+ Args:
+ input: Input sequence of ids of shape [batch, sequence]
+
+ Returns:
+ input_replaced: A replaced sequence [batch, sequence]
+ replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence]
+ """
+ token_sampled = self.sampler(input)
+ input_replaced = input * ~self.replacement_mask + token_sampled * self.replacement_mask
+ return input_replaced, self.replacement_mask
+
+
+class UniformTokenReplacer(TokenReplacer):
+ """Replace tokens in a sequence where selecting is base on uniform distribution"""
+
+ @override
+ def __init__(self, sampler: TokenSampler, ratio: float) -> None:
+ """Constructor
+
+ Args:
+ sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object for sampling replacement tokens.
+ ratio: Ratio of tokens to replace in the sequence.
+ """
+ super().__init__(sampler)
+ self.ratio = ratio
+
+ @override
+ def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]:
+ """Sample a sequence
+
+ Args:
+ input: Input sequence of ids of shape [batch, sequence]
+
+ Returns:
+ input_replaced: A replaced sequence [batch, sequence]
+ replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence]
+ """
+ sample_uniform = torch.rand(input.shape, device=input.device)
+ replacement_mask = sample_uniform < self.ratio
+ token_sampled = self.sampler(input)
+ input_replaced = input * ~replacement_mask + token_sampled * replacement_mask
+ return input_replaced, replacement_mask
diff --git a/inseq/attr/feat/ops/reagent_core/token_sampler.py b/inseq/attr/feat/ops/reagent_core/token_sampler.py
new file mode 100644
index 00000000..7ca41bf2
--- /dev/null
+++ b/inseq/attr/feat/ops/reagent_core/token_sampler.py
@@ -0,0 +1,107 @@
+import logging
+from abc import ABC, abstractmethod
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Optional, Union
+
+import torch
+from transformers import AutoTokenizer, PreTrainedTokenizerBase
+from typing_extensions import override
+
+from .....utils import INSEQ_ARTIFACTS_CACHE, cache_results, is_nltk_available
+from .....utils.typing import IdsTensor
+
+logger = logging.getLogger(__name__)
+
+
+class TokenSampler(ABC):
+ """Base class for token samplers"""
+
+ @abstractmethod
+ def __call__(self, input: IdsTensor, **kwargs) -> IdsTensor:
+ """Sample tokens according to the specified strategy.
+
+ Args:
+ input: input tensor [batch, sequence]
+
+ Returns:
+ token_uniform: A sampled tensor where its shape is the same with the input
+ """
+ raise NotImplementedError()
+
+
+class POSTagTokenSampler(TokenSampler):
+ """Sample tokens from Uniform distribution on a set of words with the same POS tag."""
+
+ def __init__(
+ self,
+ tokenizer: Union[str, PreTrainedTokenizerBase],
+ identifier: str = "pos_tag_sampler",
+ save_cache: bool = True,
+ overwrite_cache: bool = False,
+ cache_dir: Path = INSEQ_ARTIFACTS_CACHE / "pos_tag_sampler_cache",
+ device: Optional[str] = None,
+ tokenizer_kwargs: Optional[dict[str, Any]] = {},
+ ) -> None:
+ if isinstance(tokenizer, PreTrainedTokenizerBase):
+ self.tokenizer = tokenizer
+ else:
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)
+ cache_filename = cache_dir / f"{identifier.split('/')[-1]}.pkl"
+ self.pos2ids = self.build_pos_mapping_from_vocab(
+ cache_dir,
+ cache_filename,
+ save_cache,
+ overwrite_cache,
+ tokenizer=self.tokenizer,
+ )
+ num_postags = len(self.pos2ids)
+ self.id2pos = torch.zeros([self.tokenizer.vocab_size], dtype=torch.long, device=device)
+ for pos_idx, ids in enumerate(self.pos2ids.values()):
+ self.id2pos[ids] = pos_idx
+ self.num_ids_per_pos = torch.tensor(
+ [len(ids) for ids in self.pos2ids.values()], dtype=torch.long, device=device
+ )
+ self.offsets = torch.sum(
+ torch.tril(torch.ones([num_postags, num_postags], device=device), diagonal=-1) * self.num_ids_per_pos,
+ dim=-1,
+ )
+ self.compact_idx = torch.cat(
+ tuple(torch.tensor(v, dtype=torch.long, device=device) for v in self.pos2ids.values())
+ )
+
+ @staticmethod
+ @cache_results
+ def build_pos_mapping_from_vocab(
+ tokenizer: PreTrainedTokenizerBase,
+ log_every: int = 5000,
+ ) -> dict[str, list[int]]:
+ """Build mapping from POS tags to list of token ids from tokenizer's vocabulary."""
+ if not is_nltk_available():
+ raise ImportError("nltk is required to build POS tag mapping. Please install nltk.")
+ import nltk
+
+ nltk.download("averaged_perceptron_tagger")
+ pos2ids = defaultdict(list)
+ for i in range(tokenizer.vocab_size):
+ word = tokenizer.decode([i])
+ _, tag = nltk.pos_tag([word.strip()])[0]
+ pos2ids[tag].append(i)
+ if i % log_every == 0:
+ logger.info(f"Loading vocab from tokenizer - {i / tokenizer.vocab_size * 100:.2f}%")
+ return pos2ids
+
+ @override
+ def __call__(self, input_ids: IdsTensor) -> IdsTensor:
+ """Sample a tensor
+
+ Args:
+ input: input tensor [batch, sequence]
+
+ Returns:
+ token_uniform: A sampled tensor where its shape is the same with the input
+ """
+ input_ids_pos = self.id2pos[input_ids]
+ sample_uniform = torch.rand(input_ids.shape, device=input_ids.device)
+ compact_group_idx = (sample_uniform * self.num_ids_per_pos[input_ids_pos] + self.offsets[input_ids_pos]).long()
+ return self.compact_idx[compact_group_idx]
diff --git a/inseq/attr/feat/ops/value_zeroing.py b/inseq/attr/feat/ops/value_zeroing.py
new file mode 100644
index 00000000..ed95eb12
--- /dev/null
+++ b/inseq/attr/feat/ops/value_zeroing.py
@@ -0,0 +1,394 @@
+# Copyright 2023 The Inseq Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from enum import Enum
+from types import FrameType
+from typing import TYPE_CHECKING, Callable, Optional
+
+import torch
+from captum._utils.typing import TensorOrTupleOfTensorsGeneric
+from torch import nn
+from torch.utils.hooks import RemovableHandle
+
+from ....utils import (
+ find_block_stack,
+ get_post_variable_assignment_hook,
+ recursive_get_submodule,
+ validate_indices,
+)
+from ....utils.typing import (
+ EmbeddingsTensor,
+ InseqAttribution,
+ MultiLayerEmbeddingsTensor,
+ MultiLayerScoreTensor,
+ OneOrMoreIndices,
+ OneOrMoreIndicesDict,
+)
+
+if TYPE_CHECKING:
+ from ....models import HuggingfaceModel
+
+logger = logging.getLogger(__name__)
+
+
+class ValueZeroingSimilarityMetric(Enum):
+ COSINE = "cosine"
+ EUCLIDEAN = "euclidean"
+
+
+class ValueZeroingModule(Enum):
+ DECODER = "decoder"
+ ENCODER = "encoder"
+
+
+class ValueZeroing(InseqAttribution):
+ """Value Zeroing method for feature attribution.
+
+ Introduced by `Mohebbi et al. (2023) `__ to quantify context mixing inside
+ Transformer models. The method is based on the observation that context mixing is regulated by the value vectors
+ of the attention mechanism. The method consists of two steps:
+
+ 1. Zeroing the value vectors of the attention mechanism for a given token index at a given layer of the model.
+ 2. Computing the similarity between hidden states produced with and without the zeroing operation, and using it
+ as a measure of context mixing for the given token at the given layer.
+
+ The method is converted into a feature attribution method by allowing for extraction of value zeroing scores at
+ specific layers, or by aggregating them across layers.
+
+ Attributes:
+ SIMILARITY_METRICS (:obj:`Dict[str, Callable]`):
+ Dictionary of available similarity metrics to be used forvcomputing the distance between hidden states
+ produced with and without the zeroing operation. Converted to distances as 1 - produced values.
+ forward_func (:obj:`AttributionModel`):
+ The attribution model to be used for value zeroing.
+ clean_block_output_states (:obj:`Dict[int, torch.Tensor]`):
+ Dictionary to store the hidden states produced by the model without the zeroing operation.
+ corrupted_block_output_states (:obj:`Dict[int, torch.Tensor]`):
+ Dictionary to store the hidden states produced by the model with the zeroing operation.
+ """
+
+ SIMILARITY_METRICS = {
+ "cosine": nn.CosineSimilarity(dim=-1),
+ "euclidean": lambda x, y: torch.cdist(x, y, p=2),
+ }
+
+ def __init__(self, forward_func: "HuggingfaceModel") -> None:
+ super().__init__(forward_func)
+ self.clean_block_output_states: dict[int, EmbeddingsTensor] = {}
+ self.corrupted_block_output_states: dict[int, EmbeddingsTensor] = {}
+
+ @staticmethod
+ def get_value_zeroing_hook(varname: str = "value") -> Callable[..., None]:
+ """Returns a hook to zero the value vectors of the attention mechanism.
+
+ Args:
+ varname (:obj:`str`, optional): The name of the variable containing the value vectors. The variable
+ is expected to be a 3D tensor of shape (batch_size, num_heads, seq_len) and is retrieved from the
+ local variables of the execution frame during the forward pass.
+ """
+
+ def value_zeroing_forward_mid_hook(
+ frame: FrameType,
+ zeroed_token_index: Optional[int] = None,
+ zeroed_units_indices: Optional[OneOrMoreIndices] = None,
+ batch_size: int = 1,
+ ) -> None:
+ if varname not in frame.f_locals:
+ raise ValueError(
+ f"Variable {varname} not found in the local frame."
+ f"Other variable names: {', '.join(frame.f_locals.keys())}"
+ )
+ # Zeroing value vectors corresponding to the given token index
+ if zeroed_token_index is not None:
+ values_size = frame.f_locals[varname].size()
+ if len(values_size) == 3: # Assume merged shape (bsz * num_heads, seq_len, hidden_size) e.g. Whisper
+ values = frame.f_locals[varname].view(batch_size, -1, *values_size[1:])
+ elif len(values_size) == 4: # Assume per-head shape (bsz, num_heads, seq_len, hidden_size) e.g. GPT-2
+ values = frame.f_locals[varname].clone()
+ else:
+ raise ValueError(
+ f"Value vector shape {frame.f_locals[varname].size()} not supported. "
+ "Supported shapes: (batch_size, num_heads, seq_len, hidden_size) or "
+ "(batch_size * num_heads, seq_len, hidden_size)"
+ )
+ zeroed_units_indices = validate_indices(values, 1, zeroed_units_indices).to(values.device)
+ zeroed_token_index = torch.tensor(zeroed_token_index, device=values.device)
+ # Mask heads corresponding to zeroed units and tokens corresponding to zeroed tokens
+ values[:, zeroed_units_indices, zeroed_token_index] = 0
+ if len(values_size) == 3:
+ frame.f_locals[varname] = values.view(-1, *values_size[1:])
+ elif len(values_size) == 4:
+ frame.f_locals[varname] = values
+
+ return value_zeroing_forward_mid_hook
+
+ def get_states_extract_and_patch_hook(self, block_idx: int, hidden_state_idx: int = 0) -> Callable[..., None]:
+ """Returns a hook to extract the produced hidden states (corrupted by value zeroing)
+ and patch them with pre-computed clean states that will be passed onwards in the model forward.
+
+ Args:
+ block_idx (:obj:`int`): The idx of the block at which the hook is applied, used to store extracted states.
+ hidden_state_idx (:obj:`int`, optional): The index of the hidden state in the model output tuple.
+ """
+
+ def states_extract_and_patch_forward_hook(module, args, output) -> None:
+ self.corrupted_block_output_states[block_idx] = output[hidden_state_idx].clone().float().detach().cpu()
+
+ # Rebuild the output tuple patching the clean states at the place of the corrupted ones
+ output = (
+ output[:hidden_state_idx]
+ + (self.clean_block_output_states[block_idx].to(output[hidden_state_idx].device),)
+ + output[hidden_state_idx + 1 :]
+ )
+ return output
+
+ return states_extract_and_patch_forward_hook
+
+ @staticmethod
+ def has_convergence_delta() -> bool:
+ return False
+
+ def compute_modules_post_zeroing_similarity(
+ self,
+ inputs: TensorOrTupleOfTensorsGeneric,
+ additional_forward_args: TensorOrTupleOfTensorsGeneric,
+ hidden_states: MultiLayerEmbeddingsTensor,
+ attention_module_name: str,
+ attributed_seq_len: Optional[int] = None,
+ similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value,
+ mode: str = ValueZeroingModule.DECODER.value,
+ zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
+ min_score_threshold: float = 1e-5,
+ use_causal_mask: bool = False,
+ ) -> MultiLayerScoreTensor:
+ """Given a ``nn.ModuleList``, computes the similarity between the clean and corrupted states for each block.
+
+ Args:
+ modules (:obj:`nn.ModuleList`): The list of modules to compute the similarity for.
+ hidden_states (:obj:`MultiLayerEmbeddingsTensor`): The cached hidden states of the modules to use as clean
+ counterparts when computing the similarity.
+ attention_module_name (:obj:`str`): The name of the attention module to zero the values for.
+ attributed_seq_len (:obj:`int`): The length of the sequence to attribute. If not specified, it is assumed
+ to be the same as the length of the hidden states.
+ similarity_metric (:obj:`str`): The name of the similarity metric used. Default: "cosine".
+ mode (:obj:`str`): The mode of the model to compute the similarity for. Default: "decoder".
+ zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int]]` or :obj:`dict` with :obj:`int` keys and
+ `Union[int, tuple[int, int], list[int]]` values, optional): The indices of the attention heads
+ that should be zeroed to compute corrupted states.
+ - If None, all attention heads across all layers are zeroed.
+ - If an integer, the same attention head is zeroed across all layers.
+ - If a tuple of two integers, the attention heads in the range are zeroed across all layers.
+ - If a list of integers, the attention heads in the list are zeroed across all layers.
+ - If a dictionary, the keys are the layer indices and the values are the zeroed attention heads for
+ the corresponding layer. Any missing layer will not be zeroed.
+ Default: None.
+ min_score_threshold (:obj:`float`, optional): The minimum score threshold to consider when computing the
+ similarity. Default: 1e-5.
+ use_causal_mask (:obj:`bool`, optional): Whether a causal mask is applied to zeroing scores Default: False.
+
+ Returns:
+ :obj:`MultiLayerScoreTensor`: A tensor of shape ``[batch_size, seq_len, num_layer]`` containing distances
+ (1 - similarity score) between original and corrupted states for each layer.
+ """
+ if mode == ValueZeroingModule.DECODER.value:
+ modules: nn.ModuleList = find_block_stack(self.forward_func.get_decoder())
+ elif mode == ValueZeroingModule.ENCODER.value:
+ modules: nn.ModuleList = find_block_stack(self.forward_func.get_encoder())
+ else:
+ raise NotImplementedError(f"Mode {mode} not implemented for value zeroing.")
+ if attributed_seq_len is None:
+ attributed_seq_len = hidden_states.size(2)
+ batch_size = hidden_states.size(0)
+ generated_seq_len = hidden_states.size(2)
+ num_layers = len(modules)
+
+ # Store clean hidden states for later use. Starts at 1 since the first element of the modules stack is the
+ # embedding layer, and we are only interested in the transformer blocks outputs.
+ self.clean_block_output_states = {
+ block_idx: hidden_states[:, block_idx + 1, ...].clone().detach().cpu() for block_idx in range(len(modules))
+ }
+ # Scores for every layer of the model
+ all_scores = torch.ones(
+ batch_size, num_layers, generated_seq_len, attributed_seq_len, device=hidden_states.device
+ ) * float("nan")
+
+ # Hooks:
+ # 1. states_extract_and_patch_hook on the transformer block stores corrupted states and force clean states
+ # as the output of the block forward pass, i.e. the zeroing is done independently across layers.
+ # 2. value_zeroing_hook on the attention module performs the value zeroing by replacing the "value" tensor
+ # during the forward (name is config-dependent) with a zeroed version for the specified token index.
+ #
+ # State extraction hooks can be registered only once since they are token-independent
+ # Skip last block since its states are not used raw, but may have further transformations applied to them
+ # (e.g. LayerNorm, Dropout). These are extracted separately from the model outputs.
+ states_extraction_hook_handles: list[RemovableHandle] = []
+ for block_idx in range(len(modules) - 1):
+ states_extract_and_patch_hook = self.get_states_extract_and_patch_hook(block_idx, hidden_state_idx=0)
+ states_extraction_hook_handles.append(
+ modules[block_idx].register_forward_hook(states_extract_and_patch_hook)
+ )
+ # Zeroing is done for every token in the sequence separately (O(n) complexity)
+ for token_idx in range(attributed_seq_len):
+ value_zeroing_hook_handles: list[RemovableHandle] = []
+ # Value zeroing hooks are registered for every token separately since they are token-dependent
+ for block_idx, block in enumerate(modules):
+ attention_module = recursive_get_submodule(block, attention_module_name)
+ if attention_module is None:
+ raise ValueError(f"Attention module {attention_module_name} not found in block {block_idx}.")
+ if isinstance(zeroed_units_indices, dict):
+ if block_idx not in zeroed_units_indices:
+ continue
+ zeroed_units_indices_block = zeroed_units_indices[block_idx]
+ else:
+ zeroed_units_indices_block = zeroed_units_indices
+ value_zeroing_hook = get_post_variable_assignment_hook(
+ module=attention_module,
+ varname=self.forward_func.config.value_vector,
+ hook_fn=self.get_value_zeroing_hook(self.forward_func.config.value_vector),
+ zeroed_token_index=token_idx,
+ zeroed_units_indices=zeroed_units_indices_block,
+ batch_size=batch_size,
+ )
+ value_zeroing_hook_handle = attention_module.register_forward_pre_hook(value_zeroing_hook)
+ value_zeroing_hook_handles.append(value_zeroing_hook_handle)
+
+ # Run forward pass with hooks. Fills self.corrupted_hidden_states with corrupted states across layers
+ # when zeroing the specified token index.
+ with torch.no_grad():
+ output = self.forward_func.forward_with_output(
+ *inputs, *additional_forward_args, output_hidden_states=True
+ )
+ # Extract last layer states directly from the model outputs
+ # This allows us to handle the presence of additional transformations (e.g. LayerNorm, Dropout)
+ # in the last layer automatically.
+ corrupted_states_dict = self.forward_func.get_hidden_states_dict(output)
+ corrupted_decoder_last_hidden_state = (
+ corrupted_states_dict[f"{mode}_hidden_states"][:, -1, ...].clone().detach().cpu()
+ )
+ self.corrupted_block_output_states[len(modules) - 1] = corrupted_decoder_last_hidden_state
+ for handle in value_zeroing_hook_handles:
+ handle.remove()
+ for block_idx in range(len(modules)):
+ similarity_scores = self.SIMILARITY_METRICS[similarity_metric](
+ self.clean_block_output_states[block_idx].float(), self.corrupted_block_output_states[block_idx]
+ )
+ if use_causal_mask:
+ all_scores[:, block_idx, token_idx:, token_idx] = 1 - similarity_scores[:, token_idx:]
+ else:
+ all_scores[:, block_idx, :, token_idx] = 1 - similarity_scores
+ self.corrupted_block_output_states = {}
+ for handle in states_extraction_hook_handles:
+ handle.remove()
+ self.clean_block_output_states = {}
+ all_scores = torch.where(all_scores < min_score_threshold, torch.zeros_like(all_scores), all_scores)
+ # Normalize scores to sum to 1
+ per_token_sum_score = all_scores.nansum(dim=-1, keepdim=True)
+ per_token_sum_score[per_token_sum_score == 0] = 1
+ all_scores = all_scores / per_token_sum_score
+
+ # Final shape: [batch_size, attributed_seq_len, generated_seq_len, num_layers]
+ return all_scores.permute(0, 3, 2, 1)
+
+ def attribute(
+ self,
+ inputs: TensorOrTupleOfTensorsGeneric,
+ additional_forward_args: TensorOrTupleOfTensorsGeneric,
+ similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value,
+ encoder_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
+ decoder_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
+ cross_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
+ encoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None,
+ decoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None,
+ output_decoder_self_scores: bool = True,
+ output_encoder_self_scores: bool = True,
+ ) -> TensorOrTupleOfTensorsGeneric:
+ """Perform attribution using the Value Zeroing method.
+
+ Args:
+ similarity_metric (:obj:`str`, optional): The similarity metric to use for computing the distance between
+ hidden states produced with and without the zeroing operation. Default: cosine similarity.
+ zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int]]` or :obj:`dict` with :obj:`int` keys and
+ `Union[int, tuple[int, int], list[int]]` values, optional): The indices of the attention heads
+ that should be zeroed to compute corrupted states.
+ - If None, all attention heads across all layers are zeroed.
+ - If an integer, the same attention head is zeroed across all layers.
+ - If a tuple of two integers, the attention heads in the range are zeroed across all layers.
+ - If a list of integers, the attention heads in the list are zeroed across all layers.
+ - If a dictionary, the keys are the layer indices and the values are the zeroed attention heads for
+ the corresponding layer.
+
+ Default: None (all heads are zeroed for every layer).
+ encoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers + 1,
+ source_seq_len, hidden_size]`` containing hidden states of the encoder. Available only for
+ encoder-decoders models. Default: None.
+ decoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers + 1,
+ target_seq_len, hidden_size]`` containing hidden states of the decoder.
+ output_decoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the
+ decoder self-attention value vectors in encoder-decoder models. Cannot be false for decoder-only, or
+ if target-side attribution is requested using `attribute_target=True`. Default: True.
+ output_encoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the
+ encoder self-attention value vectors in encoder-decoder models. Default: True.
+
+ Returns:
+ `TensorOrTupleOfTensorsGeneric`: Attribution outputs for source-only or source + target feature attribution
+ """
+ if similarity_metric not in self.SIMILARITY_METRICS:
+ raise ValueError(
+ f"Similarity metric {similarity_metric} not available."
+ f"Available metrics: {','.join(self.SIMILARITY_METRICS.keys())}"
+ )
+ decoder_scores = None
+ if not self.forward_func.is_encoder_decoder or output_decoder_self_scores or len(inputs) > 1:
+ decoder_scores = self.compute_modules_post_zeroing_similarity(
+ inputs=inputs,
+ additional_forward_args=additional_forward_args,
+ hidden_states=decoder_hidden_states,
+ attention_module_name=self.forward_func.config.self_attention_module,
+ similarity_metric=similarity_metric,
+ mode=ValueZeroingModule.DECODER.value,
+ zeroed_units_indices=decoder_zeroed_units_indices,
+ use_causal_mask=True,
+ )
+ # Encoder-decoder models also perform zeroing on the encoder self-attention and cross-attention values
+ # Adapted from https://github.com/hmohebbi/ContextMixingASR/blob/master/scoring/valueZeroing.py
+ if self.forward_func.is_encoder_decoder:
+ encoder_scores = None
+ if output_encoder_self_scores:
+ encoder_scores = self.compute_modules_post_zeroing_similarity(
+ inputs=inputs,
+ additional_forward_args=additional_forward_args,
+ hidden_states=encoder_hidden_states,
+ attention_module_name=self.forward_func.config.self_attention_module,
+ similarity_metric=similarity_metric,
+ mode=ValueZeroingModule.ENCODER.value,
+ zeroed_units_indices=encoder_zeroed_units_indices,
+ )
+ cross_scores = self.compute_modules_post_zeroing_similarity(
+ inputs=inputs,
+ additional_forward_args=additional_forward_args,
+ hidden_states=decoder_hidden_states,
+ attributed_seq_len=encoder_hidden_states.size(2),
+ attention_module_name=self.forward_func.config.cross_attention_module,
+ similarity_metric=similarity_metric,
+ mode=ValueZeroingModule.DECODER.value,
+ zeroed_units_indices=cross_zeroed_units_indices,
+ )
+ return encoder_scores, cross_scores, decoder_scores
+ elif encoder_zeroed_units_indices is not None or cross_zeroed_units_indices is not None:
+ logger.warning(
+ "Zeroing indices for encoder and cross-attentions were specified, but the model is not an "
+ "encoder-decoder. Use `decoder_zeroed_units_indices` to parametrize zeroing for the decoder module."
+ )
+ return (decoder_scores,)
diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py
index fbebb780..498093af 100644
--- a/inseq/attr/feat/perturbation_attribution.py
+++ b/inseq/attr/feat/perturbation_attribution.py
@@ -1,16 +1,20 @@
import logging
-from typing import Any
+from typing import TYPE_CHECKING, Any
from captum.attr import Occlusion
from ...data import (
CoarseFeatureAttributionStepOutput,
GranularFeatureAttributionStepOutput,
+ MultiDimensionalFeatureAttributionStepOutput,
)
from ...utils import Registry
from .attribution_utils import get_source_target_attributions
from .gradient_attribution import FeatureAttribution
-from .ops import Lime
+from .ops import Lime, Reagent, ValueZeroing
+
+if TYPE_CHECKING:
+ from ...models import HuggingfaceModel
logger = logging.getLogger(__name__)
@@ -117,3 +121,165 @@ def attribute_step(
target_attributions=out.target_attributions,
sequence_scores=out.sequence_scores,
)
+
+
+class ReagentAttribution(PerturbationAttributionRegistry):
+ """Recursive attribution generator (ReAGent) method.
+
+ Measures importance as the drop in prediction probability produced by replacing a token with a plausible
+ alternative predicted by a LM.
+
+ Reference implementation:
+ `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models `__
+ """
+
+ method_name = "reagent"
+
+ def __init__(
+ self,
+ attribution_model: "HuggingfaceModel",
+ keep_top_n: int = 5,
+ keep_ratio: float = None,
+ invert_keep: bool = False,
+ stopping_condition_top_k: int = 3,
+ replacing_ratio: float = 0.3,
+ max_probe_steps: int = 3000,
+ num_probes: int = 16,
+ ):
+ """ReAGent method constructor.
+
+ Args:
+ keep_top_n (:obj:`int`, `optional`): If set to a value greater than 0, the top n tokens based on their importance score will be
+ kept during the prediction inference. If set to 0, the top n will be determined by ``keep_ratio``. Default: ``5``.
+ keep_ratio (:obj:`float`, `optional`): If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep.
+ invert_keep (:obj:`bool`, `optional`): If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be
+ replaced instead of being kept. Default: ``False``.
+ stopping_condition_top_k (:obj:`int`, `optional`): Threshold indicating that the stop condition achieved when the predicted target
+ exist in top k predictions. Default: ``3``.
+ replacing_ratio (:obj:`float`, `optional`): replacing ratio of tokens for probing. Default: ``0.3``.
+ max_probe_steps (:obj:`int`, `optional`): Max number of steps before stopping the probing. Default: ``3000``.
+ num_probes (:obj:`int`, `optional`): Number of probes performed in parallel. Default: ``16``.
+ """
+ super().__init__(attribution_model)
+ # Custom target attribution is currently not supported
+ self.use_predicted_target = False
+ self.method = Reagent(
+ attribution_model=self.attribution_model,
+ keep_top_n=keep_top_n,
+ keep_ratio=keep_ratio,
+ invert_keep=invert_keep,
+ stopping_condition_top_k=stopping_condition_top_k,
+ replacing_ratio=replacing_ratio,
+ max_probe_steps=max_probe_steps,
+ num_probes=num_probes,
+ )
+
+ def attribute_step(
+ self,
+ attribute_fn_main_args: dict[str, Any],
+ attribution_args: dict[str, Any] = {},
+ ) -> GranularFeatureAttributionStepOutput:
+ out = super().attribute_step(attribute_fn_main_args, attribution_args)
+ return GranularFeatureAttributionStepOutput(
+ source_attributions=out.source_attributions,
+ target_attributions=out.target_attributions,
+ sequence_scores=out.sequence_scores,
+ )
+
+
+class ValueZeroingAttribution(PerturbationAttributionRegistry):
+ """Value Zeroing method for feature attribution.
+
+ Introduced by `Mohebbi et al. (2023) `__ to quantify context mixing
+ in Transformer models. The method is based on the observation that context mixing is regulated by the value vectors
+ of the attention mechanism. The method consists of two steps:
+
+ 1. Zeroing the value vectors of the attention mechanism for a given token index at a given layer of the model.
+ 2. Computing the similarity between hidden states produced with and without the zeroing operation, and using it
+ as a measure of context mixing for the given token at the given layer.
+
+ The method is converted into a feature attribution method by allowing for extraction of value zeroing scores at
+ specific layers, or by aggregating them across layers.
+
+ Reference implementations:
+ - Original implementation: `hmohebbi/ValueZeroing `__
+ - Encoder-decoder implementation: `hmohebbi/ContextMixingASR `__
+
+ Args:
+ similarity_metric (:obj:`str`, optional): The similarity metric to use for computing the distance between
+ hidden states produced with and without the zeroing operation. Options: cosine, euclidean. Default: cosine.
+ encoder_zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int], dict]`, optional): The indices of
+ the attention heads that should be zeroed to compute corrupted states in the encoder self-attention module.
+ Not used for decoder-only models, or if ``output_encoder_self_scores`` is False. Format
+
+ - None: all attention heads across all layers are zeroed.
+ - int: the same attention head is zeroed across all layers.
+ - tuple of two integers: the attention heads in the range are zeroed across all layers.
+ - list of integers: the attention heads in the list are zeroed across all layers.
+ - dictionary: the keys are the layer indices and the values are the zeroed attention heads for the corresponding layer.
+
+ Default: None (all heads are zeroed for every encoder layer).
+ decoder_zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int], dict]`, optional): Same as
+ ``encoder_zeroed_units_indices`` but for the decoder self-attention module. Not used for encoder-decoder
+ models or if ``output_decoder_self_scores`` is False. Default: None (all heads are zeroed for every decoder layer).
+ cross_zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int], dict]`, optional): Same as
+ ``encoder_zeroed_units_indices`` but for the cross-attention module in encoder-decoder models. Not used
+ if the model is decoder-only. Default: None (all heads are zeroed for every layer).
+ output_decoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the
+ decoder self-attention value vectors in encoder-decoder models. Cannot be false for decoder-only, or
+ if target-side attribution is requested using `attribute_target=True`. Default: True.
+ output_encoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the
+ encoder self-attention value vectors in encoder-decoder models. Default: True.
+
+ Returns:
+ :class:`~inseq.data.MultiDimensionalFeatureAttributionStepOutput`: The final dimension returned by the method
+ is ``[attributed_seq_len, generated_seq_len, num_layers]``. If ``output_decoder_self_scores`` and
+ ``output_encoder_self_scores`` are True, the respective scores are returned in the ``sequence_scores``
+ output dictionary.
+ """
+
+ method_name = "value_zeroing"
+
+ def __init__(self, attribution_model, **kwargs):
+ super().__init__(attribution_model, hook_to_model=False)
+ # Hidden states will be passed to the attribute_step method
+ self.use_hidden_states = True
+ # Does not rely on predicted output (i.e. decoding strategy agnostic)
+ self.use_predicted_target = False
+ # Uses model configuration to access attention module and value vector variable
+ self.use_model_config = True
+ # Needs only the final generation step to extract scores
+ self.is_final_step_method = True
+ self.method = ValueZeroing(attribution_model)
+ self.hook(**kwargs)
+
+ def attribute_step(
+ self,
+ attribute_fn_main_args: dict[str, Any],
+ attribution_args: dict[str, Any] = {},
+ ) -> MultiDimensionalFeatureAttributionStepOutput:
+ attr = self.method.attribute(**attribute_fn_main_args, **attribution_args)
+ encoder_self_scores, decoder_cross_scores, decoder_self_scores = get_source_target_attributions(
+ attr, self.attribution_model.is_encoder_decoder, has_sequence_scores=True
+ )
+ sequence_scores = {}
+ if self.attribution_model.is_encoder_decoder:
+ if len(attribute_fn_main_args["inputs"]) > 1:
+ target_attributions = decoder_self_scores.to("cpu")
+ else:
+ target_attributions = None
+ if decoder_self_scores is not None:
+ sequence_scores["decoder_self_scores"] = decoder_self_scores.to("cpu")
+ if encoder_self_scores is not None:
+ sequence_scores["encoder_self_scores"] = encoder_self_scores.to("cpu")
+ return MultiDimensionalFeatureAttributionStepOutput(
+ source_attributions=decoder_cross_scores.to("cpu"),
+ target_attributions=target_attributions,
+ sequence_scores=sequence_scores,
+ _num_dimensions=1, # num_layers
+ )
+ return MultiDimensionalFeatureAttributionStepOutput(
+ source_attributions=None,
+ target_attributions=decoder_self_scores,
+ _num_dimensions=1, # num_layers
+ )
diff --git a/inseq/attr/step_functions.py b/inseq/attr/step_functions.py
index d9cd6092..83aa8d6e 100644
--- a/inseq/attr/step_functions.py
+++ b/inseq/attr/step_functions.py
@@ -462,8 +462,7 @@ def register_step_function(
attribution targets by gradient-based feature attribution methods.
Args:
- fn (:obj:`callable`): The function to be used to compute step scores. Default parameters (use kwargs to capture
- unused ones when defining your function):
+ fn (:obj:`callable`): The function to be used to compute step scores. Default parameters (use kwargs to capture unused ones when defining your function):
- :obj:`attribution_model`: an :class:`~inseq.models.AttributionModel` instance, corresponding to the model
used for computing the score.
diff --git a/inseq/commands/attribute_context/attribute_context.py b/inseq/commands/attribute_context/attribute_context.py
index f4cbb7cb..1eb72126 100644
--- a/inseq/commands/attribute_context/attribute_context.py
+++ b/inseq/commands/attribute_context/attribute_context.py
@@ -34,6 +34,7 @@
from .attribute_context_helpers import (
AttributeContextOutput,
CCIOutput,
+ concat_with_sep,
filter_rank_tokens,
format_template,
get_contextless_output,
@@ -69,10 +70,6 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
# Prepare input/outputs (generate if necessary)
input_full_text = format_template(args.input_template, args.input_current_text, args.input_context_text)
- if "{current}" in args.contextless_input_current_text:
- args.input_current_text = args.contextless_input_current_text.format(current=args.input_current_text)
- else:
- args.input_current_text = args.contextless_input_current_text
args.output_context_text, args.output_current_text = prepare_outputs(
model=model,
input_context_text=args.input_context_text,
@@ -92,8 +89,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
if args.input_context_text is not None:
input_context_tokens = get_filtered_tokens(args.input_context_text, model, args.special_tokens_to_keep)
if not model.is_encoder_decoder:
- sep = args.decoder_input_output_separator if not output_full_text.startswith((" ", "\n")) else ""
- output_full_text = input_full_text + sep + output_full_text
+ output_full_text = concat_with_sep(input_full_text, output_full_text, args.decoder_input_output_separator)
output_current_tokens = get_filtered_tokens(
args.output_current_text, model, args.special_tokens_to_keep, is_target=True
)
@@ -105,16 +101,18 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
input_full_tokens = get_filtered_tokens(input_full_text, model, args.special_tokens_to_keep)
output_full_tokens = get_filtered_tokens(output_full_text, model, args.special_tokens_to_keep, is_target=True)
output_current_text_offset = len(output_full_tokens) - len(output_current_tokens)
- if model.is_encoder_decoder:
- prefixed_output_current_text = args.output_current_text
- else:
- sep = args.decoder_input_output_separator if not args.output_current_text.startswith((" ", "\n")) else ""
- prefixed_output_current_text = args.input_current_text + sep + args.output_current_text
+ formatted_input_current_text = args.contextless_input_current_text.format(current=args.input_current_text)
+ formatted_output_current_text = args.contextless_output_current_text.format(current=args.output_current_text)
+ if not model.is_encoder_decoder:
+ formatted_input_current_text = concat_with_sep(
+ formatted_input_current_text, "", args.decoder_input_output_separator
+ )
+ formatted_output_current_text = formatted_input_current_text + formatted_output_current_text
# Part 1: Context-sensitive Token Identification (CTI)
cti_out = model.attribute(
- args.input_current_text,
- prefixed_output_current_text,
+ formatted_input_current_text.rstrip(" "),
+ formatted_output_current_text,
attribute_target=model.is_encoder_decoder,
step_scores=[args.context_sensitivity_metric],
contrast_sources=input_full_text if model.is_encoder_decoder else None,
@@ -126,6 +124,11 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
cti_out.show(do_aggregation=False)
start_pos = 1 if has_lang_tag else 0
+ contextless_output_prefix = args.contextless_output_current_text.split("{current}")[0]
+ contextless_output_prefix_tokens = get_filtered_tokens(
+ contextless_output_prefix, model, args.special_tokens_to_keep, is_target=True
+ )
+ start_pos += len(contextless_output_prefix_tokens)
cti_scores = cti_out.step_scores[args.context_sensitivity_metric][start_pos:].tolist()
cti_tokens = [t.token for t in cti_out.target][start_pos + cti_out.attr_pos_start :]
if model.is_encoder_decoder:
@@ -149,18 +152,29 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
)
# Part 2: Contextual Cues Imputation (CCI)
for cci_step_idx, (cti_idx, cti_score, cti_tok) in enumerate(cti_ranked_tokens):
+ contextual_input = model.convert_tokens_to_string(input_full_tokens, skip_special_tokens=False).lstrip(" ")
contextual_output = model.convert_tokens_to_string(
output_full_tokens[: output_current_text_offset + cti_idx + 1], skip_special_tokens=False
- )
+ ).lstrip(" ")
if not contextual_output:
- contextual_output = output_full_tokens[output_current_text_offset + cti_idx]
-
+ output_ctx_tokens = [output_full_tokens[output_current_text_offset + cti_idx]]
+ if model.is_encoder_decoder:
+ output_ctx_tokens.append(model.pad_token)
+ contextual_output = model.convert_tokens_to_string(output_ctx_tokens, skip_special_tokens=True)
+ else:
+ output_ctx_tokens = model.convert_string_to_tokens(
+ contextual_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder
+ )
cci_kwargs = {}
contextless_output = None
if args.attributed_fn is not None and is_contrastive_step_function(args.attributed_fn):
+ if not model.is_encoder_decoder:
+ formatted_input_current_text = concat_with_sep(
+ formatted_input_current_text, contextless_output_prefix, args.decoder_input_output_separator
+ )
contextless_output = get_contextless_output(
model,
- args.input_current_text,
+ formatted_input_current_text,
output_current_tokens,
cti_idx,
cti_ranked_tokens,
@@ -171,20 +185,18 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
args.special_tokens_to_keep,
deepcopy(args.generation_kwargs),
)
- cci_kwargs["contrast_sources"] = args.input_current_text if model.is_encoder_decoder else None
+ cci_kwargs["contrast_sources"] = formatted_input_current_text if model.is_encoder_decoder else None
cci_kwargs["contrast_targets"] = contextless_output
- output_ctx_tokens = model.convert_string_to_tokens(
- contextual_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder
- )
output_ctxless_tokens = model.convert_string_to_tokens(
contextless_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder
)
tok_pos = -2 if model.is_encoder_decoder else -1
if args.attributed_fn == "kl_divergence" or output_ctx_tokens[tok_pos] == output_ctxless_tokens[tok_pos]:
cci_kwargs["contrast_force_inputs"] = True
- pos_start = output_current_text_offset + cti_idx + int(model.is_encoder_decoder) + int(has_lang_tag)
+ bos_offset = int(model.is_encoder_decoder or output_ctx_tokens[0] == model.bos_token)
+ pos_start = output_current_text_offset + cti_idx + bos_offset + int(has_lang_tag)
cci_attrib_out = model.attribute(
- input_full_text,
+ contextual_input,
contextual_output,
attribute_target=model.is_encoder_decoder and args.has_output_context,
show_progress=False,
@@ -206,6 +218,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
model,
cci_attrib_out,
args.input_template,
+ args.input_current_text,
input_context_tokens,
input_full_tokens,
args.output_template,
@@ -213,6 +226,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
args.has_input_context,
args.has_output_context,
has_lang_tag,
+ args.decoder_input_output_separator,
args.special_tokens_to_keep,
)
cci_out = CCIOutput(
diff --git a/inseq/commands/attribute_context/attribute_context_args.py b/inseq/commands/attribute_context/attribute_context_args.py
index 716eb5b6..295d5cc2 100644
--- a/inseq/commands/attribute_context/attribute_context_args.py
+++ b/inseq/commands/attribute_context/attribute_context_args.py
@@ -91,6 +91,17 @@ class AttributeContextInputArgs:
" be used as-is for the contrastive comparison, enabling contrastive comparison with different inputs."
),
)
+ contextless_output_current_text: Optional[str] = cli_arg(
+ default=None,
+ help=(
+ "The output current text or template to use in the contrastive comparison with contextual output. By default"
+ " it is the same as ``output_current_text``, but it can be useful in cases where the context is nested "
+ "inside the current text (e.g. for an ``output_template`` like \n{context}\n{current}\n we "
+ "can use this parameter to format the contextless version as \n{current}\n)."
+ "If it contains the tag {current}, it will be infilled with the ``output_current_text``. Otherwise, it will"
+ " be used as-is for the contrastive comparison, enabling contrastive comparison with different outputs."
+ ),
+ )
@command_args_docstring
@@ -184,7 +195,7 @@ class AttributeContextOutputArgs:
default=False,
help=(
"If specified, the intermediate outputs produced by the Inseq library for context-sensitive target "
- "identification (CTI) and contextual cues imputation (CCI) are shown during the process.",
+ "identification (CTI) and contextual cues imputation (CCI) are shown during the process."
),
)
save_path: Optional[str] = cli_arg(
@@ -212,8 +223,19 @@ class AttributeContextArgs(AttributeContextInputArgs, AttributeContextMethodArgs
def __repr__(self):
return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})"
+ @classmethod
+ def _to_dict(cls, val: Any) -> dict[str, Any]:
+ if val is None or isinstance(val, (str, int, float, bool)):
+ return val
+ elif isinstance(val, dict):
+ return {k: cls._to_dict(v) for k, v in val.items()}
+ elif isinstance(val, (list, tuple)):
+ return [cls._to_dict(v) for v in val]
+ else:
+ return str(val)
+
def to_dict(self) -> dict[str, Any]:
- return dict(self.__dict__.items())
+ return self._to_dict(self.__dict__)
def __post_init__(self):
if (
@@ -236,6 +258,18 @@ def __post_init__(self):
self.output_template = "{current}" if self.output_context_text is None else "{context} {current}"
if self.contextless_input_current_text is None:
self.contextless_input_current_text = "{current}"
+ if "{current}" not in self.contextless_input_current_text:
+ raise ValueError(
+ "{current} placeholder is missing from contextless_input_current_text template"
+ f" {self.contextless_input_current_text}."
+ )
+ if self.contextless_output_current_text is None:
+ self.contextless_output_current_text = "{current}"
+ if "{current}" not in self.contextless_output_current_text:
+ raise ValueError(
+ "{current} placeholder is missing from contextless_output_current_text template"
+ f" {self.contextless_output_current_text}."
+ )
self.has_input_context = "{context}" in self.input_template
self.has_output_context = "{context}" in self.output_template
if not self.input_current_text:
diff --git a/inseq/commands/attribute_context/attribute_context_helpers.py b/inseq/commands/attribute_context/attribute_context_helpers.py
index f436bba4..cb8793e2 100644
--- a/inseq/commands/attribute_context/attribute_context_helpers.py
+++ b/inseq/commands/attribute_context/attribute_context_helpers.py
@@ -72,6 +72,14 @@ def from_dict(cls, out_dict: dict[str, Any]) -> "AttributeContextOutput":
return out
+def concat_with_sep(s1: str, s2: str, sep: str) -> bool:
+ """Adds separator between two strings if needed."""
+ need_sep = not s1.endswith(sep) and not s2.startswith(sep)
+ if need_sep:
+ return s1 + sep + s2
+ return s1 + s2
+
+
def format_template(template: str, current: str, context: Optional[str] = None) -> str:
kwargs = {"current": current}
if context is not None:
@@ -89,7 +97,7 @@ def get_filtered_tokens(
"""Tokenize text and filter out special tokens, keeping only those in ``special_tokens_to_keep``."""
as_targets = is_target and model.is_encoder_decoder
return [
- t.replace("Ä ", " ").replace("Ä", " ").replace("ā", " ") if replace_special_characters else t
+ t.replace("Ä ", " ").replace("Ä", "\n").replace("ā", " ") if replace_special_characters else t
for t in model.convert_string_to_tokens(text, skip_special_tokens=False, as_targets=as_targets)
if t not in model.special_tokens or t in special_tokens_to_keep
]
@@ -99,11 +107,14 @@ def generate_with_special_tokens(
model: HuggingfaceModel,
model_input: str,
special_tokens_to_keep: list[str] = [],
+ output_generated_only: bool = True,
**generation_kwargs,
) -> str:
"""Generate text preserving special tokens in ``special_tokens_to_keep``."""
# Generate outputs, strip special tokens and remove prefix/suffix
- output_gen = model.generate(model_input, skip_special_tokens=False, **generation_kwargs)[0]
+ output_gen = model.generate(
+ model_input, skip_special_tokens=False, output_generated_only=output_generated_only, **generation_kwargs
+ )[0]
output_tokens = get_filtered_tokens(output_gen, model, special_tokens_to_keep, is_target=True)
return model.convert_tokens_to_string(output_tokens, skip_special_tokens=False)
@@ -236,20 +247,19 @@ def prepare_outputs(
if "forced_bos_token_id" in generation_kwargs:
generation_kwargs["decoder_input_ids"][0, 0] = generation_kwargs["forced_bos_token_id"]
else:
- sep = ""
- if output_current_prefix and not output_current_prefix.startswith((" ", "\n")):
- sep = decoder_input_output_separator
- model_input = input_full_text + sep + output_current_prefix
+ model_input = concat_with_sep(input_full_text, output_current_prefix, decoder_input_output_separator)
output_current_prefix = model_input
+ if not model.is_encoder_decoder:
+ model_input = concat_with_sep(input_full_text, "", decoder_input_output_separator)
+
output_gen = generate_model_output(
model, model_input, generation_kwargs, special_tokens_to_keep, output_template, output_current_prefix, suffix
)
# Settings 3, 4
if (has_out_ctx == use_out_ctx) and not has_out_curr:
- final_current = output_gen if model.is_encoder_decoder or use_out_ctx else output_gen[len(model_input) :]
- return final_context, final_current.strip()
+ return final_context, output_gen.strip()
# Settings 5, 6
# Try splitting the output into context and current text using ``separator``. As we have no guarantees of its
@@ -385,14 +395,12 @@ def generate_contextless_output(
generation_input = input_current_text
else:
generation_kwargs["max_new_tokens"] = 1
- sep = ""
- if contextual_prefix and not contextual_prefix.startswith((" ", "\n")):
- sep = decoder_input_output_separator
- generation_input = input_current_text + sep + contextual_prefix
+ generation_input = concat_with_sep(input_current_text, contextual_prefix, decoder_input_output_separator)
contextless_output = generate_with_special_tokens(
model,
generation_input,
special_tokens_to_keep,
+ output_generated_only=False,
**generation_kwargs,
)
return contextless_output
@@ -402,6 +410,7 @@ def get_source_target_cci_scores(
model: HuggingfaceModel,
cci_attrib_out: FeatureAttributionSequenceOutput,
input_template: str,
+ input_current_text: str,
input_context_tokens: list[str],
input_full_tokens: list[str],
output_template: str,
@@ -409,6 +418,7 @@ def get_source_target_cci_scores(
has_input_context: bool,
has_output_context: bool,
model_has_lang_tag: bool,
+ decoder_input_output_separator: str,
special_tokens_to_keep: list[str] = [],
) -> tuple[Optional[list[float]], Optional[list[float]]]:
"""Extract attribution scores for the input and output contexts."""
@@ -417,20 +427,26 @@ def get_source_target_cci_scores(
if model.is_encoder_decoder:
input_scores = cci_attrib_out.source_attributions[:, 0].tolist()
if model_has_lang_tag:
- input_scores = input_scores[1:]
+ input_scores = input_scores[2:]
else:
input_scores = cci_attrib_out.target_attributions[:, 0].tolist()
input_prefix, *_ = input_template.partition("{context}")
+ if "{current}" in input_prefix:
+ input_prefix = input_prefix.format(current=input_current_text)
input_prefix_tokens = get_filtered_tokens(input_prefix, model, special_tokens_to_keep, is_target=False)
input_prefix_len = len(input_prefix_tokens)
input_scores = input_scores[input_prefix_len : len(input_context_tokens) + input_prefix_len]
if has_output_context:
output_scores = cci_attrib_out.target_attributions[:, 0].tolist()
if model_has_lang_tag:
- output_scores = output_scores[1:]
+ output_scores = output_scores[2:]
output_prefix, *_ = output_template.partition("{context}")
+ if not model.is_encoder_decoder and output_prefix:
+ output_prefix = decoder_input_output_separator + output_prefix
output_prefix_tokens = get_filtered_tokens(output_prefix, model, special_tokens_to_keep, is_target=True)
- prefix_len = len(output_prefix_tokens) + int(not model.is_encoder_decoder) * len(input_full_tokens)
+ prefix_len = len(output_prefix_tokens)
+ if not model.is_encoder_decoder:
+ prefix_len += len(input_full_tokens)
output_scores = output_scores[prefix_len : len(output_context_tokens) + prefix_len]
return input_scores, output_scores
diff --git a/inseq/commands/attribute_context/attribute_context_viz_helpers.py b/inseq/commands/attribute_context/attribute_context_viz_helpers.py
index 6c322e12..8a2dff76 100644
--- a/inseq/commands/attribute_context/attribute_context_viz_helpers.py
+++ b/inseq/commands/attribute_context/attribute_context_viz_helpers.py
@@ -145,7 +145,7 @@ def visualize_attribute_context(
console.print(viz, soft_wrap=False)
html = console.export_html()
if output.info.viz_path:
- with open(output.info.viz_path, "w") as f:
+ with open(output.info.viz_path, "w", encoding="utf-8") as f:
f.write(html)
if output.info.show_viz:
console.print(viz, soft_wrap=False)
diff --git a/inseq/commands/commands_utils.py b/inseq/commands/commands_utils.py
index 7701409a..dbfb8ac4 100644
--- a/inseq/commands/commands_utils.py
+++ b/inseq/commands/commands_utils.py
@@ -18,5 +18,4 @@ def command_args_docstring(cls):
field_help = field.metadata.get("help", "")
docstring += textwrap.dedent(f"\n**{field.name}** (``{field_type}``): {field_help}\n")
cls.__doc__ = docstring
- print(docstring)
return cls
diff --git a/inseq/data/aggregator.py b/inseq/data/aggregator.py
index bb475707..dbae5352 100644
--- a/inseq/data/aggregator.py
+++ b/inseq/data/aggregator.py
@@ -12,9 +12,10 @@
aggregate_token_sequence,
available_classes,
extract_signature_args,
+ validate_indices,
)
from ..utils import normalize as normalize_fn
-from ..utils.typing import IndexSpan, TokenWithId
+from ..utils.typing import IndexSpan, OneOrMoreIndices, TokenWithId
from .aggregation_functions import AggregationFunction
from .data_utils import TensorWrapper
@@ -305,7 +306,7 @@ def _process_attribution_scores(
cls,
attr: "FeatureAttributionSequenceOutput",
aggregate_fn: AggregationFunction,
- select_idx: Union[int, tuple[int, int], list[int], None] = None,
+ select_idx: Optional[OneOrMoreIndices] = None,
normalize: bool = True,
**kwargs,
):
@@ -366,7 +367,7 @@ def aggregate_source_attributions(
cls,
attr: "FeatureAttributionSequenceOutput",
aggregate_fn: AggregationFunction,
- select_idx: Union[int, tuple[int, int], list[int], None] = None,
+ select_idx: Optional[OneOrMoreIndices] = None,
normalize: bool = True,
**kwargs,
):
@@ -380,7 +381,7 @@ def aggregate_target_attributions(
cls,
attr: "FeatureAttributionSequenceOutput",
aggregate_fn: AggregationFunction,
- select_idx: Union[int, tuple[int, int], list[int], None] = None,
+ select_idx: Optional[OneOrMoreIndices] = None,
normalize: bool = True,
**kwargs,
):
@@ -398,7 +399,7 @@ def aggregate_sequence_scores(
cls,
attr: "FeatureAttributionSequenceOutput",
aggregate_fn: AggregationFunction,
- select_idx: Union[int, tuple[int, int], list[int], None] = None,
+ select_idx: Optional[OneOrMoreIndices] = None,
**kwargs,
):
if aggregate_fn.takes_sequence_scores:
@@ -439,46 +440,12 @@ def is_compatible(attr: "FeatureAttributionSequenceOutput"):
def _filter_scores(
scores: torch.Tensor,
dim: int = -1,
- indices: Union[int, tuple[int, int], list[int], None] = None,
+ indices: Optional[OneOrMoreIndices] = None,
) -> torch.Tensor:
- n_units = scores.shape[dim]
-
- if hasattr(indices, "__iter__"):
- if len(indices) == 0:
- raise RuntimeError("At least two indices must be specified for aggregation.")
- if len(indices) == 1:
- indices = indices[0]
-
+ indexed = scores.index_select(dim, validate_indices(scores, dim, indices).to(scores.device))
if isinstance(indices, int):
- if indices not in range(-n_units, n_units):
- raise IndexError(f"Index out of range. Scores only have {n_units} units.")
- indices = indices if indices >= 0 else n_units + indices
- return scores.select(dim, torch.tensor(indices, device=scores.device))
- else:
- if indices is None:
- indices = (0, n_units)
- logger.info("No indices specified for extraction. Using all units by default.")
-
- # Convert negative indices to positive indices
- if hasattr(indices, "__iter__"):
- indices = type(indices)([h_idx if h_idx >= 0 else n_units + h_idx for h_idx in indices])
- if not hasattr(indices, "__iter__") or (
- len(indices) == 2 and isinstance(indices, tuple) and indices[0] >= indices[1]
- ):
- raise RuntimeError(
- "A (start, end) tuple of indices representing a span, a list of individual indices"
- " or a single index must be specified for select_idx."
- )
- max_idx_val = n_units if isinstance(indices, list) else n_units + 1
- if not all(h in range(-n_units, max_idx_val) for h in indices):
- raise IndexError("One or more index out of range. Scores only have {n_units} units.")
- if len(set(indices)) != len(indices):
- raise IndexError("Duplicate indices are not allowed.")
- if isinstance(indices, tuple):
- scores = scores.index_select(dim, torch.arange(indices[0], indices[1], device=scores.device))
- else:
- scores = scores.index_select(dim, torch.tensor(indices, device=scores.device))
- return scores
+ return indexed.squeeze(dim)
+ return indexed
@staticmethod
def _aggregate_scores(
diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py
index f3671244..7841cf7c 100644
--- a/inseq/data/attribution.py
+++ b/inseq/data/attribution.py
@@ -12,6 +12,7 @@
get_sequences_from_batched_steps,
json_advanced_dump,
json_advanced_load,
+ pad_with_nan,
pretty_dict,
remap_from_filtered,
)
@@ -178,9 +179,8 @@ def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callab
def from_step_attributions(
cls,
attributions: list["FeatureAttributionStepOutput"],
- tokenized_target_sentences: Optional[list[list[TokenWithId]]] = None,
- pad_id: Optional[Any] = None,
- has_bos_token: bool = True,
+ tokenized_target_sentences: list[list[TokenWithId]],
+ pad_token: Optional[Any] = None,
attr_pos_end: Optional[int] = None,
) -> list["FeatureAttributionSequenceOutput"]:
"""Converts a list of :class:`~inseq.data.attribution.FeatureAttributionStepOutput` objects containing multiple
@@ -198,36 +198,35 @@ def from_step_attributions(
num_sequences = len(attr.prefix)
if not all(len(attr.prefix) == num_sequences for attr in attributions):
raise ValueError("All the attributions must include the same number of sequences.")
- seq_attributions = []
- sources = None
- if attr.source_attributions is not None:
- sources = [drop_padding(attr.source[seq_id], pad_id) for seq_id in range(num_sequences)]
- targets = [
- drop_padding([a.target[seq_id][0] for a in attributions], pad_id) for seq_id in range(num_sequences)
- ]
- if tokenized_target_sentences is None:
- tokenized_target_sentences = targets
- if has_bos_token:
- tokenized_target_sentences = [tok_seq[1:] for tok_seq in tokenized_target_sentences]
- tokenized_target_sentences = [
- drop_padding(tokenized_target_sentences[seq_id], pad_id) for seq_id in range(num_sequences)
- ]
+ seq_attributions: list[FeatureAttributionSequenceOutput] = []
+ sources = []
+ targets = []
+ pos_start = []
+ for seq_idx in range(num_sequences):
+ if attr.source_attributions is not None:
+ sources.append(drop_padding(attr.source[seq_idx], pad_token))
+ curr_target = [a.target[seq_idx][0] for a in attributions]
+ targets.append(drop_padding(curr_target, pad_token))
+ if all(attr.prefix[seq_idx][0] == pad_token for seq_idx in range(num_sequences)):
+ tokenized_target_sentences[seq_idx] = tokenized_target_sentences[seq_idx][:1] + drop_padding(
+ tokenized_target_sentences[seq_idx][1:], pad_token
+ )
+ else:
+ tokenized_target_sentences[seq_idx] = drop_padding(tokenized_target_sentences[seq_idx], pad_token)
if attr_pos_end is None:
attr_pos_end = max(len(t) for t in tokenized_target_sentences)
- pos_start = [
- min(len(tokenized_target_sentences[seq_id]), attr_pos_end) - len(targets[seq_id])
- for seq_id in range(num_sequences)
- ]
- for seq_id in range(num_sequences):
- source = tokenized_target_sentences[seq_id][: pos_start[seq_id]] if sources is None else sources[seq_id]
- seq_attributions.append(
- attr.get_sequence_cls(
- source=source,
- target=tokenized_target_sentences[seq_id],
- attr_pos_start=pos_start[seq_id],
- attr_pos_end=attr_pos_end,
- )
+ for seq_idx in range(num_sequences):
+ # If the model is decoder-only, the source is the input prefix
+ curr_pos_start = min(len(tokenized_target_sentences[seq_idx]), attr_pos_end) - len(targets[seq_idx])
+ pos_start.append(curr_pos_start)
+ source = tokenized_target_sentences[seq_idx][:curr_pos_start] if not sources else sources[seq_idx]
+ curr_seq_attribution: FeatureAttributionSequenceOutput = attr.get_sequence_cls(
+ source=source,
+ target=tokenized_target_sentences[seq_idx],
+ attr_pos_start=pos_start[seq_idx],
+ attr_pos_end=attr_pos_end,
)
+ seq_attributions.append(curr_seq_attribution)
if attr.source_attributions is not None:
source_attributions = get_sequences_from_batched_steps([att.source_attributions for att in attributions])
for seq_id in range(num_sequences):
@@ -241,18 +240,13 @@ def from_step_attributions(
[att.target_attributions for att in attributions], padding_dims=[1]
)
for seq_id in range(num_sequences):
- if has_bos_token:
- target_attributions[seq_id] = target_attributions[seq_id][1:, ...]
start_idx = max(pos_start) - pos_start[seq_id]
end_idx = start_idx + len(tokenized_target_sentences[seq_id])
target_attributions[seq_id] = target_attributions[seq_id][
start_idx:end_idx, : len(targets[seq_id]), ... # noqa: E203
]
if target_attributions[seq_id].shape[0] != len(tokenized_target_sentences[seq_id]):
- empty_final_row = torch.ones(
- 1, *target_attributions[seq_id].shape[1:], device=target_attributions[seq_id].device
- ) * float("nan")
- target_attributions[seq_id] = torch.cat([target_attributions[seq_id], empty_final_row], dim=0)
+ target_attributions[seq_id] = pad_with_nan(target_attributions[seq_id], dim=0, pad_size=1)
seq_attributions[seq_id].target_attributions = target_attributions[seq_id]
if attr.step_scores is not None:
step_scores = [{} for _ in range(num_sequences)]
@@ -427,47 +421,51 @@ def remap_from_filtered(
self,
target_attention_mask: TargetIdsTensor,
batch: Union[DecoderOnlyBatch, EncoderDecoderBatch],
+ is_final_step_method: bool = False,
) -> None:
"""Remaps the attributions to the original shape of the input sequence."""
+ batch_size = (
+ len(batch.sources.input_tokens) if self.source_attributions is not None else len(batch.target_tokens)
+ )
+ source_len = len(batch.sources.input_tokens[0])
+ target_len = len(batch.target_tokens[0])
+ # Normal per-step attribution outputs have shape (batch_size, seq_len, ...)
+ other_dims_start_idx = 2
+ # Final step attribution outputs have shape (batch_size, seq_len, seq_len, ...)
+ if is_final_step_method:
+ other_dims_start_idx += 1
+ other_dims = (
+ self.source_attributions.shape[other_dims_start_idx:]
+ if self.source_attributions is not None
+ else self.target_attributions.shape[other_dims_start_idx:]
+ )
if self.source_attributions is not None:
self.source_attributions = remap_from_filtered(
- original_shape=(len(batch.sources.input_tokens), *self.source_attributions.shape[1:]),
+ original_shape=(batch_size, *self.source_attributions.shape[1:]),
mask=target_attention_mask,
filtered=self.source_attributions,
)
if self.target_attributions is not None:
self.target_attributions = remap_from_filtered(
- original_shape=(len(batch.target_tokens), *self.target_attributions.shape[1:]),
+ original_shape=(batch_size, *self.target_attributions.shape[1:]),
mask=target_attention_mask,
filtered=self.target_attributions,
)
if self.step_scores is not None:
for score_name, score_tensor in self.step_scores.items():
self.step_scores[score_name] = remap_from_filtered(
- original_shape=(len(batch.target_tokens), 1),
+ original_shape=(batch_size, 1),
mask=target_attention_mask,
filtered=score_tensor.unsqueeze(-1),
).squeeze(-1)
if self.sequence_scores is not None:
for score_name, score_tensor in self.sequence_scores.items():
if score_name.startswith("decoder"):
- original_shape = (
- len(batch.target_tokens),
- self.target_attributions.shape[1],
- *self.target_attributions.shape[1:],
- )
+ original_shape = (batch_size, target_len, target_len, *other_dims)
elif score_name.startswith("encoder"):
- original_shape = (
- len(batch.sources.input_tokens),
- self.source_attributions.shape[1],
- *self.source_attributions.shape[1:],
- )
+ original_shape = (batch_size, source_len, source_len, *other_dims)
else: # default case: cross-attention
- original_shape = (
- len(batch.sources.input_tokens),
- self.target_attributions.shape[1],
- *self.source_attributions.shape[1:],
- )
+ original_shape = (batch_size, source_len, target_len, *other_dims)
self.sequence_scores[score_name] = remap_from_filtered(
original_shape=original_shape,
mask=target_attention_mask,
diff --git a/inseq/data/data_utils.py b/inseq/data/data_utils.py
index b907d627..d0f90203 100644
--- a/inseq/data/data_utils.py
+++ b/inseq/data/data_utils.py
@@ -112,7 +112,7 @@ def _torch(attr):
def _eq(self_attr: TensorClass, other_attr: TensorClass) -> bool:
try:
if isinstance(self_attr, torch.Tensor):
- return torch.allclose(self_attr, other_attr, equal_nan=True)
+ return torch.allclose(self_attr, other_attr, equal_nan=True, atol=1e-5)
elif isinstance(self_attr, dict):
return all(TensorWrapper._eq(self_attr[k], other_attr[k]) for k in self_attr.keys())
else:
@@ -175,6 +175,10 @@ def clone(self: TensorClass) -> TensorClass:
out_params[field.name] = None
return self.__class__(**out_params)
+ def clone_empty(self: TensorClass) -> TensorClass:
+ out_params = {k: v for k, v in self.__dict__.items() if k.startswith("_") and v is not None}
+ return self.__class__(**out_params)
+
def to_dict(self: TensorClass) -> dict[str, Any]:
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
diff --git a/inseq/data/viz.py b/inseq/data/viz.py
index fef73b58..5721dff1 100644
--- a/inseq/data/viz.py
+++ b/inseq/data/viz.py
@@ -26,6 +26,7 @@
from rich.color import Color
from rich.console import Console
from rich.live import Live
+from rich.markup import escape
from rich.padding import Padding
from rich.panel import Panel
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
@@ -255,7 +256,7 @@ def get_saliency_heatmap_rich(
):
columns = [Column(header="", justify="right", overflow="fold")]
for column_label in column_labels:
- columns.append(Column(header=column_label, justify="center", overflow="fold"))
+ columns.append(Column(header=escape(column_label), justify="center", overflow="fold"))
table = Table(
*columns,
title=f"{label + ' ' if label else ''}Saliency Heatmap",
@@ -266,7 +267,7 @@ def get_saliency_heatmap_rich(
)
if scores is not None:
for row_index in range(scores.shape[0]):
- row = [Text(row_labels[row_index], style="bold")]
+ row = [Text(escape(row_labels[row_index]), style="bold")]
for col_index in range(scores.shape[1]):
color = Color.from_rgb(*input_colors[row_index][col_index])
score = ""
@@ -281,7 +282,7 @@ def get_saliency_heatmap_rich(
else:
threshold = step_scores_threshold.get(step_score_name, 0.5)
style = lambda val, limit: "bold" if abs(val) >= limit and isinstance(val, float) else ""
- score_row = [Text(step_score_name, style="bold")]
+ score_row = [Text(escape(step_score_name), style="bold")]
for score in step_score_values:
curr_score = round(score.item(), 2) if isinstance(score, float) else score.item()
score_row.append(Text(f"{score:.2f}", justify="center", style=style(curr_score, threshold)))
@@ -317,13 +318,13 @@ def get_progress_bar(
TimeRemainingColumn(),
)
for idx, (tgt, tgt_len) in enumerate(zip(sequences.targets, target_lengths)):
- clean_tgt = tgt.replace("\n", "\\n")
+ clean_tgt = escape(tgt.replace("\n", "\\n"))
job_progress.add_task(f"{idx}. {clean_tgt}", total=tgt_len)
progress_table = Table.grid()
row_contents = [
Panel.fit(
job_progress,
- title=f"[b]Attributing with {method_name}",
+ title=f"[b]Attributing with {escape(method_name)}",
border_style="green",
padding=(1, 2),
)
@@ -331,7 +332,7 @@ def get_progress_bar(
if sequences.sources is not None:
sources = []
for idx, src in enumerate(sequences.sources):
- clean_src = src.replace("\n", "\\n")
+ clean_src = escape(src.replace("\n", "\\n"))
sources.append(f"{idx}. {clean_src}")
row_contents = [
Panel.fit(
@@ -370,7 +371,7 @@ def update_progress_bar(
past_length = 0
for split, color in zip(split_targets, ["grey58", "green", "orange1", "grey58"]):
if split[job.id]:
- formatted_desc += f"[{color}]" + split[job.id].replace("\n", "\\n") + "[/]"
+ formatted_desc += f"[{color}]" + escape(split[job.id].replace("\n", "\\n")) + "[/]"
past_length += len(split[job.id])
if past_length in whitespace_indexes[job.id]:
formatted_desc += " "
diff --git a/inseq/models/attribution_model.py b/inseq/models/attribution_model.py
index c74e82c8..2f259a4c 100644
--- a/inseq/models/attribution_model.py
+++ b/inseq/models/attribution_model.py
@@ -219,6 +219,7 @@ def __init__(self, **kwargs) -> None:
self.pad_token: Optional[str] = None
self.embed_scale: Optional[float] = None
self._device: Optional[str] = None
+ self.device_map: Optional[dict[str, Union[str, int, torch.device]]] = None
self.attribution_method: Optional[FeatureAttribution] = None
self.is_hooked: bool = False
self._default_attributed_fn_id: str = "probability"
@@ -386,6 +387,24 @@ def attribute(
original_device = self.device
if device is not None:
self.device = device
+ attribution_method = self.get_attribution_method(method, override_default_attribution)
+ attributed_fn = self.get_attributed_fn(attributed_fn)
+ attribution_args, attributed_fn_args, step_scores_args = extract_args(
+ attribution_method,
+ attributed_fn,
+ step_scores,
+ default_args=self.formatter.get_step_function_reserved_args(),
+ **kwargs,
+ )
+ if isnotebook():
+ logger.debug("Pretty progress currently not supported in notebooks, falling back to tqdm.")
+ pretty_progress = False
+ if attribution_method.is_final_step_method:
+ if step_scores:
+ raise ValueError(
+ "Step scores are not supported for final step methods since they do not iterate over the full"
+ " sequence. Please remove the step scores and compute them separatly passing method='dummy'."
+ )
input_texts, generated_texts = format_input_texts(input_texts, generated_texts)
has_generated_texts = generated_texts is not None
if not self.is_encoder_decoder:
@@ -411,36 +430,30 @@ def attribute(
f"Generation arguments {generation_args} are provided, but will be ignored (constrained decoding)."
)
logger.debug(f"reference_texts={generated_texts}")
- attribution_method = self.get_attribution_method(method, override_default_attribution)
- attributed_fn = self.get_attributed_fn(attributed_fn)
- attribution_args, attributed_fn_args, step_scores_args = extract_args(
- attribution_method,
- attributed_fn,
- step_scores,
- default_args=self.formatter.get_step_function_reserved_args(),
- **kwargs,
- )
- if isnotebook():
- logger.debug("Pretty progress currently not supported in notebooks, falling back to tqdm.")
- pretty_progress = False
if not self.is_encoder_decoder:
assert all(
generated_texts[idx].startswith(input_texts[idx]) for idx in range(len(input_texts))
), "Forced generations with decoder-only models must start with the input texts."
if has_generated_texts and len(input_texts) > 1:
- logger.info(
+ logger.warning(
"Batched constrained decoding is currently not supported for decoder-only models."
" Using batch size of 1."
)
batch_size = 1
if len(input_texts) > 1 and (attr_pos_start is not None or attr_pos_end is not None):
- logger.info(
+ logger.warning(
"Custom attribution positions are currently not supported when batching generations for"
" decoder-only models. Using batch size of 1."
)
batch_size = 1
+ elif attribution_method.is_final_step_method and len(input_texts) > 1:
+ logger.warning(
+ "Batched attribution with encoder-decoder models currently not supported for final-step methods."
+ " Using batch size of 1."
+ )
+ batch_size = 1
if attribution_method.method_name == "lime":
- logger.info("Batched attribution currently not supported for LIME. Using batch size of 1.")
+ logger.warning("Batched attribution currently not supported for LIME. Using batch size of 1.")
batch_size = 1
attribution_outputs = attribution_method.prepare_and_attribute(
input_texts,
diff --git a/inseq/models/decoder_only.py b/inseq/models/decoder_only.py
index f4355e42..e9f8a25e 100644
--- a/inseq/models/decoder_only.py
+++ b/inseq/models/decoder_only.py
@@ -119,7 +119,7 @@ def enrich_step_output(
contrast_target_ids = contrast_batch.target_ids[:, contrast_aligned_idx]
step_output.target = join_token_ids(
tokens=target_tokens,
- ids=attribution_model.convert_ids_to_tokens(contrast_target_ids),
+ ids=attribution_model.convert_ids_to_tokens(contrast_target_ids, skip_special_tokens=False),
contrast_tokens=attribution_model.convert_ids_to_tokens(
contrast_target_ids[None, ...], skip_special_tokens=False
),
diff --git a/inseq/models/huggingface_model.py b/inseq/models/huggingface_model.py
index c966372f..bb44f21c 100644
--- a/inseq/models/huggingface_model.py
+++ b/inseq/models/huggingface_model.py
@@ -95,9 +95,6 @@ def __init__(
if isinstance(model, PreTrainedModel):
self.model = model
else:
- if "output_attentions" not in model_kwargs:
- model_kwargs["output_attentions"] = True
-
self.model = self._autoclass.from_pretrained(model, **model_kwargs)
self.model_name = self.model.config.name_or_path
self.tokenizer_name = tokenizer if isinstance(tokenizer, str) else None
@@ -112,14 +109,23 @@ def __init__(
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)
- if self.model.config.pad_token_id is not None:
- self.pad_token = self._convert_ids_to_tokens(self.model.config.pad_token_id, skip_special_tokens=False)
+ self.eos_token_id = getattr(self.model.config, "eos_token_id", None)
+ pad_token_id = self.model.config.pad_token_id
+ if pad_token_id is None:
+ if self.tokenizer.pad_token_id is None:
+ logger.info(f"Setting `pad_token_id` to `eos_token_id`:{self.eos_token_id} for open-end generation.")
+ pad_token_id = self.eos_token_id
+ else:
+ pad_token_id = self.tokenizer.pad_token_id
+ self.pad_token = self._convert_ids_to_tokens(pad_token_id, skip_special_tokens=False)
+ if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.pad_token
+ if self.model.config.pad_token_id is None:
+ self.model.config.pad_token_id = pad_token_id
self.bos_token_id = getattr(self.model.config, "decoder_start_token_id", None)
if self.bos_token_id is None:
self.bos_token_id = self.model.config.bos_token_id
self.bos_token = self._convert_ids_to_tokens(self.bos_token_id, skip_special_tokens=False)
- self.eos_token_id = getattr(self.model.config, "eos_token_id", None)
if self.eos_token_id is None:
self.eos_token_id = self.tokenizer.pad_token_id
if self.tokenizer.unk_token_id is None:
@@ -127,6 +133,9 @@ def __init__(
self.embed_scale = 1.0
self.encoder_int_embeds = None
self.decoder_int_embeds = None
+ self.device_map = None
+ if hasattr(self.model, "hf_device_map") and self.model.hf_device_map is not None:
+ self.device_map = self.model.hf_device_map
self.is_encoder_decoder = self.model.config.is_encoder_decoder
self.configure_embeddings_scale()
self.setup(device, attribution_method, **kwargs)
@@ -162,16 +171,19 @@ def device(self, new_device: str) -> None:
is_loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
is_loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)
is_quantized = is_loaded_in_8bit or is_loaded_in_4bit
+ has_device_map = self.device_map is not None
# Enable compatibility with 8bit models
if self.model:
- if not is_quantized:
- self.model.to(self._device)
- else:
+ if is_quantized:
mode = "8bit" if is_loaded_in_8bit else "4bit"
logger.warning(
f"The model is loaded in {mode} mode. The device cannot be changed after loading the model."
)
+ elif has_device_map:
+ logger.warning("The model is loaded with a device map. The device cannot be changed after loading.")
+ else:
+ self.model.to(self._device)
@abstractmethod
def configure_embeddings_scale(self) -> None:
@@ -195,6 +207,7 @@ def generate(
inputs: Union[TextInput, BatchEncoding],
return_generation_output: bool = False,
skip_special_tokens: bool = True,
+ output_generated_only: bool = False,
**kwargs,
) -> Union[list[str], tuple[list[str], ModelOutput]]:
"""Wrapper of model.generate to handle tokenization and decoding.
@@ -204,6 +217,9 @@ def generate(
Inputs to be provided to the model for generation.
return_generation_output (`bool`, *optional*, defaults to False):
If true, generation outputs are returned alongside the generated text.
+ output_generated_only (`bool`, *optional*, defaults to False):
+ If true, only the generated text is returned. Relevant for decoder-only models that would otherwise return
+ the full input + output.
Returns:
`Union[List[str], Tuple[List[str], ModelOutput]]`: Generated text or a tuple of generated text and
@@ -220,6 +236,8 @@ def generate(
**kwargs,
)
sequences = generation_out.sequences
+ if output_generated_only and not self.is_encoder_decoder:
+ sequences = sequences[:, inputs.input_ids.shape[1] :]
texts = self.decode(ids=sequences, skip_special_tokens=skip_special_tokens)
if return_generation_output:
return texts, generation_out
diff --git a/inseq/models/model_config.py b/inseq/models/model_config.py
index 05b8a468..52d9d47b 100644
--- a/inseq/models/model_config.py
+++ b/inseq/models/model_config.py
@@ -1,6 +1,7 @@
import logging
from dataclasses import dataclass
from pathlib import Path
+from typing import Optional
import yaml
@@ -10,14 +11,25 @@
@dataclass
class ModelConfig:
"""Configuration used by the methods for which the attribute ``use_model_config=True``.
+
Args:
- attention_module (:obj:`str`):
- The name of the module performing the attention computation (e.g.``attn`` for the GPT-2 model in
- transformers). Can be identified by looking at the name of the attribute instantiating the attention module
+ self_attention_module (:obj:`str`):
+ The name of the module performing the self-attention computation (e.g.``attn`` for the GPT-2 model in
+ transformers). Can be identified by looking at the name of the self-attention module attribute
in the model's transformer block class (e.g. :obj:`transformers.models.gpt2.GPT2Block` for GPT-2).
+ cross_attention_module (:obj:`str`):
+ The name of the module performing the cross-attention computation (e.g.``encoder_attn`` for MarianMT models
+ in transformers). Can be identified by looking at the name of the cross-attention module attribute
+ in the model's transformer block class (e.g. :obj:`transformers.models.marian.MarianDecoderLayer`).
+ value_vector (:obj:`str`):
+ The name of the variable in the forward pass of the attention module containing the value vector
+ (e.g. ``value`` for the GPT-2 model in transformers). Can be identified by looking at the forward pass of
+ the attention module (e.g. :obj:`transformers.models.gpt2.modeling_gpt2.GPT2Attention.forward` for GPT-2).
"""
- attention_module: str
+ self_attention_module: str
+ value_vector: str
+ cross_attention_module: Optional[str] = None
MODEL_CONFIGS = {
diff --git a/inseq/models/model_config.yaml b/inseq/models/model_config.yaml
index b48ed209..1b2433a4 100644
--- a/inseq/models/model_config.yaml
+++ b/inseq/models/model_config.yaml
@@ -1,2 +1,129 @@
+# Decoder-only models
+BioGptForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+BloomForCausalLM:
+ self_attention_module: "self_attention"
+ value_vector: "value_layer"
+CodeGenForCausalLM:
+ self_attention_module: "attn"
+ value_vector: "value"
+CohereForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+DbrxForCausalLM:
+ self_attention_module: "attn"
+ value_vector: "value_states"
+FalconForCausalLM:
+ self_attention_module: "self_attention"
+ value_vector: "value_layer"
+GemmaForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+GPTBigCodeForCausalLM:
+ self_attention_module: "attn"
+ value_vector: "value"
+GPTJForCausalLM:
+ self_attention_module: "attn"
+ value_vector: "value"
GPT2LMHeadModel:
- attention_module: "attn"
\ No newline at end of file
+ self_attention_module: "attn"
+ value_vector: "value"
+GPTNeoForCausalLM:
+ self_attention_module: "attn"
+ value_vector: "value"
+GPTNeoXForCausalLM:
+ self_attention_module: "attention"
+ value_vector: "value"
+LlamaForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+MistralForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+MixtralForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+MptForCausalLM:
+ self_attention_module: "attn"
+ value_vector: "value_states"
+OlmoForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+OpenAIGPTLMHeadModel:
+ self_attention_module: "attn"
+ value_vector: "value"
+OPTForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+PhiForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+Phi3ForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+Qwen2ForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+Qwen2MoeForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+StableLmForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+StarCoder2ForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+XGLMForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+
+# Encoder-decoder models
+BartForConditionalGeneration:
+ self_attention_module: "self_attn"
+ cross_attention_module: "encoder_attn"
+ value_vector: "value_states"
+MarianMTModel:
+ self_attention_module: "self_attn"
+ cross_attention_module: "encoder_attn"
+ value_vector: "value_states"
+FSMTForConditionalGeneration:
+ self_attention_module: "self_attn"
+ cross_attention_module: "encoder_attn"
+ value_vector: "v"
+M2M100ForConditionalGeneration:
+ self_attention_module: "self_attn"
+ cross_attention_module: "encoder_attn"
+ value_vector: "value_states"
+MBartForConditionalGeneration:
+ self_attention_module: "self_attn"
+ cross_attention_module: "encoder_attn"
+ value_vector: "value_states"
+MT5ForConditionalGeneration:
+ self_attention_module: "SelfAttention"
+ cross_attention_module: "EncDecAttention"
+ value_vector: "value_states"
+NllbMoeForConditionalGeneration:
+ self_attention_module: "self_attn"
+ cross_attention_module: "cross_attention"
+ value_vector: "value_states"
+PegasusForConditionalGeneration:
+ self_attention_module: "self_attn"
+ cross_attention_module: "encoder_attn"
+ value_vector: "value_states"
+SeamlessM4TForTextToText:
+ self_attention_module: "self_attn"
+ cross_attention_module: "cross_attention"
+ value_vector: "value"
+SeamlessM4Tv2ForTextToText:
+ self_attention_module: "self_attn"
+ cross_attention_module: "cross_attention"
+ value_vector: "value"
+T5ForConditionalGeneration:
+ self_attention_module: "SelfAttention"
+ cross_attention_module: "EncDecAttention"
+ value_vector: "value_states"
+UMT5ForConditionalGeneration:
+ self_attention_module: "SelfAttention"
+ cross_attention_module: "EncDecAttention"
+ value_vector: "value_states"
diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py
index 29f81615..92a763cf 100644
--- a/inseq/utils/__init__.py
+++ b/inseq/utils/__init__.py
@@ -8,11 +8,14 @@
MissingAttributionMethodError,
UnknownAttributionMethodError,
)
+from .hooks import get_post_variable_assignment_hook
from .import_utils import (
+ is_accelerate_available,
is_captum_available,
is_datasets_available,
is_ipywidgets_available,
is_joblib_available,
+ is_nltk_available,
is_scikitlearn_available,
is_sentencepiece_available,
is_transformers_available,
@@ -49,12 +52,16 @@
check_device,
euclidean_distance,
filter_logits,
+ find_block_stack,
get_default_device,
get_front_padding,
get_sequences_from_batched_steps,
normalize,
+ pad_with_nan,
+ recursive_get_submodule,
remap_from_filtered,
top_p_logits_mask,
+ validate_indices,
)
__all__ = [
@@ -94,6 +101,7 @@
"is_datasets_available",
"is_captum_available",
"is_joblib_available",
+ "is_nltk_available",
"check_device",
"get_default_device",
"ndarray_to_bin_str",
@@ -118,4 +126,9 @@
"top_p_logits_mask",
"filter_logits",
"cli_arg",
+ "get_post_variable_assignment_hook",
+ "validate_indices",
+ "pad_with_nan",
+ "recursive_get_submodule",
+ "is_accelerate_available",
]
diff --git a/inseq/utils/hooks.py b/inseq/utils/hooks.py
new file mode 100644
index 00000000..98fd07e5
--- /dev/null
+++ b/inseq/utils/hooks.py
@@ -0,0 +1,109 @@
+import re
+from inspect import getsourcelines
+from sys import gettrace, settrace
+from types import FrameType
+from typing import Callable, Optional
+
+from torch import nn
+
+from .misc import get_left_padding
+
+
+def get_last_variable_assignment_position(
+ module: nn.Module,
+ varname: str,
+ fname: str = "forward",
+) -> Optional[int]:
+ """Extract the code line number of the last variable assignment for a variable of interest in the specified method
+ of a `nn.Module` object.
+
+ Args:
+ module (`nn.Module`):
+ A PyTorch module containing a method with a variable assignment after which the hook should be executed.
+ varname (`str`):
+ The name of the variable to use as anchor for the hook.
+ fname (`str`, *optional*, defaults to "forward"):
+ The name of the method in which the variable assignment should be searched.
+
+ Returns:
+ `Optional[int]`: Returns the line number in the file (not relative to the method) of the last variable
+ assignment. Returns None if no assignment to the variable was found.
+ """
+ # Matches any assignment of variable varname
+ pattern = rf"^\s*(?:\w+\s*,\s*)*\b{varname}\b\s*(?:,.+\s*)*=\s*[^\W=]+.*$"
+ code, startline = getsourcelines(getattr(module, fname))
+ line_numbers = []
+ i = 0
+ while i < len(code):
+ line = code[i]
+ # Handles multi-line assignments
+ if re.match(pattern, line):
+ parentheses_count = line.count("(") - line.count(")")
+ ends_with_newline = lambda l: l.strip().endswith("\\")
+ follow_indent = lambda l, i: len(code) > i + 1 and get_left_padding(code[i + 1]) > get_left_padding(l)
+ while (ends_with_newline(line) or follow_indent(line, i) or parentheses_count > 0) and len(code) > i + 1:
+ i += 1
+ line = code[i]
+ parentheses_count += line.count("(") - line.count(")")
+ line_numbers.append(i)
+ i += 1
+ if len(line_numbers) == 0:
+ return None
+ return line_numbers[-1] + startline + 1
+
+
+def get_post_variable_assignment_hook(
+ module: nn.Module,
+ varname: str,
+ fname: str = "forward",
+ hook_fn: Callable[[FrameType], None] = lambda **kwargs: None,
+ **kwargs,
+) -> Callable[[], None]:
+ """Creates a hook that is called after the last variable assignment in the specified method of a `nn.Module`.
+
+ This is a hacky method using the ``sys.settrace()`` function to circumvent the limited hook points of Pytorch hooks
+ and set a custom hook point dynamically. This approach is preferred to ensure a broader compatibility with Hugging
+ Face transformers models that do not provide hook points in their architectures for the moment.
+
+ Args:
+ module (`nn.Module`):
+ A PyTorch module containing a method with a variable assignment after which the hook should be executed.
+ varname (`str`):
+ The name of the variable to use as anchor for the hook.
+ fname (`str`, *optional*, defaults to "forward"):
+ The name of the method in which the variable assignment should be searched.
+ hook_fn (`Callable[[FrameType], None]`, *optional*, defaults to lambdaframe):
+ A custom hook function that is called after the last variable assignment in the specified method. The first
+ parameter is the current frame in the execution at the hook point, and any additional arguments can be
+ passed when creating the hook. ``frame.f_locals`` is a dictionary containing all local variables.
+
+ Returns:
+ The hook function that can be registered with the module. If hooking the module's ``forward()`` method, the
+ hook can be registered with Pytorch native hook methods.
+ """
+ hook_line_num = get_last_variable_assignment_position(module, varname, fname)
+ curr_trace_fn = gettrace()
+ if hook_line_num is None:
+ raise ValueError(f"Could not find assignment to {varname} in {module}'s {fname}() method")
+
+ def var_tracer(frame, event, arg=None):
+ curr_line_num = frame.f_lineno
+ curr_func_name = frame.f_code.co_name
+
+ # Matches the first executable line after hook_line_num in the same function of the same module
+ if (
+ event == "line"
+ and curr_line_num >= hook_line_num
+ and curr_func_name == fname
+ and isinstance(frame.f_locals.get("self"), nn.Module)
+ and frame.f_locals.get("self")._get_name() == module._get_name()
+ ):
+ # Call the custom hook providing the current frame and any additional arguments as context
+ hook_fn(frame, **kwargs)
+ settrace(curr_trace_fn)
+ return var_tracer
+
+ def hook(*args, **kwargs):
+ settrace(var_tracer)
+
+ return hook
diff --git a/inseq/utils/import_utils.py b/inseq/utils/import_utils.py
index cbd03420..e8ae455e 100644
--- a/inseq/utils/import_utils.py
+++ b/inseq/utils/import_utils.py
@@ -7,6 +7,8 @@
_datasets_available = find_spec("datasets") is not None
_captum_available = find_spec("captum") is not None
_joblib_available = find_spec("joblib") is not None
+_nltk_available = find_spec("nltk") is not None
+_accelerate_available = find_spec("accelerate") is not None
def is_ipywidgets_available():
@@ -35,3 +37,11 @@ def is_captum_available():
def is_joblib_available():
return _joblib_available
+
+
+def is_nltk_available():
+ return _nltk_available
+
+
+def is_accelerate_available():
+ return _accelerate_available
diff --git a/inseq/utils/misc.py b/inseq/utils/misc.py
index e09e5df7..628995bc 100644
--- a/inseq/utils/misc.py
+++ b/inseq/utils/misc.py
@@ -10,7 +10,6 @@
from functools import wraps
from importlib import import_module
from inspect import signature
-from itertools import dropwhile
from numbers import Number
from os import PathLike, fsync
from typing import Any, Callable, Optional, Union
@@ -171,10 +170,10 @@ def pad(seq: Sequence[Sequence[Any]], pad_id: Any):
return seq
-def drop_padding(seq: Sequence[Any], pad_id: Any):
+def drop_padding(seq: Sequence[TokenWithId], pad_id: str):
if pad_id is None:
return seq
- return list(reversed(list(dropwhile(lambda x: x == pad_id, reversed(seq)))))
+ return [x for x in seq if x.token != pad_id]
def isnotebook():
@@ -435,3 +434,8 @@ def clean_tokens(tokens: list[str], remove_tokens: list[str]) -> tuple[list[str]
else:
removed_token_idxs += [idx]
return clean_tokens, removed_token_idxs
+
+
+def get_left_padding(text: str):
+ """Returns the number of spaces at the beginning of a string."""
+ return len(text) - len(text.lstrip())
diff --git a/inseq/utils/torch_utils.py b/inseq/utils/torch_utils.py
index 88e807cc..86acd635 100644
--- a/inseq/utils/torch_utils.py
+++ b/inseq/utils/torch_utils.py
@@ -5,11 +5,14 @@
import torch
import torch.nn.functional as F
from jaxtyping import Int, Num
+from torch import nn
from torch.backends.cuda import is_built as is_cuda_built
from torch.backends.mps import is_available as is_mps_available
from torch.backends.mps import is_built as is_mps_built
from torch.cuda import is_available as is_cuda_available
+from .typing import OneOrMoreIndices
+
if TYPE_CHECKING:
pass
@@ -244,3 +247,118 @@ def get_default_device() -> str:
return "cpu"
else:
return "cpu"
+
+
+def find_block_stack(module):
+ """Recursively searches for the first instance of a `nn.ModuleList` submodule within a given `torch.nn.Module`.
+
+ Args:
+ module (:obj:`torch.nn.Module`): A Pytorch :obj:`nn.Module` object.
+
+ Returns:
+ :obj:`torch.nn.ModuleList`: The first instance of a :obj:`nn.Module` submodule found within the given object.
+ None: If no `nn.ModuleList` submodule is found within the given `nn.Module` object.
+ """
+ # Check if the current module is an instance of nn.ModuleList
+ if isinstance(module, nn.ModuleList):
+ return module
+
+ # Recursively search for nn.ModuleList in the submodules of the current module
+ for submodule in module.children():
+ module_list = find_block_stack(submodule)
+ if module_list is not None:
+ return module_list
+
+ # If nn.ModuleList is not found in any submodules, return None
+ return None
+
+
+def validate_indices(
+ scores: torch.Tensor,
+ dim: int = -1,
+ indices: Optional[OneOrMoreIndices] = None,
+) -> OneOrMoreIndices:
+ """Validates a set of indices for a given dimension of a tensor of scores. Supports single indices, spans and lists
+ of indices, including negative indices to specify positions relative to the end of the tensor.
+
+ Args:
+ scores (torch.Tensor): The tensor of scores.
+ dim (int, optional): The dimension of the tensor that will be indexed. Defaults to -1.
+ indices (Union[int, tuple[int, int], list[int], None], optional):
+ - If an integer, it is interpreted as a single index for the dimension.
+ - If a tuple of two integers, it is interpreted as a span of indices for the dimension.
+ - If a list of integers, it is interpreted as a list of individual indices for the dimension.
+
+ Returns:
+ ``Union[int, tuple[int, int], list[int]]``: The validated list of positive indices for indexing the dimension.
+ """
+ if dim >= scores.ndim:
+ raise IndexError(f"Dimension {dim} is greater than tensor dimension {scores.ndim}")
+ n_units = scores.shape[dim]
+ if not isinstance(indices, (int, tuple, list)) and indices is not None:
+ raise TypeError(
+ "Indices must be an integer, a (start, end) tuple of indices representing a span, a list of individual"
+ " indices or a single index."
+ )
+ if hasattr(indices, "__iter__"):
+ if len(indices) == 0:
+ raise RuntimeError("An empty sequence of indices is not allowed.")
+ if len(indices) == 1:
+ indices = indices[0]
+
+ if isinstance(indices, int):
+ if indices not in range(-n_units, n_units):
+ raise IndexError(f"Index out of range. Scores only have {n_units} units.")
+ indices = indices if indices >= 0 else n_units + indices
+ return torch.tensor(indices)
+ else:
+ if indices is None:
+ indices = (0, n_units)
+ logger.info("No indices specified. Using all indices by default.")
+
+ # Convert negative indices to positive indices
+ if hasattr(indices, "__iter__"):
+ indices = type(indices)([h_idx if h_idx >= 0 else n_units + h_idx for h_idx in indices])
+ if not hasattr(indices, "__iter__") or (
+ len(indices) == 2 and isinstance(indices, tuple) and indices[0] >= indices[1]
+ ):
+ raise RuntimeError(
+ "A (start, end) tuple of indices representing a span, a list of individual indices"
+ " or a single index must be specified."
+ )
+ max_idx_val = n_units if isinstance(indices, list) else n_units + 1
+ if not all(h in range(-n_units, max_idx_val) for h in indices):
+ raise IndexError(f"One or more index out of range. Scores only have {n_units} units.")
+ if len(set(indices)) != len(indices):
+ raise IndexError("Duplicate indices are not allowed.")
+ if isinstance(indices, tuple):
+ return torch.arange(indices[0], indices[1])
+ else:
+ return torch.tensor(indices)
+
+
+def pad_with_nan(t: torch.Tensor, dim: int, pad_size: int, front: bool = False) -> torch.Tensor:
+ """Utility to pad a tensor with nan values along a given dimension."""
+ nan_tensor = torch.ones(
+ *t.shape[:dim],
+ pad_size,
+ *t.shape[dim + 1 :],
+ device=t.device,
+ ) * float("nan")
+ if front:
+ return torch.cat([nan_tensor, t], dim=dim)
+ return torch.cat([t, nan_tensor], dim=dim)
+
+
+def recursive_get_submodule(parent: nn.Module, target: str) -> Optional[nn.Module]:
+ if target == "":
+ return parent
+ mod = None
+ if hasattr(parent, target):
+ mod = getattr(parent, target)
+ else:
+ for submodule in parent.children():
+ mod = recursive_get_submodule(submodule, target)
+ if mod is not None:
+ break
+ return mod
diff --git a/inseq/utils/typing.py b/inseq/utils/typing.py
index 7599bbc7..4eec4a5b 100644
--- a/inseq/utils/typing.py
+++ b/inseq/utils/typing.py
@@ -1,13 +1,17 @@
from collections.abc import Sequence
from dataclasses import dataclass
-from typing import Optional, Union
+from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
+from captum.attr._utils.attribution import Attribution
from jaxtyping import Float, Float32, Int64
from transformers import PreTrainedModel
TextInput = Union[str, Sequence[str]]
+if TYPE_CHECKING:
+ from inseq.models import AttributionModel
+
@dataclass
class TokenWithId:
@@ -28,6 +32,34 @@ def __eq__(self, other: Union[str, int, "TokenWithId"]):
return False
+class InseqAttribution(Attribution):
+ """A wrapper class for the Captum library's Attribution class to type hint the ``forward_func`` attribute
+ as an :class:`~inseq.models.AttributionModel`.
+ """
+
+ def __init__(self, forward_func: "AttributionModel") -> None:
+ r"""
+ Args:
+ forward_func (:class:`~inseq.models.AttributionModel`): The model hooker to the attribution method.
+ """
+ self.forward_func = forward_func
+
+ attribute: Callable
+
+ @property
+ def multiplies_by_inputs(self):
+ return False
+
+ def has_convergence_delta(self) -> bool:
+ return False
+
+ compute_convergence_delta: Callable
+
+ @classmethod
+ def get_name(cls: type["InseqAttribution"]) -> str:
+ return "".join([char if char.islower() or idx == 0 else " " + char for idx, char in enumerate(cls.__name__)])
+
+
@dataclass
class TextSequences:
targets: TextInput
@@ -40,6 +72,8 @@ class TextSequences:
OneOrMoreAttributionSequences = Sequence[Sequence[float]]
IndexSpan = Union[tuple[int, int], Sequence[tuple[int, int]]]
+OneOrMoreIndices = Union[int, list[int], tuple[int, int]]
+OneOrMoreIndicesDict = dict[int, OneOrMoreIndices]
IdsTensor = Int64[torch.Tensor, "batch_size seq_len"]
TargetIdsTensor = Int64[torch.Tensor, "batch_size"]
diff --git a/pyproject.toml b/pyproject.toml
index 3babbe9a..32592ee8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "inseq"
-version = "0.6.0.dev0"
+version = "0.7.0.dev0"
description = "Interpretability for Sequence Generation Models š"
readme = "README.md"
requires-python = ">=3.9"
@@ -74,7 +74,7 @@ docs = [
]
lint = [
"bandit>=1.7.4",
- "safety>=2.2.0",
+ "safety>=3.1.0",
"pydoclint>=0.4.0",
"pre-commit>=2.19.0",
"pytest>=7.2.0",
@@ -93,6 +93,9 @@ notebook = [
"ipykernel>=6.29.2",
"ipywidgets>=8.1.2"
]
+nltk = [
+ "nltk>=3.8.1",
+]
[project.urls]
homepage = "https://github.com/inseq-team/inseq"
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 91a4d3f2..2fda0ec1 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,4 +1,4 @@
-# This file was autogenerated by uv v0.1.2 via the following command:
+# This file was autogenerated by uv via the following command:
# uv pip compile --all-extras pyproject.toml -o requirements-dev.txt
aiohttp==3.9.3
# via
@@ -32,6 +32,7 @@ charset-normalizer==3.3.2
# via requests
click==8.1.7
# via
+ # nltk
# pydoclint
# safety
# typer
@@ -43,7 +44,7 @@ contourpy==1.2.0
# via matplotlib
coverage==7.4.1
# via pytest-cov
-cryptography==42.0.2
+cryptography==42.0.5
# via authlib
cycler==0.12.1
# via matplotlib
@@ -123,7 +124,9 @@ jinja2==3.1.3
# sphinx
# torch
joblib==1.3.2
- # via scikit-learn
+ # via
+ # nltk
+ # scikit-learn
jupyter-client==8.6.0
# via ipykernel
jupyter-core==5.7.1
@@ -160,6 +163,7 @@ nest-asyncio==1.6.0
# via ipykernel
networkx==3.2.1
# via torch
+nltk==3.8.1
nodeenv==1.8.0
# via pre-commit
numpy==1.26.4
@@ -258,7 +262,9 @@ pyzmq==25.1.2
# ipykernel
# jupyter-client
regex==2023.12.25
- # via transformers
+ # via
+ # nltk
+ # transformers
requests==2.31.0
# via
# datasets
@@ -280,7 +286,7 @@ ruamel-yaml-clib==0.2.8
ruff==0.2.1
safetensors==0.4.2
# via transformers
-safety==3.0.1
+safety==3.1.0
safety-schemas==0.0.2
# via safety
scikit-learn==1.4.0
@@ -351,6 +357,7 @@ tqdm==4.66.2
# captum
# datasets
# huggingface-hub
+ # nltk
# transformers
traitlets==5.14.1
# via
@@ -361,7 +368,7 @@ traitlets==5.14.1
# jupyter-client
# jupyter-core
# matplotlib-inline
-transformers==4.37.2
+transformers==4.38.1
typeguard==2.13.3
# via jaxtyping
typer==0.9.0
diff --git a/requirements.txt b/requirements.txt
index 9f392d72..93809632 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-# This file was autogenerated by uv v0.1.2 via the following command:
+# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml -o requirements.txt
captum==0.7.0
certifi==2024.2.2
@@ -93,7 +93,7 @@ tqdm==4.66.2
# captum
# huggingface-hub
# transformers
-transformers==4.37.2
+transformers==4.38.1
typeguard==2.13.3
# via jaxtyping
typing-extensions==4.9.0
diff --git a/tests/attr/feat/test_feature_attribution.py b/tests/attr/feat/test_feature_attribution.py
index 07b2045c..80856176 100644
--- a/tests/attr/feat/test_feature_attribution.py
+++ b/tests/attr/feat/test_feature_attribution.py
@@ -1,8 +1,14 @@
+from typing import Any, Optional
+
import torch
+from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from pytest import fixture
import inseq
+from inseq.attr.feat.internals_attribution import InternalsAttributionRegistry
+from inseq.data import MultiDimensionalFeatureAttributionStepOutput
from inseq.models import HuggingfaceDecoderOnlyModel, HuggingfaceEncoderDecoderModel
+from inseq.utils.typing import InseqAttribution, MultiLayerMultiUnitScoreTensor
@fixture(scope="session")
@@ -69,7 +75,7 @@ def test_contrastive_attribution_seq2seq_alignments(saliency_mt_model_larger: Hu
"orig_tgt": "I soldati della pace ONU",
"contrast_tgt": "Le forze militari di pace delle Nazioni Unite",
"alignments": [[(0, 0), (1, 1), (2, 2), (3, 4), (4, 5), (5, 7), (6, 9)]],
- "aligned_tgts": ["āLe ā āI", "āforze ā āsoldati", "ādi ā ādella", "āpace", "āNazioni ā āONU", ""],
+ "aligned_tgts": ["", "āLe ā āI", "āforze ā āsoldati", "ādi ā ādella", "āpace", "āNazioni ā āONU", ""],
}
out = saliency_mt_model_larger.attribute(
aligned["src"],
@@ -129,3 +135,122 @@ def test_mcd_weighted_attribution_gpt(saliency_gpt_model):
)
attribution_scores = out.sequence_attributions[0].target_attributions
assert isinstance(attribution_scores, torch.Tensor)
+
+
+class MultiStepAttentionWeights(InseqAttribution):
+ """Variant of the AttentionWeights class with is_final_step_method = False.
+ As a result, the attention matrix is computed and sliced at every generation step.
+ We define it here to test consistency with the final step method.
+ """
+
+ def attribute(
+ self,
+ inputs: TensorOrTupleOfTensorsGeneric,
+ additional_forward_args: TensorOrTupleOfTensorsGeneric,
+ encoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None,
+ decoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None,
+ cross_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None,
+ ) -> MultiDimensionalFeatureAttributionStepOutput:
+ # We adopt the format [batch_size, sequence_length, num_layers, num_heads]
+ # for consistency with other multi-unit methods (e.g. gradient attribution)
+ decoder_self_attentions = decoder_self_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2)
+ if self.forward_func.is_encoder_decoder:
+ sequence_scores = {}
+ if len(inputs) > 1:
+ target_attributions = decoder_self_attentions
+ else:
+ target_attributions = None
+ sequence_scores["decoder_self_attentions"] = decoder_self_attentions
+ sequence_scores["encoder_self_attentions"] = (
+ encoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2)
+ )
+ return MultiDimensionalFeatureAttributionStepOutput(
+ source_attributions=cross_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2),
+ target_attributions=target_attributions,
+ sequence_scores=sequence_scores,
+ _num_dimensions=2, # num_layers, num_heads
+ )
+ else:
+ return MultiDimensionalFeatureAttributionStepOutput(
+ source_attributions=None,
+ target_attributions=decoder_self_attentions,
+ _num_dimensions=2, # num_layers, num_heads
+ )
+
+
+class MultiStepAttentionWeightsAttribution(InternalsAttributionRegistry):
+ """Variant of the basic attention attribution method computing attention weights at every generation step."""
+
+ method_name = "per_step_attention"
+
+ def __init__(self, attribution_model, **kwargs):
+ super().__init__(attribution_model)
+ # Attention weights will be passed to the attribute_step method
+ self.use_attention_weights = True
+ # Does not rely on predicted output (i.e. decoding strategy agnostic)
+ self.use_predicted_target = False
+ self.method = MultiStepAttentionWeights(attribution_model)
+
+ def attribute_step(
+ self,
+ attribute_fn_main_args: dict[str, Any],
+ attribution_args: dict[str, Any],
+ ) -> MultiDimensionalFeatureAttributionStepOutput:
+ return self.method.attribute(**attribute_fn_main_args, **attribution_args)
+
+
+def test_seq2seq_final_step_per_step_conformity(saliency_mt_model_larger: HuggingfaceEncoderDecoderModel):
+ out_per_step = saliency_mt_model_larger.attribute(
+ "Hello ladies and badgers!",
+ method="per_step_attention",
+ attribute_target=True,
+ show_progress=False,
+ output_step_attributions=True,
+ )
+ out_final_step = saliency_mt_model_larger.attribute(
+ "Hello ladies and badgers!",
+ method="attention",
+ attribute_target=True,
+ show_progress=False,
+ output_step_attributions=True,
+ )
+ assert out_per_step[0] == out_final_step[0]
+
+
+def test_gpt_final_step_per_step_conformity(saliency_gpt_model_larger: HuggingfaceDecoderOnlyModel):
+ out_per_step = saliency_gpt_model_larger.attribute(
+ "Hello ladies and badgers!",
+ method="per_step_attention",
+ show_progress=False,
+ output_step_attributions=True,
+ )
+ out_final_step = saliency_gpt_model_larger.attribute(
+ "Hello ladies and badgers!",
+ method="attention",
+ show_progress=False,
+ output_step_attributions=True,
+ )
+ assert out_per_step[0] == out_final_step[0]
+
+
+# Batching for Seq2Seq models is not supported when using is_final_step methods
+# Passing several sentences will attributed them one by one under the hood
+# def test_seq2seq_multi_step_attention_weights_batched_full_match(saliency_mt_model: HuggingfaceEncoderDecoderModel):
+
+
+def test_gpt_multi_step_attention_weights_batched_full_match(saliency_gpt_model_larger: HuggingfaceDecoderOnlyModel):
+ out_per_step = saliency_gpt_model_larger.attribute(
+ ["Hello world!", "Colorless green ideas sleep furiously."],
+ method="per_step_attention",
+ show_progress=False,
+ )
+ out_final_step = saliency_gpt_model_larger.attribute(
+ ["Hello world!", "Colorless green ideas sleep furiously."],
+ method="attention",
+ show_progress=False,
+ )
+ for i in range(2):
+ assert out_per_step[i].target_attributions.shape == out_final_step[i].target_attributions.shape
+ assert torch.allclose(
+ out_per_step[i].target_attributions, out_final_step[i].target_attributions, equal_nan=True, atol=1e-5
+ )
diff --git a/tests/commands/test_attribute_context.py b/tests/commands/test_attribute_context.py
index d747213e..1013b1cc 100644
--- a/tests/commands/test_attribute_context.py
+++ b/tests/commands/test_attribute_context.py
@@ -41,6 +41,7 @@ def test_in_out_ctx_encdec_whitespace_sep(encdec_model: MarianMTModel):
input_template="{context} {current}",
attributed_fn="contrast_prob_diff",
show_viz=False,
+ show_intermediate_outputs=True,
# Pre-defining natural model outputs to avoid user input in unit tests
output_context_text="",
output_current_text="OĆ¹ sont-elles?",
@@ -77,6 +78,7 @@ def test_in_ctx_deconly(deconly_model: GPT2LMHeadModel):
model_name_or_path=deconly_model,
input_context_text="George was sick yesterday.",
input_current_text="His colleagues asked him to come",
+ output_current_text="to the hospital. He said he was fine",
attributed_fn="contrast_prob_diff",
show_viz=False,
add_output_info=False,
@@ -110,12 +112,14 @@ def test_out_ctx_deconly(deconly_model: GPT2LMHeadModel):
# Base case for context-aware decoder-only model with forced output context mocking a reasoning chain.
out_ctx_deconly = AttributeContextArgs(
model_name_or_path=deconly_model,
- output_template="\n\nLet's think step by step:\n{context}\n\nAnswer:\n{current}",
+ output_template="Let's think step by step:\n{context}\n\nAnswer:\n{current}",
input_template="{current}",
input_current_text="Question: How many pairs of legs do 10 horses have?",
output_context_text="1. A horse has 4 legs.\n2. 10 horses have 40 legs.\n3. 40 legs make 20 pairs of legs.",
output_current_text="20 pairs of legs.",
attributed_fn="contrast_prob_diff",
+ decoder_input_output_separator="\n",
+ contextless_output_current_text="Answer:\n{current}",
show_viz=False,
add_output_info=False,
)
@@ -155,16 +159,45 @@ def test_out_ctx_deconly(deconly_model: GPT2LMHeadModel):
],
output_current="20 pairs of legs.",
output_current_tokens=["20", "Ä pairs", "Ä of", "Ä legs", "."],
- cti_scores=[4.53, 1.33, 0.43, 0.74, 0.93],
+ cti_scores=[0.77, 1.39, 0.54, 0.48, 0.91],
cci_scores=[
CCIOutput(
- cti_idx=0,
- cti_token="20 ā Ä 20",
- cti_score=4.53,
- contextual_output="Question: How many pairs of legs do 10 horses have?\n\nLet's think step by step:\n1. A horse has 4 legs.\n2. 10 horses have 40 legs.\n3. 40 legs make 20 pairs of legs.\n\nAnswer:\n20",
- contextless_output="Question: How many pairs of legs do 10 horses have?\n",
+ cti_idx=1,
+ cti_token="Ä pairs",
+ cti_score=1.39,
+ contextual_output="Question: How many pairs of legs do 10 horses have?\nLet's think step by step:\n1. A horse has 4 legs.\n2. 10 horses have 40 legs.\n3. 40 legs make 20 pairs of legs.\n\nAnswer:\n20 pairs",
+ contextless_output="Question: How many pairs of legs do 10 horses have?\nAnswer:\n20 horses",
input_context_scores=None,
- output_context_scores=[0.0] * 28,
+ output_context_scores=[
+ 0.1,
+ 0.1,
+ 0.06,
+ 0.19,
+ 0.05,
+ 0.07,
+ 0.17,
+ 0.1,
+ 0.08,
+ 0.07,
+ 0.11,
+ 0.22,
+ 0.44,
+ 0.07,
+ 0.15,
+ 0.17,
+ 0.12,
+ 0.13,
+ 0.06,
+ 0.14,
+ 0.11,
+ 0.19,
+ 0.16,
+ 0.34,
+ 1.33,
+ 0.04,
+ 0.13,
+ 0.07,
+ ],
),
],
info=None,
@@ -180,6 +213,7 @@ def test_in_out_ctx_deconly(deconly_model: GPT2LMHeadModel):
input_context_text="George was sick yesterday.",
input_current_text="His colleagues asked him if",
output_context_text="something was wrong. He said",
+ output_current_text="he was fine.",
attributed_fn="contrast_prob_diff",
show_viz=False,
add_output_info=False,
diff --git a/tests/data/test_aggregator.py b/tests/data/test_aggregator.py
index eb5086ca..f7e7c3e5 100644
--- a/tests/data/test_aggregator.py
+++ b/tests/data/test_aggregator.py
@@ -39,14 +39,14 @@ def test_sequence_attribution_aggregator(saliency_mt_model: HuggingfaceEncoderDe
)
seqattr = out.sequence_attributions[0]
assert seqattr.source_attributions.shape == (6, 7, 512)
- assert seqattr.target_attributions.shape == (7, 7, 512)
+ assert seqattr.target_attributions.shape == (8, 7, 512)
assert seqattr.step_scores["probability"].shape == (7,)
for i, step in enumerate(out.step_attributions):
assert step.source_attributions.shape == (1, 6, 512)
assert step.target_attributions.shape == (1, i + 1, 512)
out_agg = seqattr.aggregate()
assert out_agg.source_attributions.shape == (6, 7)
- assert out_agg.target_attributions.shape == (7, 7)
+ assert out_agg.target_attributions.shape == (8, 7)
assert out_agg.step_scores["probability"].shape == (7,)
@@ -56,9 +56,9 @@ def test_continuous_span_aggregator(saliency_mt_model: HuggingfaceEncoderDecoder
)
seqattr = out.sequence_attributions[0]
out_agg = seqattr.aggregate(ContiguousSpanAggregator, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)])
- assert out_agg.source_attributions.shape == (5, 4, 512)
- assert out_agg.target_attributions.shape == (4, 4, 512)
- assert out_agg.step_scores["probability"].shape == (4,)
+ assert out_agg.source_attributions.shape == (5, 5, 512)
+ assert out_agg.target_attributions.shape == (5, 5, 512)
+ assert out_agg.step_scores["probability"].shape == (5,)
def test_span_aggregator_with_prefix(saliency_gpt_model: HuggingfaceDecoderOnlyModel):
@@ -76,14 +76,14 @@ def test_aggregator_pipeline(saliency_mt_model: HuggingfaceEncoderDecoderModel):
seqattr = out.sequence_attributions[0]
squeezesum = AggregatorPipeline([ContiguousSpanAggregator, SequenceAttributionAggregator])
out_agg_squeezesum = seqattr.aggregate(squeezesum, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)])
- assert out_agg_squeezesum.source_attributions.shape == (5, 4)
- assert out_agg_squeezesum.target_attributions.shape == (4, 4)
- assert out_agg_squeezesum.step_scores["probability"].shape == (4,)
+ assert out_agg_squeezesum.source_attributions.shape == (5, 5)
+ assert out_agg_squeezesum.target_attributions.shape == (5, 5)
+ assert out_agg_squeezesum.step_scores["probability"].shape == (5,)
sumsqueeze = AggregatorPipeline([SequenceAttributionAggregator, ContiguousSpanAggregator])
out_agg_sumsqueeze = seqattr.aggregate(sumsqueeze, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)])
- assert out_agg_sumsqueeze.source_attributions.shape == (5, 4)
- assert out_agg_sumsqueeze.target_attributions.shape == (4, 4)
- assert out_agg_sumsqueeze.step_scores["probability"].shape == (4,)
+ assert out_agg_sumsqueeze.source_attributions.shape == (5, 5)
+ assert out_agg_sumsqueeze.target_attributions.shape == (5, 5)
+ assert out_agg_sumsqueeze.step_scores["probability"].shape == (5,)
assert not torch.allclose(out_agg_squeezesum.source_attributions, out_agg_sumsqueeze.source_attributions)
assert not torch.allclose(out_agg_squeezesum.target_attributions, out_agg_sumsqueeze.target_attributions)
# Named indexing version
@@ -91,12 +91,12 @@ def test_aggregator_pipeline(saliency_mt_model: HuggingfaceEncoderDecoderModel):
named_sumsqueeze = ["scores", "spans"]
out_agg_squeezesum_named = seqattr.aggregate(named_squeezesum, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)])
out_agg_sumsqueeze_named = seqattr.aggregate(named_sumsqueeze, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)])
- assert out_agg_squeezesum_named.source_attributions.shape == (5, 4)
- assert out_agg_squeezesum_named.target_attributions.shape == (4, 4)
- assert out_agg_squeezesum_named.step_scores["probability"].shape == (4,)
- assert out_agg_sumsqueeze_named.source_attributions.shape == (5, 4)
- assert out_agg_sumsqueeze_named.target_attributions.shape == (4, 4)
- assert out_agg_sumsqueeze_named.step_scores["probability"].shape == (4,)
+ assert out_agg_squeezesum_named.source_attributions.shape == (5, 5)
+ assert out_agg_squeezesum_named.target_attributions.shape == (5, 5)
+ assert out_agg_squeezesum_named.step_scores["probability"].shape == (5,)
+ assert out_agg_sumsqueeze_named.source_attributions.shape == (5, 5)
+ assert out_agg_sumsqueeze_named.target_attributions.shape == (5, 5)
+ assert out_agg_sumsqueeze_named.step_scores["probability"].shape == (5,)
assert not torch.allclose(
out_agg_squeezesum_named.source_attributions, out_agg_sumsqueeze_named.source_attributions
)
diff --git a/tests/fixtures/aggregator.json b/tests/fixtures/aggregator.json
index fc029eec..53123526 100644
--- a/tests/fixtures/aggregator.json
+++ b/tests/fixtures/aggregator.json
@@ -36,6 +36,7 @@
],
"target": "Inseq \u00e8 un framework per l'attribuzione automatica di modelli sequenziali.",
"target_subwords": [
+ "",
"\u2581In",
"se",
"q",
@@ -58,6 +59,7 @@
""
],
"target_merged": [
+ "",
"\u2581Inseq",
"\u2581\u00e8",
"\u2581un",
diff --git a/tests/inference_commons.py b/tests/inference_commons.py
index 3da21068..19810018 100644
--- a/tests/inference_commons.py
+++ b/tests/inference_commons.py
@@ -1,3 +1,6 @@
+import json
+import os
+
from inseq.data import EncoderDecoderBatch
from inseq.utils import json_advanced_load
@@ -9,3 +12,8 @@ def get_example_batches():
dict_batches["batches"] = [batch.torch() for batch in dict_batches["batches"]]
assert all(isinstance(batch, EncoderDecoderBatch) for batch in dict_batches["batches"])
return dict_batches
+
+
+def load_examples() -> dict:
+ file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/huggingface_model.json")
+ return json.load(open(file))
diff --git a/tests/models/test_huggingface_model.py b/tests/models/test_huggingface_model.py
index 993c07ac..72da4a2f 100644
--- a/tests/models/test_huggingface_model.py
+++ b/tests/models/test_huggingface_model.py
@@ -2,8 +2,6 @@
since it is bugged is not very elegant, this will need to be refactored.
"""
-import json
-import os
import pytest
import torch
@@ -15,8 +13,9 @@
from inseq.data import FeatureAttributionOutput, FeatureAttributionSequenceOutput
from inseq.utils import get_default_device
-EXAMPLES_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures/huggingface_model.json")
-EXAMPLES = json.load(open(EXAMPLES_FILE))
+from ..inference_commons import load_examples
+
+EXAMPLES = load_examples()
USE_REFERENCE_TEXT = [True, False]
ATTRIBUTE_TARGET = [True, False]
@@ -275,8 +274,8 @@ def test_attribute_slice_seq2seq(saliency_mt_model):
assert ex2.attr_pos_start == len(ex2.target)
assert ex2.attr_pos_end == len(ex2.target)
assert ex2.source_attributions.shape[1] == 0 and ex2.target_attributions.shape[1] == 0
- assert ex3.attr_pos_start == 12
- assert ex3.attr_pos_end == 15
+ assert ex3.attr_pos_start == 13
+ assert ex3.attr_pos_end == 16
assert ex1.source_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start
assert ex1.target_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start
assert ex1.target_attributions.shape[0] == ex1.attr_pos_end
@@ -303,12 +302,12 @@ def test_attribute_decoder(saliency_gpt2_model):
assert ex1.target_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start
assert ex1.target_attributions.shape[0] == ex1.attr_pos_end
# Empty attributions outputs have start and end set to seq length
- assert ex2.attr_pos_start == 17
- assert ex2.attr_pos_end == 22
+ assert ex2.attr_pos_start == 9
+ assert ex2.attr_pos_end == 14
assert ex2.target_attributions.shape[1] == ex2.attr_pos_end - ex2.attr_pos_start
assert ex2.target_attributions.shape[0] == ex2.attr_pos_end
- assert ex3.attr_pos_start == 17
- assert ex3.attr_pos_end == 22
+ assert ex3.attr_pos_start == 12
+ assert ex3.attr_pos_end == 17
assert ex3.target_attributions.shape[1] == ex3.attr_pos_end - ex3.attr_pos_start
assert ex3.target_attributions.shape[0] == ex3.attr_pos_end
assert out.info["attr_pos_start"] == 17