diff --git a/CHANGELOG.md b/CHANGELOG.md index faeca60e..b4a36ee2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,21 +4,15 @@ ## šŸš€ Features -- Support for multi-GPU attribution ([#238](https://github.com/inseq-team/inseq/pull/238)) -- Added `inseq attribute-context` CLI command to support the [PECoRe framework] for detecting and attributing context reliance in generative LMs ([#237](https://github.com/inseq-team/inseq/pull/237)) - -## šŸ”§ Fixes & Refactoring - -- Fix `ContiguousSpanAggregator` and `SubwordAggregator` edge case of single-step generation ([#247](https://github.com/inseq-team/inseq/pull/247)) -- Move tensors to CPU right away in the forward pass to avoid OOM when cloning ([#245](https://github.com/inseq-team/inseq/pull/245)) -- Fix `remap_from_filtered` behavior on sequence_scores tensors. ([#245](https://github.com/inseq-team/inseq/pull/245)) -- Use torch-native padding when converting lists of `FeatureAttributionStepOutput` to `FeatureAttributionSequenceOutput` in `get_sequences_from_batched_steps`. ([#245](https://github.com/inseq-team/inseq/pull/245)) -- Bump `ruff` version ([#245](https://github.com/inseq-team/inseq/pull/245)) -- Drop `poetry` in favor of [`uv`](https://github.com/astral-sh/uv) to accelerate package installation and simplify config in `pyproject.toml`. ([#249](https://github.com/inseq-team/inseq/pull/249)) -- Drop `darglint` in favor of `pydoclint`. ([#249](https://github.com/inseq-team/inseq/pull/249)) -- Replace Arxiv with ACL Anthology badge in `README`. ([#249](https://github.com/inseq-team/inseq/pull/249)) -- Add first version of `CHANGELOG.md` ([#249](https://github.com/inseq-team/inseq/pull/249)) -- Added multithread support for running tests using `pytest-xdist` +- Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM` to model config. + +## šŸ”§ Fixes and Refactoring + +- Fix the issue in the attention implementation from [#268](https://github.com/inseq-team/inseq/issues/268) where non-terminal position in the tensor were set to nan if they were 0s ([#269](https://github.com/inseq-team/inseq/pull/269)). + +- Fix the pad token in cases where it is not specified by default in the loaded model (e.g. for Qwen models) ([#269](https://github.com/inseq-team/inseq/pull/269)). + +- Fix bug reported in [#266](https://github.com/inseq-team/inseq/issues/266) making `value_zeroing` unusable for SDPA attention. This enables using the method on models using SDPA attention as default (e.g. `GemmaForCausalLM`) without passing `model_kwargs={'attn_implementation': 'eager'}` ([#267](https://github.com/inseq-team/inseq/pull/267)). ## šŸ“ Documentation and Tutorials @@ -26,4 +20,4 @@ ## šŸ’„ Breaking Changes -*No changes* +*No changes* \ No newline at end of file diff --git a/Makefile b/Makefile index f0eb967b..a03222c5 100644 --- a/Makefile +++ b/Makefile @@ -60,7 +60,7 @@ install-dev: .PHONY: install-ci install-ci: - make uv-activate && uv pip install -e .[lint] + make uv-activate && uv pip install -r requirements-dev.txt .PHONY: update-deps update-deps: diff --git a/README.md b/README.md index 79aa0563..583d8a88 100644 --- a/README.md +++ b/README.md @@ -147,6 +147,10 @@ Use the `inseq.list_feature_attribution_methods` function to list all available - `lime`: ["Why Should I Trust You?": Explaining the Predictions of Any Classifier](https://arxiv.org/abs/1602.04938) (Ribeiro et al., 2016) +- `value_zeroing`: [Quantifying Context Mixing in Transformers](https://aclanthology.org/2023.eacl-main.245/) (Mohebbi et al. 2023) + +- `reagent`: [ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models](https://arxiv.org/abs/2402.00794) (Zhao et al., 2024) + #### Step functions Step functions are used to extract custom scores from the model at each step of the attribution process with the `step_scores` argument in `model.attribute`. They can also be used as targets for attribution methods relying on model outputs (e.g. gradient-based methods) by passing them as the `attributed_fn` argument. The following step functions are currently supported: @@ -301,7 +305,10 @@ If you use Inseq in your research we suggest to include a mention to the specifi ## Research using Inseq -Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below. If you know more, please let us know or submit a pull request (*last updated: February 2024*). +Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below. + +> [!TIP] +> Last update: April 2024. Please open a pull request to add your publication to the list.
2023 @@ -322,6 +329,7 @@ Inseq has been used in various research projects. A list of known publications t
  1. LLMCheckup: Conversational Examination of Large Language Models via Interpretability Tools (Wang et al., 2024)
  2. ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models (Zhao et al., 2024)
  3. +
  4. Revisiting subword tokenization: A case study on affixal negation in large language models (Truong et al., 2024)
diff --git a/docs/source/conf.py b/docs/source/conf.py index 12d27a66..d2805c1e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,13 +21,13 @@ # -- Project information ----------------------------------------------------- project = "inseq" -copyright = "2023, The Inseq Team, Licensed under the Apache License, Version 2.0" +copyright = "2024 , The Inseq Team, Licensed under the Apache License, Version 2.0" author = "The Inseq Team" # The short X.Y version -version = "0.6" +version = "0.7" # The full version, including alpha/beta/rc tags -release = "0.6.0.dev0" +release = "0.7.0.dev0" # Prefix link to point to master, comment this during version release and uncomment below line diff --git a/docs/source/main_classes/cli.rst b/docs/source/main_classes/cli.rst index 1793360c..8230e020 100644 --- a/docs/source/main_classes/cli.rst +++ b/docs/source/main_classes/cli.rst @@ -23,7 +23,7 @@ Three commands are supported: - ``inseq attribute-dataset``: Extends ``attribute`` to full dataset using Hugging Face ``datasets.load_dataset`` API. -- ``inseq attribute-context``: Detects and attribute context dependence for generation tasks using the approach of `Sarti et al. (2023) `__. +- ``inseq attribute-context``: Detects and attribute context dependence for generation tasks using the approach of `Sarti et al. (2023) `__. ``attribute`` ----------------------------------------------------------------------------------------------------------------------- @@ -47,6 +47,6 @@ The ``attribute-dataset`` command extends the ``attribute`` command to full data ----------------------------------------------------------------------------------------------------------------------- The ``attribute-context`` command detects and attributes context dependence for generation tasks using the approach of -`Sarti et al. (2023) `__. The command takes the following arguments: +`Sarti et al. (2023) `__. The command takes the following arguments: -.. autoclass:: inseq.commands.attribute_context.attribute_context_args.AttributeContextArgs \ No newline at end of file +.. autoclass:: inseq.commands.attribute_context.attribute_context_args.AttributeContextArgs diff --git a/docs/source/main_classes/feature_attribution.rst b/docs/source/main_classes/feature_attribution.rst index 174b405c..d7c4f5fc 100644 --- a/docs/source/main_classes/feature_attribution.rst +++ b/docs/source/main_classes/feature_attribution.rst @@ -17,7 +17,7 @@ Attribution Methods .. autoclass:: inseq.attr.FeatureAttribution :members: -Gradient Attribution Methods +Gradient-based Attribution Methods ----------------------------------------------------------------------------------------------------------------------- .. autoclass:: inseq.attr.feat.GradientAttributionRegistry @@ -67,7 +67,7 @@ Layer Attribution Methods :members: -Attention Attribution Methods +Internals-based Attribution Methods ----------------------------------------------------------------------------------------------------------------------- .. autoclass:: inseq.attr.feat.InternalsAttributionRegistry @@ -76,3 +76,39 @@ Attention Attribution Methods .. autoclass:: inseq.attr.feat.AttentionWeightsAttribution :members: + +Perturbation-based Attribution Methods +----------------------------------------------------------------------------------------------------------------------- + +.. autoclass:: inseq.attr.feat.PerturbationAttributionRegistry + :members: + +.. autoclass:: inseq.attr.feat.OcclusionAttribution + :members: + +.. autoclass:: inseq.attr.feat.LimeAttribution + :members: + +.. autoclass:: inseq.attr.feat.ValueZeroingAttribution + :members: + +.. autoclass:: inseq.attr.feat.ReagentAttribution + :members: + + .. automethod:: __init__ + +.. code:: python + + import inseq + + model = inseq.load_model( + "gpt2-medium", + "reagent", + keep_top_n=5, + stopping_condition_top_k=3, + replacing_ratio=0.3, + max_probe_steps=3000, + num_probes=8 + ) + out = model.attribute("Super Mario Land is a game that developed by") + out.show() diff --git a/inseq/attr/feat/__init__.py b/inseq/attr/feat/__init__.py index cc07f530..7a81014f 100644 --- a/inseq/attr/feat/__init__.py +++ b/inseq/attr/feat/__init__.py @@ -17,6 +17,9 @@ from .perturbation_attribution import ( LimeAttribution, OcclusionAttribution, + PerturbationAttributionRegistry, + ReagentAttribution, + ValueZeroingAttribution, ) __all__ = [ @@ -39,4 +42,7 @@ "OcclusionAttribution", "LimeAttribution", "SequentialIntegratedGradientsAttribution", + "ValueZeroingAttribution", + "PerturbationAttributionRegistry", + "ReagentAttribution", ] diff --git a/inseq/attr/feat/attribution_utils.py b/inseq/attr/feat/attribution_utils.py index 8da4f899..a9679845 100644 --- a/inseq/attr/feat/attribution_utils.py +++ b/inseq/attr/feat/attribution_utils.py @@ -144,11 +144,15 @@ def extract_args( def get_source_target_attributions( attr: Union[StepAttributionTensor, tuple[StepAttributionTensor, StepAttributionTensor]], is_encoder_decoder: bool, + has_sequence_scores: bool = False, ) -> tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]: if isinstance(attr, tuple): if is_encoder_decoder: - return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None) + if has_sequence_scores: + return (attr[0], attr[1], attr[2]) + else: + return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None) else: - return (None, attr[0]) + return (None, None, attr[0]) if has_sequence_scores else (None, attr[0]) else: return (attr, None) if is_encoder_decoder else (None, attr) diff --git a/inseq/attr/feat/feature_attribution.py b/inseq/attr/feat/feature_attribution.py index 250cd700..3d832bc3 100644 --- a/inseq/attr/feat/feature_attribution.py +++ b/inseq/attr/feat/feature_attribution.py @@ -114,6 +114,7 @@ def __init__(self, attribution_model: "AttributionModel", hook_to_model: bool = self.use_hidden_states: bool = False self.use_predicted_target: bool = True self.use_model_config: bool = False + self.is_final_step_method: bool = False if hook_to_model: self.hook(**kwargs) @@ -272,6 +273,35 @@ def _run_compatibility_checks(self, attributed_fn) -> None: " method." ) + @staticmethod + def _build_multistep_output_from_single_step( + single_step_output: FeatureAttributionStepOutput, + attr_pos_start: int, + attr_pos_end: int, + ) -> list[FeatureAttributionStepOutput]: + if single_step_output.step_scores: + raise ValueError("step_scores are not supported for final step attribution methods.") + num_seq = len(single_step_output.prefix) + steps = [] + for pos_idx in range(attr_pos_start, attr_pos_end): + step_output = single_step_output.clone_empty() + step_output.source = single_step_output.source + step_output.prefix = [single_step_output.prefix[seq_idx][:pos_idx] for seq_idx in range(num_seq)] + step_output.target = ( + single_step_output.target + if pos_idx == attr_pos_end - 1 + else [[single_step_output.prefix[seq_idx][pos_idx]] for seq_idx in range(num_seq)] + ) + if single_step_output.source_attributions is not None: + step_output.source_attributions = single_step_output.source_attributions[:, :, pos_idx - 1] + if single_step_output.target_attributions is not None: + step_output.target_attributions = single_step_output.target_attributions[:, :pos_idx, pos_idx - 1] + single_step_output.step_scores = {} + if single_step_output.sequence_scores is not None: + step_output.sequence_scores = single_step_output.sequence_scores + steps.append(step_output) + return steps + def format_contrastive_targets( self, target_sequences: TextSequences, @@ -416,9 +446,9 @@ def attribute( target_lengths=targets_lengths, method_name=self.method_name, show=show_progress, - pretty=pretty_progress, + pretty=False if self.is_final_step_method else pretty_progress, attr_pos_start=attr_pos_start, - attr_pos_end=attr_pos_end, + attr_pos_end=1 if self.is_final_step_method else attr_pos_end, ) whitespace_indexes = find_char_indexes(sequences.targets, " ") attribution_outputs = [] @@ -427,6 +457,8 @@ def attribute( # Attribution loop for generation for step in range(attr_pos_start, iter_pos_end): + if self.is_final_step_method and step != iter_pos_end - 1: + continue tgt_ids, tgt_mask = batch.get_step_target(step, with_attention=True) step_output = self.filtered_attribute_step( batch[:step], @@ -450,7 +482,7 @@ def attribute( contrast_targets_alignments=contrast_targets_alignments, ) attribution_outputs.append(step_output) - if pretty_progress: + if pretty_progress and not self.is_final_step_method: tgt_tokens = batch.target_tokens skipped_prefixes = tok2string(self.attribution_model, tgt_tokens, end=attr_pos_start) attributed_sentences = tok2string(self.attribution_model, tgt_tokens, attr_pos_start, step + 1) @@ -464,19 +496,24 @@ def attribute( skipped_suffixes, whitespace_indexes, show=show_progress, - pretty=pretty_progress, + pretty=True, ) else: - update_progress_bar(pbar, show=show_progress, pretty=pretty_progress) + update_progress_bar(pbar, show=show_progress, pretty=False) end = datetime.now() - close_progress_bar(pbar, show=show_progress, pretty=pretty_progress) + close_progress_bar(pbar, show=show_progress, pretty=False if self.is_final_step_method else pretty_progress) batch.detach().to("cpu") + if self.is_final_step_method: + attribution_outputs = self._build_multistep_output_from_single_step( + attribution_outputs[0], + attr_pos_start=attr_pos_start, + attr_pos_end=iter_pos_end, + ) out = FeatureAttributionOutput( sequence_attributions=FeatureAttributionSequenceOutput.from_step_attributions( attributions=attribution_outputs, tokenized_target_sentences=target_tokens_with_ids, - pad_id=self.attribution_model.pad_token, - has_bos_token=self.attribution_model.is_encoder_decoder, + pad_token=self.attribution_model.pad_token, attr_pos_end=attr_pos_end, ), step_attributions=attribution_outputs if output_step_attributions else None, @@ -593,7 +630,7 @@ def filtered_attribute_step( step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args).to("cpu") # Reinsert finished sentences if target_attention_mask is not None and is_filtered: - step_output.remap_from_filtered(target_attention_mask, orig_batch) + step_output.remap_from_filtered(target_attention_mask, orig_batch, self.is_final_step_method) step_output = step_output.detach().to("cpu") return step_output diff --git a/inseq/attr/feat/internals_attribution.py b/inseq/attr/feat/internals_attribution.py index 9c6e8923..f3f0479d 100644 --- a/inseq/attr/feat/internals_attribution.py +++ b/inseq/attr/feat/internals_attribution.py @@ -17,11 +17,10 @@ from typing import Any, Optional from captum._utils.typing import TensorOrTupleOfTensorsGeneric -from captum.attr._utils.attribution import Attribution from ...data import MultiDimensionalFeatureAttributionStepOutput from ...utils import Registry -from ...utils.typing import MultiLayerMultiUnitScoreTensor +from ...utils.typing import InseqAttribution, MultiLayerMultiUnitScoreTensor from .feature_attribution import FeatureAttribution logger = logging.getLogger(__name__) @@ -38,7 +37,7 @@ class AttentionWeightsAttribution(InternalsAttributionRegistry): method_name = "attention" - class AttentionWeights(Attribution): + class AttentionWeights(InseqAttribution): @staticmethod def has_convergence_delta() -> bool: return False @@ -74,9 +73,9 @@ def attribute( :class:`~inseq.data.MultiDimensionalFeatureAttributionStepOutput`: A step output containing attention weights for each layer and head, with shape :obj:`(batch_size, seq_len, n_layers, n_heads)`. """ - # We adopt the format [batch_size, sequence_length, num_layers, num_heads] + # We adopt the format [batch_size, sequence_length, sequence_length, num_layers, num_heads] # for consistency with other multi-unit methods (e.g. gradient attribution) - decoder_self_attentions = decoder_self_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2) + decoder_self_attentions = decoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2) if self.forward_func.is_encoder_decoder: sequence_scores = {} if len(inputs) > 1: @@ -85,10 +84,11 @@ def attribute( target_attributions = None sequence_scores["decoder_self_attentions"] = decoder_self_attentions sequence_scores["encoder_self_attentions"] = ( - encoder_self_attentions.to("cpu").clone().permute(0, 3, 4, 1, 2) + encoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2) ) + cross_attentions = cross_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2) return MultiDimensionalFeatureAttributionStepOutput( - source_attributions=cross_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2), + source_attributions=cross_attentions, target_attributions=target_attributions, sequence_scores=sequence_scores, _num_dimensions=2, # num_layers, num_heads @@ -106,6 +106,8 @@ def __init__(self, attribution_model, **kwargs): self.use_attention_weights = True # Does not rely on predicted output (i.e. decoding strategy agnostic) self.use_predicted_target = False + # Needs only the final generation step to extract scores + self.is_final_step_method = True self.method = self.AttentionWeights(attribution_model) def attribute_step( diff --git a/inseq/attr/feat/ops/__init__.py b/inseq/attr/feat/ops/__init__.py index 0e533b4a..7a4d726f 100644 --- a/inseq/attr/feat/ops/__init__.py +++ b/inseq/attr/feat/ops/__init__.py @@ -1,13 +1,17 @@ from .discretized_integrated_gradients import DiscretetizedIntegratedGradients from .lime import Lime from .monotonic_path_builder import MonotonicPathBuilder +from .reagent import Reagent from .rollout import rollout_fn from .sequential_integrated_gradients import SequentialIntegratedGradients +from .value_zeroing import ValueZeroing __all__ = [ "DiscretetizedIntegratedGradients", "MonotonicPathBuilder", + "ValueZeroing", "Lime", + "Reagent", "SequentialIntegratedGradients", "rollout_fn", ] diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py new file mode 100644 index 00000000..080a2131 --- /dev/null +++ b/inseq/attr/feat/ops/reagent.py @@ -0,0 +1,134 @@ +from typing import TYPE_CHECKING, Any, Union + +import torch +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric +from torch import Tensor +from typing_extensions import override + +from ....utils.typing import InseqAttribution +from .reagent_core import ( + AggregateRationalizer, + DeltaProbImportanceScoreEvaluator, + POSTagTokenSampler, + TopKStoppingConditionEvaluator, + UniformTokenReplacer, +) + +if TYPE_CHECKING: + from ....models import HuggingfaceModel + + +class Reagent(InseqAttribution): + r"""Recursive attribution generator (ReAGent) method. + + Measures importance as the drop in prediction probability produced by replacing a token with a plausible + alternative predicted by a LM. + + Reference implementation: + `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models + `__ + + Args: + forward_func (callable): The forward function of the model or any modification of it + keep_top_n (int): If set to a value greater than 0, the top n tokens based on their importance score will be + kept during the prediction inference. If set to 0, the top n will be determined by ``keep_ratio``. + keep_ratio (float): If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be + replaced instead of being kept. + stopping_condition_top_k (int): Threshold indicating that the stop condition achieved when the predicted target + exist in top k predictions + replacing_ratio (float): replacing ratio of tokens for probing + max_probe_steps (int): max_probe_steps + num_probes (int): number of probes in parallel + + Example: + ``` + import inseq + + model = inseq.load_model("gpt2-medium", "reagent", + keep_top_n=5, + stopping_condition_top_k=3, + replacing_ratio=0.3, + max_probe_steps=3000, + num_probes=8 + ) + out = model.attribute("Super Mario Land is a game that developed by") + out.show() + ``` + """ + + def __init__( + self, + attribution_model: "HuggingfaceModel", + keep_top_n: int = 5, + keep_ratio: float = None, + invert_keep: bool = False, + stopping_condition_top_k: int = 3, + replacing_ratio: float = 0.3, + max_probe_steps: int = 3000, + num_probes: int = 16, + ) -> None: + super().__init__(attribution_model) + + model = attribution_model.model + tokenizer = attribution_model.tokenizer + model_name = attribution_model.model_name + + sampler = POSTagTokenSampler(tokenizer=tokenizer, identifier=model_name, device=attribution_model.device) + stopping_condition_evaluator = TopKStoppingConditionEvaluator( + model=model, + sampler=sampler, + top_k=stopping_condition_top_k, + keep_top_n=keep_top_n, + keep_ratio=keep_ratio, + invert_keep=invert_keep, + ) + importance_score_evaluator = DeltaProbImportanceScoreEvaluator( + model=model, + tokenizer=tokenizer, + token_replacer=UniformTokenReplacer(sampler=sampler, ratio=replacing_ratio), + stopping_condition_evaluator=stopping_condition_evaluator, + max_steps=max_probe_steps, + ) + + self.rationalizer = AggregateRationalizer( + importance_score_evaluator=importance_score_evaluator, + batch_size=num_probes, + overlap_threshold=0, + overlap_strict_pos=True, + keep_top_n=keep_top_n, + keep_ratio=keep_ratio, + ) + + @override + def attribute( # type: ignore + self, + inputs: TensorOrTupleOfTensorsGeneric, + _target: TargetType = None, + additional_forward_args: Any = None, + ) -> Union[ + TensorOrTupleOfTensorsGeneric, + tuple[TensorOrTupleOfTensorsGeneric, Tensor], + ]: + """Implement attribute""" + # encoder-decoder + if self.forward_func.is_encoder_decoder: + # with target-side attribution + if len(inputs) > 1: + self.rationalizer( + additional_forward_args[0], additional_forward_args[2], additional_forward_args[1], True + ) + mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0) + res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2]) + return ( + res[:, : additional_forward_args[0].shape[1], :], + res[:, additional_forward_args[0].shape[1] :, :], + ) + # source-side only + else: + self.rationalizer(additional_forward_args[1], additional_forward_args[3], additional_forward_args[2]) + # decoder-only + self.rationalizer(additional_forward_args[0], additional_forward_args[1]) + mean_importance_score = torch.unsqueeze(self.rationalizer.mean_importance_score, 0) + res = torch.unsqueeze(mean_importance_score, 2).repeat(1, 1, inputs[0].shape[2]) + return (res,) diff --git a/inseq/attr/feat/ops/reagent_core/__init__.py b/inseq/attr/feat/ops/reagent_core/__init__.py new file mode 100644 index 00000000..13917c00 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/__init__.py @@ -0,0 +1,13 @@ +from .importance_score_evaluator import DeltaProbImportanceScoreEvaluator +from .rationalizer import AggregateRationalizer +from .stopping_condition_evaluator import TopKStoppingConditionEvaluator +from .token_replacer import UniformTokenReplacer +from .token_sampler import POSTagTokenSampler + +__all__ = [ + "DeltaProbImportanceScoreEvaluator", + "AggregateRationalizer", + "TopKStoppingConditionEvaluator", + "UniformTokenReplacer", + "POSTagTokenSampler", +] diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py new file mode 100644 index 00000000..adc79b7b --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from jaxtyping import Float +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer +from typing_extensions import override + +from .....utils.typing import IdsTensor, MultipleScoresPerStepTensor, TargetIdsTensor +from .stopping_condition_evaluator import StoppingConditionEvaluator +from .token_replacer import TokenReplacer + + +class BaseImportanceScoreEvaluator(ABC): + """Importance Score Evaluator""" + + def __init__(self, model: AutoModelForCausalLM | AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer) -> None: + """Base Constructor + + Args: + model: A Huggingface AutoModelForCausalLM or AutoModelForSeq2SeqLM model + tokenizer: A Huggingface AutoTokenizer + """ + self.model = model + self.tokenizer = tokenizer + self.importance_score = None + + @abstractmethod + def __call__( + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> MultipleScoresPerStepTensor: + """Evaluate importance score of input sequence + + Args: + input_ids: input sequence [batch, sequence] + target_id: target token [batch] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + importance_score: evaluated importance score for each token in the input [batch, sequence] + + """ + raise NotImplementedError() + + +class DeltaProbImportanceScoreEvaluator(BaseImportanceScoreEvaluator): + """Importance Score Evaluator""" + + @override + def __init__( + self, + model: AutoModelForCausalLM | AutoModelForSeq2SeqLM, + tokenizer: AutoTokenizer, + token_replacer: TokenReplacer, + stopping_condition_evaluator: StoppingConditionEvaluator, + max_steps: float, + ) -> None: + """Constructor + + Args: + model: A Huggingface AutoModelForCausalLM or AutoModelForSeq2SeqLM model + tokenizer: A Huggingface AutoTokenizer + token_replacer: A TokenReplacer + stopping_condition_evaluator: A StoppingConditionEvaluator + """ + super().__init__(model, tokenizer) + self.token_replacer = token_replacer + self.stopping_condition_evaluator = stopping_condition_evaluator + self.max_steps = max_steps + self.importance_score = None + self.num_steps = 0 + + def update_importance_score( + self, + logit_importance_score: MultipleScoresPerStepTensor, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + prob_original_target: Float[torch.Tensor, "batch_size 1"], + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> MultipleScoresPerStepTensor: + """Update importance score by one step + + Args: + logit_importance_score: Current importance score in logistic scale [batch, sequence] + input_ids: input tensor [batch, sequence] + target_id: target tensor [batch] + prob_original_target: predictive probability of the target on the original sequence [batch, 1] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + logit_importance_score: updated importance score in logistic scale [batch, sequence] + """ + # Randomly replace a set of tokens R to form a new sequence \hat{y_{1...t}} + if not attribute_target: + input_ids_replaced, mask_replacing = self.token_replacer(input_ids) + else: + ids_replaced, mask_replacing = self.token_replacer(torch.cat((input_ids, decoder_input_ids), 1)) + input_ids_replaced = ids_replaced[:, : input_ids.shape[1]] + decoder_input_ids_replaced = ids_replaced[:, input_ids.shape[1] :] + + logging.debug(f"Replacing mask: { mask_replacing }") + logging.debug( + f"Replaced sequence: { [[ self.tokenizer.decode(seq[i]) for i in range(input_ids_replaced.shape[1]) ] for seq in input_ids_replaced ] }" + ) + + # Inference \hat{p^{(y)}} = p(y_{t+1}|\hat{y_{1...t}}) + kwargs = {"input_ids": input_ids_replaced} + if decoder_input_ids is not None: + kwargs["decoder_input_ids"] = decoder_input_ids_replaced if attribute_target else decoder_input_ids + logits_replaced = self.model(**kwargs)["logits"] + prob_replaced_target = torch.softmax(logits_replaced[:, -1, :], -1)[:, target_id] + + # Compute changes delta = p^{(y)} - \hat{p^{(y)}} + delta_prob_target = prob_original_target - prob_replaced_target + logging.debug(f"likelihood delta: { delta_prob_target }") + + # Update importance scores based on delta (magnitude) and replacement (direction) + delta_score = mask_replacing * delta_prob_target + ~mask_replacing * -delta_prob_target + # TODO: better solution? + # Rescaling from [-1, 1] to [0, 1] before logit function + logit_delta_score = torch.logit(delta_score * 0.5 + 0.5) + logit_importance_score = logit_importance_score + logit_delta_score + logging.debug(f"Updated importance score: { torch.softmax(logit_importance_score, -1) }") + return logit_importance_score + + @override + def __call__( + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> MultipleScoresPerStepTensor: + """Evaluate importance score of input sequence + + Args: + input_ids: input sequence [batch, sequence] + target_id: target token [batch] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + importance_score: evaluated importance score for each token in the input [batch, sequence] + """ + self.stop_mask = torch.zeros([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) + + # Inference p^{(y)} = p(y_{t+1}|y_{1...t}) + if decoder_input_ids is None: + logits_original = self.model(input_ids)["logits"] + else: + logits_original = self.model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)["logits"] + + prob_original_target = torch.softmax(logits_original[:, -1, :], -1)[:, target_id] + + # Initialize importance score s for each token in the sequence y_{1...t} + if not attribute_target: + logit_importance_score = torch.rand(input_ids.shape, device=input_ids.device) + else: + logit_importance_score = torch.rand( + (input_ids.shape[0], input_ids.shape[1] + decoder_input_ids.shape[1]), device=input_ids.device + ) + logging.debug(f"Initialize importance score -> { torch.softmax(logit_importance_score, -1) }") + + # TODO: limit max steps + self.num_steps = 0 + while self.num_steps < self.max_steps: + self.num_steps += 1 + # Update importance score + logit_importance_score_update = self.update_importance_score( + logit_importance_score, input_ids, target_id, prob_original_target, decoder_input_ids, attribute_target + ) + logit_importance_score = ( + ~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update + + torch.unsqueeze(self.stop_mask, 1) * logit_importance_score + ) + self.importance_score = torch.softmax(logit_importance_score, -1) + + # Evaluate stop condition + self.stop_mask = self.stop_mask | self.stopping_condition_evaluator( + input_ids, target_id, self.importance_score, decoder_input_ids, attribute_target + ) + if torch.prod(self.stop_mask) > 0: + break + + logging.info(f"Importance score evaluated in {self.num_steps} steps.") + return torch.softmax(logit_importance_score, -1) diff --git a/inseq/attr/feat/ops/reagent_core/rationalizer.py b/inseq/attr/feat/ops/reagent_core/rationalizer.py new file mode 100644 index 00000000..ab7c2be8 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/rationalizer.py @@ -0,0 +1,128 @@ +import math +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from jaxtyping import Int64 +from typing_extensions import override + +from .....utils.typing import IdsTensor, TargetIdsTensor +from .importance_score_evaluator import BaseImportanceScoreEvaluator + + +class BaseRationalizer(ABC): + def __init__(self, importance_score_evaluator: BaseImportanceScoreEvaluator) -> None: + super().__init__() + self.importance_score_evaluator = importance_score_evaluator + self.mean_importance_score = None + + @abstractmethod + def __call__( + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> Int64[torch.Tensor, "batch_size other_dims"]: + """Compute rational of a sequence on a target + + Args: + input_ids: The sequence [batch, sequence] (first dimension need to be 1) + target_id: The target [batch] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + pos_top_n: rational position in the sequence [batch, rational_size] + + """ + raise NotImplementedError() + + +class AggregateRationalizer(BaseRationalizer): + """AggregateRationalizer""" + + @override + def __init__( + self, + importance_score_evaluator: BaseImportanceScoreEvaluator, + batch_size: int, + overlap_threshold: int, + overlap_strict_pos: bool = True, + keep_top_n: int = 0, + keep_ratio: float = 0, + ) -> None: + """Constructor + + Args: + importance_score_evaluator: A ImportanceScoreEvaluator + batch_size: Batch size for aggregate + overlap_threshold: Overlap threshold of rational tokens within a batch + overlap_strict_pos: Whether overlap strict to position ot not + keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be + kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by + ``keep_ratio``. + keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + """ + super().__init__(importance_score_evaluator) + self.batch_size = batch_size + self.overlap_threshold = overlap_threshold + self.overlap_strict_pos = overlap_strict_pos + self.keep_top_n = keep_top_n + self.keep_ratio = keep_ratio + assert overlap_strict_pos, "overlap_strict_pos = False is not supported yet" + + @override + @torch.no_grad() + def __call__( + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> Int64[torch.Tensor, "batch_size other_dims"]: + """Compute rational of a sequence on a target + + Args: + input_ids: A tensor of ids of shape [batch, sequence_len] + target_id: A tensor of predicted targets of size [batch] + decoder_input_ids (optional): A tensor of ids representing the decoder input sequence for + ``AutoModelForSeq2SeqLM``, with shape [batch, sequence_len] + attribute_target: whether attribute target for encoder-decoder models + + Return: + pos_top_n: rational position in the sequence [batch, rational_size] + + """ + assert input_ids.shape[0] == 1, "the first dimension of input (batch_size) need to be 1" + batch_input_ids = input_ids.repeat(self.batch_size, 1) + batch_decoder_input_ids = ( + decoder_input_ids.repeat(self.batch_size, 1) if decoder_input_ids is not None else None + ) + batch_importance_score = self.importance_score_evaluator( + batch_input_ids, target_id, batch_decoder_input_ids, attribute_target + ) + importance_score_masked = batch_importance_score * torch.unsqueeze( + self.importance_score_evaluator.stop_mask, -1 + ) + self.mean_importance_score = torch.sum(importance_score_masked, dim=0) / torch.sum( + self.importance_score_evaluator.stop_mask + ) + pos_sorted = torch.argsort(batch_importance_score, dim=-1, descending=True) + top_n = int(math.ceil(self.keep_ratio * input_ids.shape[-1])) if not self.keep_top_n else self.keep_top_n + pos_top_n = pos_sorted[:, :top_n] + self.pos_top_n = pos_top_n + if self.overlap_strict_pos: + count_overlap = torch.bincount(pos_top_n.flatten(), minlength=input_ids.shape[1]) + pos_top_n_overlap = torch.unsqueeze( + torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0 + ) + return pos_top_n_overlap + else: + raise NotImplementedError("overlap_strict_pos = False not been supported yet") + # TODO: Convert back to pos + # token_id_top_n = input_ids[0, pos_top_n] + # count_overlap = torch.bincount(token_id_top_n.flatten(), minlength=input_ids.shape[1]) + # _token_id_top_n_overlap = torch.unsqueeze( + # torch.nonzero(count_overlap >= self.overlap_threshold, as_tuple=True)[0], 0 + # ) diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py new file mode 100644 index 00000000..fd3bb67d --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py @@ -0,0 +1,136 @@ +import logging +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from transformers import AutoModelForCausalLM + +from .....utils.typing import IdsTensor, MultipleScoresPerStepTensor, TargetIdsTensor +from .token_replacer import RankingTokenReplacer +from .token_sampler import TokenSampler + + +class StoppingConditionEvaluator(ABC): + """Base class for Stopping Condition Evaluators""" + + @abstractmethod + def __call__( + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + importance_score: MultipleScoresPerStepTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> TargetIdsTensor: + """Evaluate stop condition according to the specified strategy. + + Args: + input_ids: Input sequence [batch, sequence] + target_id: Target token [batch] + importance_score: Importance score of the input [batch, sequence] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + Boolean flag per sequence signaling whether the stop condition was reached [batch] + + """ + raise NotImplementedError() + + +class TopKStoppingConditionEvaluator(StoppingConditionEvaluator): + """ + Evaluator stopping when target exist among the top k predictions, + while top n tokens based on importance_score are not been replaced. + """ + + def __init__( + self, + model: AutoModelForCausalLM, + sampler: TokenSampler, + top_k: int, + keep_top_n: int = 0, + keep_ratio: float = 0, + invert_keep: bool = False, + ) -> None: + """Constructor for the TopKStoppingConditionEvaluator class. + + Args: + model: A Huggingface ``AutoModelForCausalLM``. + sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object to sample replacement tokens. + top_k: Top K predictions in which the target must be included in order to achieve the stopping condition. + keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be + kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by + ``keep_ratio``. + keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be + replaced instead of being kept. + """ + self.model = model + self.top_k = top_k + self.replacer = RankingTokenReplacer(sampler, keep_top_n, keep_ratio, invert_keep) + + def __call__( + self, + input_ids: IdsTensor, + target_id: TargetIdsTensor, + importance_score: MultipleScoresPerStepTensor, + decoder_input_ids: Optional[IdsTensor] = None, + attribute_target: bool = False, + ) -> TargetIdsTensor: + """Evaluate stop condition + + Args: + input_ids: Input sequence [batch, sequence] + target_id: Target token [batch] + importance_score: Importance score of the input [batch, sequence] + decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + Boolean flag per sequence signaling whether the stop condition was reached [batch] + """ + # Replace tokens with low importance score and then inference \hat{y^{(e)}_{t+1}} + self.replacer.set_score(importance_score) + if not attribute_target: + input_ids_replaced, mask_replacing = self.replacer(input_ids) + else: + ids_replaced, mask_replacing = self.replacer(torch.cat((input_ids, decoder_input_ids), 1)) + input_ids_replaced = ids_replaced[:, : input_ids.shape[1]] + decoder_input_ids_replaced = ids_replaced[:, input_ids.shape[1] :] + + logging.debug(f"Replacing mask based on importance score -> { mask_replacing }") + + # Whether the result \hat{y^{(e)}_{t+1}} consistent with y_{t+1} + assert not input_ids_replaced.requires_grad, "Error: auto-diff engine not disabled" + with torch.no_grad(): + kwargs = {"input_ids": input_ids_replaced} + if decoder_input_ids is not None: + kwargs["decoder_input_ids"] = decoder_input_ids_replaced if attribute_target else decoder_input_ids + logits_replaced = self.model(**kwargs)["logits"] + ids_prediction_sorted = torch.argsort(logits_replaced[:, -1, :], descending=True) + ids_prediction_top_k = ids_prediction_sorted[:, : self.top_k] + match_mask = ids_prediction_top_k == target_id + match_hit = torch.sum(match_mask, dim=-1, dtype=torch.bool) + return match_hit + + +class DummyStoppingConditionEvaluator(StoppingConditionEvaluator): + """ + Stopping Condition Evaluator which stop when target exist in top k predictions, + while top n tokens based on importance_score are not been replaced. + """ + + def __call__(self, input_ids: IdsTensor, **kwargs) -> TargetIdsTensor: + """Evaluate stop condition + + Args: + input_ids: Input sequence [batch, sequence] + target_id: Target token [batch] + importance_score: Importance score of the input [batch, sequence] + attribute_target: whether attribute target for encoder-decoder models + + Return: + Boolean flag per sequence signaling whether the stop condition was reached [batch] + """ + return torch.ones([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device) diff --git a/inseq/attr/feat/ops/reagent_core/token_replacer.py b/inseq/attr/feat/ops/reagent_core/token_replacer.py new file mode 100644 index 00000000..0d889144 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_replacer.py @@ -0,0 +1,111 @@ +import math +from abc import ABC, abstractmethod + +import torch +from typing_extensions import override + +from .....utils.typing import IdsTensor +from .token_sampler import TokenSampler + + +class TokenReplacer(ABC): + """ + Base class for token replacers + + """ + + def __init__(self, sampler: TokenSampler) -> None: + self.sampler = sampler + + @abstractmethod + def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]: + """Replace tokens according to the specified strategy. + + Args: + input: input sequence [batch, sequence] + + Returns: + input_replaced: A replaced sequence [batch, sequence] + replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence] + + """ + raise NotImplementedError() + + +class RankingTokenReplacer(TokenReplacer): + """Replace tokens in a sequence based on top-N ranking""" + + @override + def __init__( + self, sampler: TokenSampler, keep_top_n: int = 0, keep_ratio: float = 0, invert_keep: bool = False + ) -> None: + """Constructor for the RankingTokenReplacer class. + + Args: + sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object for sampling replacement tokens. + keep_top_n: If set to a value greater than 0, the top n tokens based on their importance score will be + kept, and the rest will be flagged for replacement. If set to 0, the top n will be determined by + ``keep_ratio``. + keep_ratio: If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be + replaced instead of being kept. + """ + super().__init__(sampler) + self.keep_top_n = keep_top_n + self.keep_ratio = keep_ratio + self.invert_keep = invert_keep + + def set_score(self, value: torch.Tensor) -> None: + pos_sorted = torch.argsort(value, descending=True) + top_n = int(math.ceil(self.keep_ratio * value.shape[-1])) if not self.keep_top_n else self.keep_top_n + pos_top_n = pos_sorted[..., :top_n] + self.replacement_mask = torch.ones_like(value, device=value.device, dtype=torch.bool).scatter( + -1, pos_top_n, self.invert_keep + ) + + @override + def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]: + """Sample a sequence + + Args: + input: Input sequence of ids of shape [batch, sequence] + + Returns: + input_replaced: A replaced sequence [batch, sequence] + replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence] + """ + token_sampled = self.sampler(input) + input_replaced = input * ~self.replacement_mask + token_sampled * self.replacement_mask + return input_replaced, self.replacement_mask + + +class UniformTokenReplacer(TokenReplacer): + """Replace tokens in a sequence where selecting is base on uniform distribution""" + + @override + def __init__(self, sampler: TokenSampler, ratio: float) -> None: + """Constructor + + Args: + sampler: A :class:`~inseq.attr.feat.ops.reagent_core.TokenSampler` object for sampling replacement tokens. + ratio: Ratio of tokens to replace in the sequence. + """ + super().__init__(sampler) + self.ratio = ratio + + @override + def __call__(self, input: IdsTensor) -> tuple[IdsTensor, IdsTensor]: + """Sample a sequence + + Args: + input: Input sequence of ids of shape [batch, sequence] + + Returns: + input_replaced: A replaced sequence [batch, sequence] + replacement_mask: Boolean mask identifying which token has been replaced [batch, sequence] + """ + sample_uniform = torch.rand(input.shape, device=input.device) + replacement_mask = sample_uniform < self.ratio + token_sampled = self.sampler(input) + input_replaced = input * ~replacement_mask + token_sampled * replacement_mask + return input_replaced, replacement_mask diff --git a/inseq/attr/feat/ops/reagent_core/token_sampler.py b/inseq/attr/feat/ops/reagent_core/token_sampler.py new file mode 100644 index 00000000..7ca41bf2 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_sampler.py @@ -0,0 +1,107 @@ +import logging +from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import Any, Optional, Union + +import torch +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from typing_extensions import override + +from .....utils import INSEQ_ARTIFACTS_CACHE, cache_results, is_nltk_available +from .....utils.typing import IdsTensor + +logger = logging.getLogger(__name__) + + +class TokenSampler(ABC): + """Base class for token samplers""" + + @abstractmethod + def __call__(self, input: IdsTensor, **kwargs) -> IdsTensor: + """Sample tokens according to the specified strategy. + + Args: + input: input tensor [batch, sequence] + + Returns: + token_uniform: A sampled tensor where its shape is the same with the input + """ + raise NotImplementedError() + + +class POSTagTokenSampler(TokenSampler): + """Sample tokens from Uniform distribution on a set of words with the same POS tag.""" + + def __init__( + self, + tokenizer: Union[str, PreTrainedTokenizerBase], + identifier: str = "pos_tag_sampler", + save_cache: bool = True, + overwrite_cache: bool = False, + cache_dir: Path = INSEQ_ARTIFACTS_CACHE / "pos_tag_sampler_cache", + device: Optional[str] = None, + tokenizer_kwargs: Optional[dict[str, Any]] = {}, + ) -> None: + if isinstance(tokenizer, PreTrainedTokenizerBase): + self.tokenizer = tokenizer + else: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs) + cache_filename = cache_dir / f"{identifier.split('/')[-1]}.pkl" + self.pos2ids = self.build_pos_mapping_from_vocab( + cache_dir, + cache_filename, + save_cache, + overwrite_cache, + tokenizer=self.tokenizer, + ) + num_postags = len(self.pos2ids) + self.id2pos = torch.zeros([self.tokenizer.vocab_size], dtype=torch.long, device=device) + for pos_idx, ids in enumerate(self.pos2ids.values()): + self.id2pos[ids] = pos_idx + self.num_ids_per_pos = torch.tensor( + [len(ids) for ids in self.pos2ids.values()], dtype=torch.long, device=device + ) + self.offsets = torch.sum( + torch.tril(torch.ones([num_postags, num_postags], device=device), diagonal=-1) * self.num_ids_per_pos, + dim=-1, + ) + self.compact_idx = torch.cat( + tuple(torch.tensor(v, dtype=torch.long, device=device) for v in self.pos2ids.values()) + ) + + @staticmethod + @cache_results + def build_pos_mapping_from_vocab( + tokenizer: PreTrainedTokenizerBase, + log_every: int = 5000, + ) -> dict[str, list[int]]: + """Build mapping from POS tags to list of token ids from tokenizer's vocabulary.""" + if not is_nltk_available(): + raise ImportError("nltk is required to build POS tag mapping. Please install nltk.") + import nltk + + nltk.download("averaged_perceptron_tagger") + pos2ids = defaultdict(list) + for i in range(tokenizer.vocab_size): + word = tokenizer.decode([i]) + _, tag = nltk.pos_tag([word.strip()])[0] + pos2ids[tag].append(i) + if i % log_every == 0: + logger.info(f"Loading vocab from tokenizer - {i / tokenizer.vocab_size * 100:.2f}%") + return pos2ids + + @override + def __call__(self, input_ids: IdsTensor) -> IdsTensor: + """Sample a tensor + + Args: + input: input tensor [batch, sequence] + + Returns: + token_uniform: A sampled tensor where its shape is the same with the input + """ + input_ids_pos = self.id2pos[input_ids] + sample_uniform = torch.rand(input_ids.shape, device=input_ids.device) + compact_group_idx = (sample_uniform * self.num_ids_per_pos[input_ids_pos] + self.offsets[input_ids_pos]).long() + return self.compact_idx[compact_group_idx] diff --git a/inseq/attr/feat/ops/value_zeroing.py b/inseq/attr/feat/ops/value_zeroing.py new file mode 100644 index 00000000..ed95eb12 --- /dev/null +++ b/inseq/attr/feat/ops/value_zeroing.py @@ -0,0 +1,394 @@ +# Copyright 2023 The Inseq Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from enum import Enum +from types import FrameType +from typing import TYPE_CHECKING, Callable, Optional + +import torch +from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from torch import nn +from torch.utils.hooks import RemovableHandle + +from ....utils import ( + find_block_stack, + get_post_variable_assignment_hook, + recursive_get_submodule, + validate_indices, +) +from ....utils.typing import ( + EmbeddingsTensor, + InseqAttribution, + MultiLayerEmbeddingsTensor, + MultiLayerScoreTensor, + OneOrMoreIndices, + OneOrMoreIndicesDict, +) + +if TYPE_CHECKING: + from ....models import HuggingfaceModel + +logger = logging.getLogger(__name__) + + +class ValueZeroingSimilarityMetric(Enum): + COSINE = "cosine" + EUCLIDEAN = "euclidean" + + +class ValueZeroingModule(Enum): + DECODER = "decoder" + ENCODER = "encoder" + + +class ValueZeroing(InseqAttribution): + """Value Zeroing method for feature attribution. + + Introduced by `Mohebbi et al. (2023) `__ to quantify context mixing inside + Transformer models. The method is based on the observation that context mixing is regulated by the value vectors + of the attention mechanism. The method consists of two steps: + + 1. Zeroing the value vectors of the attention mechanism for a given token index at a given layer of the model. + 2. Computing the similarity between hidden states produced with and without the zeroing operation, and using it + as a measure of context mixing for the given token at the given layer. + + The method is converted into a feature attribution method by allowing for extraction of value zeroing scores at + specific layers, or by aggregating them across layers. + + Attributes: + SIMILARITY_METRICS (:obj:`Dict[str, Callable]`): + Dictionary of available similarity metrics to be used forvcomputing the distance between hidden states + produced with and without the zeroing operation. Converted to distances as 1 - produced values. + forward_func (:obj:`AttributionModel`): + The attribution model to be used for value zeroing. + clean_block_output_states (:obj:`Dict[int, torch.Tensor]`): + Dictionary to store the hidden states produced by the model without the zeroing operation. + corrupted_block_output_states (:obj:`Dict[int, torch.Tensor]`): + Dictionary to store the hidden states produced by the model with the zeroing operation. + """ + + SIMILARITY_METRICS = { + "cosine": nn.CosineSimilarity(dim=-1), + "euclidean": lambda x, y: torch.cdist(x, y, p=2), + } + + def __init__(self, forward_func: "HuggingfaceModel") -> None: + super().__init__(forward_func) + self.clean_block_output_states: dict[int, EmbeddingsTensor] = {} + self.corrupted_block_output_states: dict[int, EmbeddingsTensor] = {} + + @staticmethod + def get_value_zeroing_hook(varname: str = "value") -> Callable[..., None]: + """Returns a hook to zero the value vectors of the attention mechanism. + + Args: + varname (:obj:`str`, optional): The name of the variable containing the value vectors. The variable + is expected to be a 3D tensor of shape (batch_size, num_heads, seq_len) and is retrieved from the + local variables of the execution frame during the forward pass. + """ + + def value_zeroing_forward_mid_hook( + frame: FrameType, + zeroed_token_index: Optional[int] = None, + zeroed_units_indices: Optional[OneOrMoreIndices] = None, + batch_size: int = 1, + ) -> None: + if varname not in frame.f_locals: + raise ValueError( + f"Variable {varname} not found in the local frame." + f"Other variable names: {', '.join(frame.f_locals.keys())}" + ) + # Zeroing value vectors corresponding to the given token index + if zeroed_token_index is not None: + values_size = frame.f_locals[varname].size() + if len(values_size) == 3: # Assume merged shape (bsz * num_heads, seq_len, hidden_size) e.g. Whisper + values = frame.f_locals[varname].view(batch_size, -1, *values_size[1:]) + elif len(values_size) == 4: # Assume per-head shape (bsz, num_heads, seq_len, hidden_size) e.g. GPT-2 + values = frame.f_locals[varname].clone() + else: + raise ValueError( + f"Value vector shape {frame.f_locals[varname].size()} not supported. " + "Supported shapes: (batch_size, num_heads, seq_len, hidden_size) or " + "(batch_size * num_heads, seq_len, hidden_size)" + ) + zeroed_units_indices = validate_indices(values, 1, zeroed_units_indices).to(values.device) + zeroed_token_index = torch.tensor(zeroed_token_index, device=values.device) + # Mask heads corresponding to zeroed units and tokens corresponding to zeroed tokens + values[:, zeroed_units_indices, zeroed_token_index] = 0 + if len(values_size) == 3: + frame.f_locals[varname] = values.view(-1, *values_size[1:]) + elif len(values_size) == 4: + frame.f_locals[varname] = values + + return value_zeroing_forward_mid_hook + + def get_states_extract_and_patch_hook(self, block_idx: int, hidden_state_idx: int = 0) -> Callable[..., None]: + """Returns a hook to extract the produced hidden states (corrupted by value zeroing) + and patch them with pre-computed clean states that will be passed onwards in the model forward. + + Args: + block_idx (:obj:`int`): The idx of the block at which the hook is applied, used to store extracted states. + hidden_state_idx (:obj:`int`, optional): The index of the hidden state in the model output tuple. + """ + + def states_extract_and_patch_forward_hook(module, args, output) -> None: + self.corrupted_block_output_states[block_idx] = output[hidden_state_idx].clone().float().detach().cpu() + + # Rebuild the output tuple patching the clean states at the place of the corrupted ones + output = ( + output[:hidden_state_idx] + + (self.clean_block_output_states[block_idx].to(output[hidden_state_idx].device),) + + output[hidden_state_idx + 1 :] + ) + return output + + return states_extract_and_patch_forward_hook + + @staticmethod + def has_convergence_delta() -> bool: + return False + + def compute_modules_post_zeroing_similarity( + self, + inputs: TensorOrTupleOfTensorsGeneric, + additional_forward_args: TensorOrTupleOfTensorsGeneric, + hidden_states: MultiLayerEmbeddingsTensor, + attention_module_name: str, + attributed_seq_len: Optional[int] = None, + similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value, + mode: str = ValueZeroingModule.DECODER.value, + zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, + min_score_threshold: float = 1e-5, + use_causal_mask: bool = False, + ) -> MultiLayerScoreTensor: + """Given a ``nn.ModuleList``, computes the similarity between the clean and corrupted states for each block. + + Args: + modules (:obj:`nn.ModuleList`): The list of modules to compute the similarity for. + hidden_states (:obj:`MultiLayerEmbeddingsTensor`): The cached hidden states of the modules to use as clean + counterparts when computing the similarity. + attention_module_name (:obj:`str`): The name of the attention module to zero the values for. + attributed_seq_len (:obj:`int`): The length of the sequence to attribute. If not specified, it is assumed + to be the same as the length of the hidden states. + similarity_metric (:obj:`str`): The name of the similarity metric used. Default: "cosine". + mode (:obj:`str`): The mode of the model to compute the similarity for. Default: "decoder". + zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int]]` or :obj:`dict` with :obj:`int` keys and + `Union[int, tuple[int, int], list[int]]` values, optional): The indices of the attention heads + that should be zeroed to compute corrupted states. + - If None, all attention heads across all layers are zeroed. + - If an integer, the same attention head is zeroed across all layers. + - If a tuple of two integers, the attention heads in the range are zeroed across all layers. + - If a list of integers, the attention heads in the list are zeroed across all layers. + - If a dictionary, the keys are the layer indices and the values are the zeroed attention heads for + the corresponding layer. Any missing layer will not be zeroed. + Default: None. + min_score_threshold (:obj:`float`, optional): The minimum score threshold to consider when computing the + similarity. Default: 1e-5. + use_causal_mask (:obj:`bool`, optional): Whether a causal mask is applied to zeroing scores Default: False. + + Returns: + :obj:`MultiLayerScoreTensor`: A tensor of shape ``[batch_size, seq_len, num_layer]`` containing distances + (1 - similarity score) between original and corrupted states for each layer. + """ + if mode == ValueZeroingModule.DECODER.value: + modules: nn.ModuleList = find_block_stack(self.forward_func.get_decoder()) + elif mode == ValueZeroingModule.ENCODER.value: + modules: nn.ModuleList = find_block_stack(self.forward_func.get_encoder()) + else: + raise NotImplementedError(f"Mode {mode} not implemented for value zeroing.") + if attributed_seq_len is None: + attributed_seq_len = hidden_states.size(2) + batch_size = hidden_states.size(0) + generated_seq_len = hidden_states.size(2) + num_layers = len(modules) + + # Store clean hidden states for later use. Starts at 1 since the first element of the modules stack is the + # embedding layer, and we are only interested in the transformer blocks outputs. + self.clean_block_output_states = { + block_idx: hidden_states[:, block_idx + 1, ...].clone().detach().cpu() for block_idx in range(len(modules)) + } + # Scores for every layer of the model + all_scores = torch.ones( + batch_size, num_layers, generated_seq_len, attributed_seq_len, device=hidden_states.device + ) * float("nan") + + # Hooks: + # 1. states_extract_and_patch_hook on the transformer block stores corrupted states and force clean states + # as the output of the block forward pass, i.e. the zeroing is done independently across layers. + # 2. value_zeroing_hook on the attention module performs the value zeroing by replacing the "value" tensor + # during the forward (name is config-dependent) with a zeroed version for the specified token index. + # + # State extraction hooks can be registered only once since they are token-independent + # Skip last block since its states are not used raw, but may have further transformations applied to them + # (e.g. LayerNorm, Dropout). These are extracted separately from the model outputs. + states_extraction_hook_handles: list[RemovableHandle] = [] + for block_idx in range(len(modules) - 1): + states_extract_and_patch_hook = self.get_states_extract_and_patch_hook(block_idx, hidden_state_idx=0) + states_extraction_hook_handles.append( + modules[block_idx].register_forward_hook(states_extract_and_patch_hook) + ) + # Zeroing is done for every token in the sequence separately (O(n) complexity) + for token_idx in range(attributed_seq_len): + value_zeroing_hook_handles: list[RemovableHandle] = [] + # Value zeroing hooks are registered for every token separately since they are token-dependent + for block_idx, block in enumerate(modules): + attention_module = recursive_get_submodule(block, attention_module_name) + if attention_module is None: + raise ValueError(f"Attention module {attention_module_name} not found in block {block_idx}.") + if isinstance(zeroed_units_indices, dict): + if block_idx not in zeroed_units_indices: + continue + zeroed_units_indices_block = zeroed_units_indices[block_idx] + else: + zeroed_units_indices_block = zeroed_units_indices + value_zeroing_hook = get_post_variable_assignment_hook( + module=attention_module, + varname=self.forward_func.config.value_vector, + hook_fn=self.get_value_zeroing_hook(self.forward_func.config.value_vector), + zeroed_token_index=token_idx, + zeroed_units_indices=zeroed_units_indices_block, + batch_size=batch_size, + ) + value_zeroing_hook_handle = attention_module.register_forward_pre_hook(value_zeroing_hook) + value_zeroing_hook_handles.append(value_zeroing_hook_handle) + + # Run forward pass with hooks. Fills self.corrupted_hidden_states with corrupted states across layers + # when zeroing the specified token index. + with torch.no_grad(): + output = self.forward_func.forward_with_output( + *inputs, *additional_forward_args, output_hidden_states=True + ) + # Extract last layer states directly from the model outputs + # This allows us to handle the presence of additional transformations (e.g. LayerNorm, Dropout) + # in the last layer automatically. + corrupted_states_dict = self.forward_func.get_hidden_states_dict(output) + corrupted_decoder_last_hidden_state = ( + corrupted_states_dict[f"{mode}_hidden_states"][:, -1, ...].clone().detach().cpu() + ) + self.corrupted_block_output_states[len(modules) - 1] = corrupted_decoder_last_hidden_state + for handle in value_zeroing_hook_handles: + handle.remove() + for block_idx in range(len(modules)): + similarity_scores = self.SIMILARITY_METRICS[similarity_metric]( + self.clean_block_output_states[block_idx].float(), self.corrupted_block_output_states[block_idx] + ) + if use_causal_mask: + all_scores[:, block_idx, token_idx:, token_idx] = 1 - similarity_scores[:, token_idx:] + else: + all_scores[:, block_idx, :, token_idx] = 1 - similarity_scores + self.corrupted_block_output_states = {} + for handle in states_extraction_hook_handles: + handle.remove() + self.clean_block_output_states = {} + all_scores = torch.where(all_scores < min_score_threshold, torch.zeros_like(all_scores), all_scores) + # Normalize scores to sum to 1 + per_token_sum_score = all_scores.nansum(dim=-1, keepdim=True) + per_token_sum_score[per_token_sum_score == 0] = 1 + all_scores = all_scores / per_token_sum_score + + # Final shape: [batch_size, attributed_seq_len, generated_seq_len, num_layers] + return all_scores.permute(0, 3, 2, 1) + + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + additional_forward_args: TensorOrTupleOfTensorsGeneric, + similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value, + encoder_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, + decoder_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, + cross_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, + encoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None, + decoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None, + output_decoder_self_scores: bool = True, + output_encoder_self_scores: bool = True, + ) -> TensorOrTupleOfTensorsGeneric: + """Perform attribution using the Value Zeroing method. + + Args: + similarity_metric (:obj:`str`, optional): The similarity metric to use for computing the distance between + hidden states produced with and without the zeroing operation. Default: cosine similarity. + zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int]]` or :obj:`dict` with :obj:`int` keys and + `Union[int, tuple[int, int], list[int]]` values, optional): The indices of the attention heads + that should be zeroed to compute corrupted states. + - If None, all attention heads across all layers are zeroed. + - If an integer, the same attention head is zeroed across all layers. + - If a tuple of two integers, the attention heads in the range are zeroed across all layers. + - If a list of integers, the attention heads in the list are zeroed across all layers. + - If a dictionary, the keys are the layer indices and the values are the zeroed attention heads for + the corresponding layer. + + Default: None (all heads are zeroed for every layer). + encoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers + 1, + source_seq_len, hidden_size]`` containing hidden states of the encoder. Available only for + encoder-decoders models. Default: None. + decoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers + 1, + target_seq_len, hidden_size]`` containing hidden states of the decoder. + output_decoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the + decoder self-attention value vectors in encoder-decoder models. Cannot be false for decoder-only, or + if target-side attribution is requested using `attribute_target=True`. Default: True. + output_encoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the + encoder self-attention value vectors in encoder-decoder models. Default: True. + + Returns: + `TensorOrTupleOfTensorsGeneric`: Attribution outputs for source-only or source + target feature attribution + """ + if similarity_metric not in self.SIMILARITY_METRICS: + raise ValueError( + f"Similarity metric {similarity_metric} not available." + f"Available metrics: {','.join(self.SIMILARITY_METRICS.keys())}" + ) + decoder_scores = None + if not self.forward_func.is_encoder_decoder or output_decoder_self_scores or len(inputs) > 1: + decoder_scores = self.compute_modules_post_zeroing_similarity( + inputs=inputs, + additional_forward_args=additional_forward_args, + hidden_states=decoder_hidden_states, + attention_module_name=self.forward_func.config.self_attention_module, + similarity_metric=similarity_metric, + mode=ValueZeroingModule.DECODER.value, + zeroed_units_indices=decoder_zeroed_units_indices, + use_causal_mask=True, + ) + # Encoder-decoder models also perform zeroing on the encoder self-attention and cross-attention values + # Adapted from https://github.com/hmohebbi/ContextMixingASR/blob/master/scoring/valueZeroing.py + if self.forward_func.is_encoder_decoder: + encoder_scores = None + if output_encoder_self_scores: + encoder_scores = self.compute_modules_post_zeroing_similarity( + inputs=inputs, + additional_forward_args=additional_forward_args, + hidden_states=encoder_hidden_states, + attention_module_name=self.forward_func.config.self_attention_module, + similarity_metric=similarity_metric, + mode=ValueZeroingModule.ENCODER.value, + zeroed_units_indices=encoder_zeroed_units_indices, + ) + cross_scores = self.compute_modules_post_zeroing_similarity( + inputs=inputs, + additional_forward_args=additional_forward_args, + hidden_states=decoder_hidden_states, + attributed_seq_len=encoder_hidden_states.size(2), + attention_module_name=self.forward_func.config.cross_attention_module, + similarity_metric=similarity_metric, + mode=ValueZeroingModule.DECODER.value, + zeroed_units_indices=cross_zeroed_units_indices, + ) + return encoder_scores, cross_scores, decoder_scores + elif encoder_zeroed_units_indices is not None or cross_zeroed_units_indices is not None: + logger.warning( + "Zeroing indices for encoder and cross-attentions were specified, but the model is not an " + "encoder-decoder. Use `decoder_zeroed_units_indices` to parametrize zeroing for the decoder module." + ) + return (decoder_scores,) diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py index fbebb780..498093af 100644 --- a/inseq/attr/feat/perturbation_attribution.py +++ b/inseq/attr/feat/perturbation_attribution.py @@ -1,16 +1,20 @@ import logging -from typing import Any +from typing import TYPE_CHECKING, Any from captum.attr import Occlusion from ...data import ( CoarseFeatureAttributionStepOutput, GranularFeatureAttributionStepOutput, + MultiDimensionalFeatureAttributionStepOutput, ) from ...utils import Registry from .attribution_utils import get_source_target_attributions from .gradient_attribution import FeatureAttribution -from .ops import Lime +from .ops import Lime, Reagent, ValueZeroing + +if TYPE_CHECKING: + from ...models import HuggingfaceModel logger = logging.getLogger(__name__) @@ -117,3 +121,165 @@ def attribute_step( target_attributions=out.target_attributions, sequence_scores=out.sequence_scores, ) + + +class ReagentAttribution(PerturbationAttributionRegistry): + """Recursive attribution generator (ReAGent) method. + + Measures importance as the drop in prediction probability produced by replacing a token with a plausible + alternative predicted by a LM. + + Reference implementation: + `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models `__ + """ + + method_name = "reagent" + + def __init__( + self, + attribution_model: "HuggingfaceModel", + keep_top_n: int = 5, + keep_ratio: float = None, + invert_keep: bool = False, + stopping_condition_top_k: int = 3, + replacing_ratio: float = 0.3, + max_probe_steps: int = 3000, + num_probes: int = 16, + ): + """ReAGent method constructor. + + Args: + keep_top_n (:obj:`int`, `optional`): If set to a value greater than 0, the top n tokens based on their importance score will be + kept during the prediction inference. If set to 0, the top n will be determined by ``keep_ratio``. Default: ``5``. + keep_ratio (:obj:`float`, `optional`): If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep. + invert_keep (:obj:`bool`, `optional`): If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be + replaced instead of being kept. Default: ``False``. + stopping_condition_top_k (:obj:`int`, `optional`): Threshold indicating that the stop condition achieved when the predicted target + exist in top k predictions. Default: ``3``. + replacing_ratio (:obj:`float`, `optional`): replacing ratio of tokens for probing. Default: ``0.3``. + max_probe_steps (:obj:`int`, `optional`): Max number of steps before stopping the probing. Default: ``3000``. + num_probes (:obj:`int`, `optional`): Number of probes performed in parallel. Default: ``16``. + """ + super().__init__(attribution_model) + # Custom target attribution is currently not supported + self.use_predicted_target = False + self.method = Reagent( + attribution_model=self.attribution_model, + keep_top_n=keep_top_n, + keep_ratio=keep_ratio, + invert_keep=invert_keep, + stopping_condition_top_k=stopping_condition_top_k, + replacing_ratio=replacing_ratio, + max_probe_steps=max_probe_steps, + num_probes=num_probes, + ) + + def attribute_step( + self, + attribute_fn_main_args: dict[str, Any], + attribution_args: dict[str, Any] = {}, + ) -> GranularFeatureAttributionStepOutput: + out = super().attribute_step(attribute_fn_main_args, attribution_args) + return GranularFeatureAttributionStepOutput( + source_attributions=out.source_attributions, + target_attributions=out.target_attributions, + sequence_scores=out.sequence_scores, + ) + + +class ValueZeroingAttribution(PerturbationAttributionRegistry): + """Value Zeroing method for feature attribution. + + Introduced by `Mohebbi et al. (2023) `__ to quantify context mixing + in Transformer models. The method is based on the observation that context mixing is regulated by the value vectors + of the attention mechanism. The method consists of two steps: + + 1. Zeroing the value vectors of the attention mechanism for a given token index at a given layer of the model. + 2. Computing the similarity between hidden states produced with and without the zeroing operation, and using it + as a measure of context mixing for the given token at the given layer. + + The method is converted into a feature attribution method by allowing for extraction of value zeroing scores at + specific layers, or by aggregating them across layers. + + Reference implementations: + - Original implementation: `hmohebbi/ValueZeroing `__ + - Encoder-decoder implementation: `hmohebbi/ContextMixingASR `__ + + Args: + similarity_metric (:obj:`str`, optional): The similarity metric to use for computing the distance between + hidden states produced with and without the zeroing operation. Options: cosine, euclidean. Default: cosine. + encoder_zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int], dict]`, optional): The indices of + the attention heads that should be zeroed to compute corrupted states in the encoder self-attention module. + Not used for decoder-only models, or if ``output_encoder_self_scores`` is False. Format + + - None: all attention heads across all layers are zeroed. + - int: the same attention head is zeroed across all layers. + - tuple of two integers: the attention heads in the range are zeroed across all layers. + - list of integers: the attention heads in the list are zeroed across all layers. + - dictionary: the keys are the layer indices and the values are the zeroed attention heads for the corresponding layer. + + Default: None (all heads are zeroed for every encoder layer). + decoder_zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int], dict]`, optional): Same as + ``encoder_zeroed_units_indices`` but for the decoder self-attention module. Not used for encoder-decoder + models or if ``output_decoder_self_scores`` is False. Default: None (all heads are zeroed for every decoder layer). + cross_zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int], dict]`, optional): Same as + ``encoder_zeroed_units_indices`` but for the cross-attention module in encoder-decoder models. Not used + if the model is decoder-only. Default: None (all heads are zeroed for every layer). + output_decoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the + decoder self-attention value vectors in encoder-decoder models. Cannot be false for decoder-only, or + if target-side attribution is requested using `attribute_target=True`. Default: True. + output_encoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the + encoder self-attention value vectors in encoder-decoder models. Default: True. + + Returns: + :class:`~inseq.data.MultiDimensionalFeatureAttributionStepOutput`: The final dimension returned by the method + is ``[attributed_seq_len, generated_seq_len, num_layers]``. If ``output_decoder_self_scores`` and + ``output_encoder_self_scores`` are True, the respective scores are returned in the ``sequence_scores`` + output dictionary. + """ + + method_name = "value_zeroing" + + def __init__(self, attribution_model, **kwargs): + super().__init__(attribution_model, hook_to_model=False) + # Hidden states will be passed to the attribute_step method + self.use_hidden_states = True + # Does not rely on predicted output (i.e. decoding strategy agnostic) + self.use_predicted_target = False + # Uses model configuration to access attention module and value vector variable + self.use_model_config = True + # Needs only the final generation step to extract scores + self.is_final_step_method = True + self.method = ValueZeroing(attribution_model) + self.hook(**kwargs) + + def attribute_step( + self, + attribute_fn_main_args: dict[str, Any], + attribution_args: dict[str, Any] = {}, + ) -> MultiDimensionalFeatureAttributionStepOutput: + attr = self.method.attribute(**attribute_fn_main_args, **attribution_args) + encoder_self_scores, decoder_cross_scores, decoder_self_scores = get_source_target_attributions( + attr, self.attribution_model.is_encoder_decoder, has_sequence_scores=True + ) + sequence_scores = {} + if self.attribution_model.is_encoder_decoder: + if len(attribute_fn_main_args["inputs"]) > 1: + target_attributions = decoder_self_scores.to("cpu") + else: + target_attributions = None + if decoder_self_scores is not None: + sequence_scores["decoder_self_scores"] = decoder_self_scores.to("cpu") + if encoder_self_scores is not None: + sequence_scores["encoder_self_scores"] = encoder_self_scores.to("cpu") + return MultiDimensionalFeatureAttributionStepOutput( + source_attributions=decoder_cross_scores.to("cpu"), + target_attributions=target_attributions, + sequence_scores=sequence_scores, + _num_dimensions=1, # num_layers + ) + return MultiDimensionalFeatureAttributionStepOutput( + source_attributions=None, + target_attributions=decoder_self_scores, + _num_dimensions=1, # num_layers + ) diff --git a/inseq/attr/step_functions.py b/inseq/attr/step_functions.py index d9cd6092..83aa8d6e 100644 --- a/inseq/attr/step_functions.py +++ b/inseq/attr/step_functions.py @@ -462,8 +462,7 @@ def register_step_function( attribution targets by gradient-based feature attribution methods. Args: - fn (:obj:`callable`): The function to be used to compute step scores. Default parameters (use kwargs to capture - unused ones when defining your function): + fn (:obj:`callable`): The function to be used to compute step scores. Default parameters (use kwargs to capture unused ones when defining your function): - :obj:`attribution_model`: an :class:`~inseq.models.AttributionModel` instance, corresponding to the model used for computing the score. diff --git a/inseq/commands/attribute_context/attribute_context.py b/inseq/commands/attribute_context/attribute_context.py index f4cbb7cb..1eb72126 100644 --- a/inseq/commands/attribute_context/attribute_context.py +++ b/inseq/commands/attribute_context/attribute_context.py @@ -34,6 +34,7 @@ from .attribute_context_helpers import ( AttributeContextOutput, CCIOutput, + concat_with_sep, filter_rank_tokens, format_template, get_contextless_output, @@ -69,10 +70,6 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM # Prepare input/outputs (generate if necessary) input_full_text = format_template(args.input_template, args.input_current_text, args.input_context_text) - if "{current}" in args.contextless_input_current_text: - args.input_current_text = args.contextless_input_current_text.format(current=args.input_current_text) - else: - args.input_current_text = args.contextless_input_current_text args.output_context_text, args.output_current_text = prepare_outputs( model=model, input_context_text=args.input_context_text, @@ -92,8 +89,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM if args.input_context_text is not None: input_context_tokens = get_filtered_tokens(args.input_context_text, model, args.special_tokens_to_keep) if not model.is_encoder_decoder: - sep = args.decoder_input_output_separator if not output_full_text.startswith((" ", "\n")) else "" - output_full_text = input_full_text + sep + output_full_text + output_full_text = concat_with_sep(input_full_text, output_full_text, args.decoder_input_output_separator) output_current_tokens = get_filtered_tokens( args.output_current_text, model, args.special_tokens_to_keep, is_target=True ) @@ -105,16 +101,18 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM input_full_tokens = get_filtered_tokens(input_full_text, model, args.special_tokens_to_keep) output_full_tokens = get_filtered_tokens(output_full_text, model, args.special_tokens_to_keep, is_target=True) output_current_text_offset = len(output_full_tokens) - len(output_current_tokens) - if model.is_encoder_decoder: - prefixed_output_current_text = args.output_current_text - else: - sep = args.decoder_input_output_separator if not args.output_current_text.startswith((" ", "\n")) else "" - prefixed_output_current_text = args.input_current_text + sep + args.output_current_text + formatted_input_current_text = args.contextless_input_current_text.format(current=args.input_current_text) + formatted_output_current_text = args.contextless_output_current_text.format(current=args.output_current_text) + if not model.is_encoder_decoder: + formatted_input_current_text = concat_with_sep( + formatted_input_current_text, "", args.decoder_input_output_separator + ) + formatted_output_current_text = formatted_input_current_text + formatted_output_current_text # Part 1: Context-sensitive Token Identification (CTI) cti_out = model.attribute( - args.input_current_text, - prefixed_output_current_text, + formatted_input_current_text.rstrip(" "), + formatted_output_current_text, attribute_target=model.is_encoder_decoder, step_scores=[args.context_sensitivity_metric], contrast_sources=input_full_text if model.is_encoder_decoder else None, @@ -126,6 +124,11 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM cti_out.show(do_aggregation=False) start_pos = 1 if has_lang_tag else 0 + contextless_output_prefix = args.contextless_output_current_text.split("{current}")[0] + contextless_output_prefix_tokens = get_filtered_tokens( + contextless_output_prefix, model, args.special_tokens_to_keep, is_target=True + ) + start_pos += len(contextless_output_prefix_tokens) cti_scores = cti_out.step_scores[args.context_sensitivity_metric][start_pos:].tolist() cti_tokens = [t.token for t in cti_out.target][start_pos + cti_out.attr_pos_start :] if model.is_encoder_decoder: @@ -149,18 +152,29 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM ) # Part 2: Contextual Cues Imputation (CCI) for cci_step_idx, (cti_idx, cti_score, cti_tok) in enumerate(cti_ranked_tokens): + contextual_input = model.convert_tokens_to_string(input_full_tokens, skip_special_tokens=False).lstrip(" ") contextual_output = model.convert_tokens_to_string( output_full_tokens[: output_current_text_offset + cti_idx + 1], skip_special_tokens=False - ) + ).lstrip(" ") if not contextual_output: - contextual_output = output_full_tokens[output_current_text_offset + cti_idx] - + output_ctx_tokens = [output_full_tokens[output_current_text_offset + cti_idx]] + if model.is_encoder_decoder: + output_ctx_tokens.append(model.pad_token) + contextual_output = model.convert_tokens_to_string(output_ctx_tokens, skip_special_tokens=True) + else: + output_ctx_tokens = model.convert_string_to_tokens( + contextual_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder + ) cci_kwargs = {} contextless_output = None if args.attributed_fn is not None and is_contrastive_step_function(args.attributed_fn): + if not model.is_encoder_decoder: + formatted_input_current_text = concat_with_sep( + formatted_input_current_text, contextless_output_prefix, args.decoder_input_output_separator + ) contextless_output = get_contextless_output( model, - args.input_current_text, + formatted_input_current_text, output_current_tokens, cti_idx, cti_ranked_tokens, @@ -171,20 +185,18 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM args.special_tokens_to_keep, deepcopy(args.generation_kwargs), ) - cci_kwargs["contrast_sources"] = args.input_current_text if model.is_encoder_decoder else None + cci_kwargs["contrast_sources"] = formatted_input_current_text if model.is_encoder_decoder else None cci_kwargs["contrast_targets"] = contextless_output - output_ctx_tokens = model.convert_string_to_tokens( - contextual_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder - ) output_ctxless_tokens = model.convert_string_to_tokens( contextless_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder ) tok_pos = -2 if model.is_encoder_decoder else -1 if args.attributed_fn == "kl_divergence" or output_ctx_tokens[tok_pos] == output_ctxless_tokens[tok_pos]: cci_kwargs["contrast_force_inputs"] = True - pos_start = output_current_text_offset + cti_idx + int(model.is_encoder_decoder) + int(has_lang_tag) + bos_offset = int(model.is_encoder_decoder or output_ctx_tokens[0] == model.bos_token) + pos_start = output_current_text_offset + cti_idx + bos_offset + int(has_lang_tag) cci_attrib_out = model.attribute( - input_full_text, + contextual_input, contextual_output, attribute_target=model.is_encoder_decoder and args.has_output_context, show_progress=False, @@ -206,6 +218,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM model, cci_attrib_out, args.input_template, + args.input_current_text, input_context_tokens, input_full_tokens, args.output_template, @@ -213,6 +226,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM args.has_input_context, args.has_output_context, has_lang_tag, + args.decoder_input_output_separator, args.special_tokens_to_keep, ) cci_out = CCIOutput( diff --git a/inseq/commands/attribute_context/attribute_context_args.py b/inseq/commands/attribute_context/attribute_context_args.py index 716eb5b6..295d5cc2 100644 --- a/inseq/commands/attribute_context/attribute_context_args.py +++ b/inseq/commands/attribute_context/attribute_context_args.py @@ -91,6 +91,17 @@ class AttributeContextInputArgs: " be used as-is for the contrastive comparison, enabling contrastive comparison with different inputs." ), ) + contextless_output_current_text: Optional[str] = cli_arg( + default=None, + help=( + "The output current text or template to use in the contrastive comparison with contextual output. By default" + " it is the same as ``output_current_text``, but it can be useful in cases where the context is nested " + "inside the current text (e.g. for an ``output_template`` like \n{context}\n{current}\n we " + "can use this parameter to format the contextless version as \n{current}\n)." + "If it contains the tag {current}, it will be infilled with the ``output_current_text``. Otherwise, it will" + " be used as-is for the contrastive comparison, enabling contrastive comparison with different outputs." + ), + ) @command_args_docstring @@ -184,7 +195,7 @@ class AttributeContextOutputArgs: default=False, help=( "If specified, the intermediate outputs produced by the Inseq library for context-sensitive target " - "identification (CTI) and contextual cues imputation (CCI) are shown during the process.", + "identification (CTI) and contextual cues imputation (CCI) are shown during the process." ), ) save_path: Optional[str] = cli_arg( @@ -212,8 +223,19 @@ class AttributeContextArgs(AttributeContextInputArgs, AttributeContextMethodArgs def __repr__(self): return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})" + @classmethod + def _to_dict(cls, val: Any) -> dict[str, Any]: + if val is None or isinstance(val, (str, int, float, bool)): + return val + elif isinstance(val, dict): + return {k: cls._to_dict(v) for k, v in val.items()} + elif isinstance(val, (list, tuple)): + return [cls._to_dict(v) for v in val] + else: + return str(val) + def to_dict(self) -> dict[str, Any]: - return dict(self.__dict__.items()) + return self._to_dict(self.__dict__) def __post_init__(self): if ( @@ -236,6 +258,18 @@ def __post_init__(self): self.output_template = "{current}" if self.output_context_text is None else "{context} {current}" if self.contextless_input_current_text is None: self.contextless_input_current_text = "{current}" + if "{current}" not in self.contextless_input_current_text: + raise ValueError( + "{current} placeholder is missing from contextless_input_current_text template" + f" {self.contextless_input_current_text}." + ) + if self.contextless_output_current_text is None: + self.contextless_output_current_text = "{current}" + if "{current}" not in self.contextless_output_current_text: + raise ValueError( + "{current} placeholder is missing from contextless_output_current_text template" + f" {self.contextless_output_current_text}." + ) self.has_input_context = "{context}" in self.input_template self.has_output_context = "{context}" in self.output_template if not self.input_current_text: diff --git a/inseq/commands/attribute_context/attribute_context_helpers.py b/inseq/commands/attribute_context/attribute_context_helpers.py index f436bba4..cb8793e2 100644 --- a/inseq/commands/attribute_context/attribute_context_helpers.py +++ b/inseq/commands/attribute_context/attribute_context_helpers.py @@ -72,6 +72,14 @@ def from_dict(cls, out_dict: dict[str, Any]) -> "AttributeContextOutput": return out +def concat_with_sep(s1: str, s2: str, sep: str) -> bool: + """Adds separator between two strings if needed.""" + need_sep = not s1.endswith(sep) and not s2.startswith(sep) + if need_sep: + return s1 + sep + s2 + return s1 + s2 + + def format_template(template: str, current: str, context: Optional[str] = None) -> str: kwargs = {"current": current} if context is not None: @@ -89,7 +97,7 @@ def get_filtered_tokens( """Tokenize text and filter out special tokens, keeping only those in ``special_tokens_to_keep``.""" as_targets = is_target and model.is_encoder_decoder return [ - t.replace("Ä ", " ").replace("Ċ", " ").replace("ā–", " ") if replace_special_characters else t + t.replace("Ä ", " ").replace("Ċ", "\n").replace("ā–", " ") if replace_special_characters else t for t in model.convert_string_to_tokens(text, skip_special_tokens=False, as_targets=as_targets) if t not in model.special_tokens or t in special_tokens_to_keep ] @@ -99,11 +107,14 @@ def generate_with_special_tokens( model: HuggingfaceModel, model_input: str, special_tokens_to_keep: list[str] = [], + output_generated_only: bool = True, **generation_kwargs, ) -> str: """Generate text preserving special tokens in ``special_tokens_to_keep``.""" # Generate outputs, strip special tokens and remove prefix/suffix - output_gen = model.generate(model_input, skip_special_tokens=False, **generation_kwargs)[0] + output_gen = model.generate( + model_input, skip_special_tokens=False, output_generated_only=output_generated_only, **generation_kwargs + )[0] output_tokens = get_filtered_tokens(output_gen, model, special_tokens_to_keep, is_target=True) return model.convert_tokens_to_string(output_tokens, skip_special_tokens=False) @@ -236,20 +247,19 @@ def prepare_outputs( if "forced_bos_token_id" in generation_kwargs: generation_kwargs["decoder_input_ids"][0, 0] = generation_kwargs["forced_bos_token_id"] else: - sep = "" - if output_current_prefix and not output_current_prefix.startswith((" ", "\n")): - sep = decoder_input_output_separator - model_input = input_full_text + sep + output_current_prefix + model_input = concat_with_sep(input_full_text, output_current_prefix, decoder_input_output_separator) output_current_prefix = model_input + if not model.is_encoder_decoder: + model_input = concat_with_sep(input_full_text, "", decoder_input_output_separator) + output_gen = generate_model_output( model, model_input, generation_kwargs, special_tokens_to_keep, output_template, output_current_prefix, suffix ) # Settings 3, 4 if (has_out_ctx == use_out_ctx) and not has_out_curr: - final_current = output_gen if model.is_encoder_decoder or use_out_ctx else output_gen[len(model_input) :] - return final_context, final_current.strip() + return final_context, output_gen.strip() # Settings 5, 6 # Try splitting the output into context and current text using ``separator``. As we have no guarantees of its @@ -385,14 +395,12 @@ def generate_contextless_output( generation_input = input_current_text else: generation_kwargs["max_new_tokens"] = 1 - sep = "" - if contextual_prefix and not contextual_prefix.startswith((" ", "\n")): - sep = decoder_input_output_separator - generation_input = input_current_text + sep + contextual_prefix + generation_input = concat_with_sep(input_current_text, contextual_prefix, decoder_input_output_separator) contextless_output = generate_with_special_tokens( model, generation_input, special_tokens_to_keep, + output_generated_only=False, **generation_kwargs, ) return contextless_output @@ -402,6 +410,7 @@ def get_source_target_cci_scores( model: HuggingfaceModel, cci_attrib_out: FeatureAttributionSequenceOutput, input_template: str, + input_current_text: str, input_context_tokens: list[str], input_full_tokens: list[str], output_template: str, @@ -409,6 +418,7 @@ def get_source_target_cci_scores( has_input_context: bool, has_output_context: bool, model_has_lang_tag: bool, + decoder_input_output_separator: str, special_tokens_to_keep: list[str] = [], ) -> tuple[Optional[list[float]], Optional[list[float]]]: """Extract attribution scores for the input and output contexts.""" @@ -417,20 +427,26 @@ def get_source_target_cci_scores( if model.is_encoder_decoder: input_scores = cci_attrib_out.source_attributions[:, 0].tolist() if model_has_lang_tag: - input_scores = input_scores[1:] + input_scores = input_scores[2:] else: input_scores = cci_attrib_out.target_attributions[:, 0].tolist() input_prefix, *_ = input_template.partition("{context}") + if "{current}" in input_prefix: + input_prefix = input_prefix.format(current=input_current_text) input_prefix_tokens = get_filtered_tokens(input_prefix, model, special_tokens_to_keep, is_target=False) input_prefix_len = len(input_prefix_tokens) input_scores = input_scores[input_prefix_len : len(input_context_tokens) + input_prefix_len] if has_output_context: output_scores = cci_attrib_out.target_attributions[:, 0].tolist() if model_has_lang_tag: - output_scores = output_scores[1:] + output_scores = output_scores[2:] output_prefix, *_ = output_template.partition("{context}") + if not model.is_encoder_decoder and output_prefix: + output_prefix = decoder_input_output_separator + output_prefix output_prefix_tokens = get_filtered_tokens(output_prefix, model, special_tokens_to_keep, is_target=True) - prefix_len = len(output_prefix_tokens) + int(not model.is_encoder_decoder) * len(input_full_tokens) + prefix_len = len(output_prefix_tokens) + if not model.is_encoder_decoder: + prefix_len += len(input_full_tokens) output_scores = output_scores[prefix_len : len(output_context_tokens) + prefix_len] return input_scores, output_scores diff --git a/inseq/commands/attribute_context/attribute_context_viz_helpers.py b/inseq/commands/attribute_context/attribute_context_viz_helpers.py index 6c322e12..8a2dff76 100644 --- a/inseq/commands/attribute_context/attribute_context_viz_helpers.py +++ b/inseq/commands/attribute_context/attribute_context_viz_helpers.py @@ -145,7 +145,7 @@ def visualize_attribute_context( console.print(viz, soft_wrap=False) html = console.export_html() if output.info.viz_path: - with open(output.info.viz_path, "w") as f: + with open(output.info.viz_path, "w", encoding="utf-8") as f: f.write(html) if output.info.show_viz: console.print(viz, soft_wrap=False) diff --git a/inseq/commands/commands_utils.py b/inseq/commands/commands_utils.py index 7701409a..dbfb8ac4 100644 --- a/inseq/commands/commands_utils.py +++ b/inseq/commands/commands_utils.py @@ -18,5 +18,4 @@ def command_args_docstring(cls): field_help = field.metadata.get("help", "") docstring += textwrap.dedent(f"\n**{field.name}** (``{field_type}``): {field_help}\n") cls.__doc__ = docstring - print(docstring) return cls diff --git a/inseq/data/aggregator.py b/inseq/data/aggregator.py index bb475707..dbae5352 100644 --- a/inseq/data/aggregator.py +++ b/inseq/data/aggregator.py @@ -12,9 +12,10 @@ aggregate_token_sequence, available_classes, extract_signature_args, + validate_indices, ) from ..utils import normalize as normalize_fn -from ..utils.typing import IndexSpan, TokenWithId +from ..utils.typing import IndexSpan, OneOrMoreIndices, TokenWithId from .aggregation_functions import AggregationFunction from .data_utils import TensorWrapper @@ -305,7 +306,7 @@ def _process_attribution_scores( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, - select_idx: Union[int, tuple[int, int], list[int], None] = None, + select_idx: Optional[OneOrMoreIndices] = None, normalize: bool = True, **kwargs, ): @@ -366,7 +367,7 @@ def aggregate_source_attributions( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, - select_idx: Union[int, tuple[int, int], list[int], None] = None, + select_idx: Optional[OneOrMoreIndices] = None, normalize: bool = True, **kwargs, ): @@ -380,7 +381,7 @@ def aggregate_target_attributions( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, - select_idx: Union[int, tuple[int, int], list[int], None] = None, + select_idx: Optional[OneOrMoreIndices] = None, normalize: bool = True, **kwargs, ): @@ -398,7 +399,7 @@ def aggregate_sequence_scores( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, - select_idx: Union[int, tuple[int, int], list[int], None] = None, + select_idx: Optional[OneOrMoreIndices] = None, **kwargs, ): if aggregate_fn.takes_sequence_scores: @@ -439,46 +440,12 @@ def is_compatible(attr: "FeatureAttributionSequenceOutput"): def _filter_scores( scores: torch.Tensor, dim: int = -1, - indices: Union[int, tuple[int, int], list[int], None] = None, + indices: Optional[OneOrMoreIndices] = None, ) -> torch.Tensor: - n_units = scores.shape[dim] - - if hasattr(indices, "__iter__"): - if len(indices) == 0: - raise RuntimeError("At least two indices must be specified for aggregation.") - if len(indices) == 1: - indices = indices[0] - + indexed = scores.index_select(dim, validate_indices(scores, dim, indices).to(scores.device)) if isinstance(indices, int): - if indices not in range(-n_units, n_units): - raise IndexError(f"Index out of range. Scores only have {n_units} units.") - indices = indices if indices >= 0 else n_units + indices - return scores.select(dim, torch.tensor(indices, device=scores.device)) - else: - if indices is None: - indices = (0, n_units) - logger.info("No indices specified for extraction. Using all units by default.") - - # Convert negative indices to positive indices - if hasattr(indices, "__iter__"): - indices = type(indices)([h_idx if h_idx >= 0 else n_units + h_idx for h_idx in indices]) - if not hasattr(indices, "__iter__") or ( - len(indices) == 2 and isinstance(indices, tuple) and indices[0] >= indices[1] - ): - raise RuntimeError( - "A (start, end) tuple of indices representing a span, a list of individual indices" - " or a single index must be specified for select_idx." - ) - max_idx_val = n_units if isinstance(indices, list) else n_units + 1 - if not all(h in range(-n_units, max_idx_val) for h in indices): - raise IndexError("One or more index out of range. Scores only have {n_units} units.") - if len(set(indices)) != len(indices): - raise IndexError("Duplicate indices are not allowed.") - if isinstance(indices, tuple): - scores = scores.index_select(dim, torch.arange(indices[0], indices[1], device=scores.device)) - else: - scores = scores.index_select(dim, torch.tensor(indices, device=scores.device)) - return scores + return indexed.squeeze(dim) + return indexed @staticmethod def _aggregate_scores( diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py index f3671244..7841cf7c 100644 --- a/inseq/data/attribution.py +++ b/inseq/data/attribution.py @@ -12,6 +12,7 @@ get_sequences_from_batched_steps, json_advanced_dump, json_advanced_load, + pad_with_nan, pretty_dict, remap_from_filtered, ) @@ -178,9 +179,8 @@ def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callab def from_step_attributions( cls, attributions: list["FeatureAttributionStepOutput"], - tokenized_target_sentences: Optional[list[list[TokenWithId]]] = None, - pad_id: Optional[Any] = None, - has_bos_token: bool = True, + tokenized_target_sentences: list[list[TokenWithId]], + pad_token: Optional[Any] = None, attr_pos_end: Optional[int] = None, ) -> list["FeatureAttributionSequenceOutput"]: """Converts a list of :class:`~inseq.data.attribution.FeatureAttributionStepOutput` objects containing multiple @@ -198,36 +198,35 @@ def from_step_attributions( num_sequences = len(attr.prefix) if not all(len(attr.prefix) == num_sequences for attr in attributions): raise ValueError("All the attributions must include the same number of sequences.") - seq_attributions = [] - sources = None - if attr.source_attributions is not None: - sources = [drop_padding(attr.source[seq_id], pad_id) for seq_id in range(num_sequences)] - targets = [ - drop_padding([a.target[seq_id][0] for a in attributions], pad_id) for seq_id in range(num_sequences) - ] - if tokenized_target_sentences is None: - tokenized_target_sentences = targets - if has_bos_token: - tokenized_target_sentences = [tok_seq[1:] for tok_seq in tokenized_target_sentences] - tokenized_target_sentences = [ - drop_padding(tokenized_target_sentences[seq_id], pad_id) for seq_id in range(num_sequences) - ] + seq_attributions: list[FeatureAttributionSequenceOutput] = [] + sources = [] + targets = [] + pos_start = [] + for seq_idx in range(num_sequences): + if attr.source_attributions is not None: + sources.append(drop_padding(attr.source[seq_idx], pad_token)) + curr_target = [a.target[seq_idx][0] for a in attributions] + targets.append(drop_padding(curr_target, pad_token)) + if all(attr.prefix[seq_idx][0] == pad_token for seq_idx in range(num_sequences)): + tokenized_target_sentences[seq_idx] = tokenized_target_sentences[seq_idx][:1] + drop_padding( + tokenized_target_sentences[seq_idx][1:], pad_token + ) + else: + tokenized_target_sentences[seq_idx] = drop_padding(tokenized_target_sentences[seq_idx], pad_token) if attr_pos_end is None: attr_pos_end = max(len(t) for t in tokenized_target_sentences) - pos_start = [ - min(len(tokenized_target_sentences[seq_id]), attr_pos_end) - len(targets[seq_id]) - for seq_id in range(num_sequences) - ] - for seq_id in range(num_sequences): - source = tokenized_target_sentences[seq_id][: pos_start[seq_id]] if sources is None else sources[seq_id] - seq_attributions.append( - attr.get_sequence_cls( - source=source, - target=tokenized_target_sentences[seq_id], - attr_pos_start=pos_start[seq_id], - attr_pos_end=attr_pos_end, - ) + for seq_idx in range(num_sequences): + # If the model is decoder-only, the source is the input prefix + curr_pos_start = min(len(tokenized_target_sentences[seq_idx]), attr_pos_end) - len(targets[seq_idx]) + pos_start.append(curr_pos_start) + source = tokenized_target_sentences[seq_idx][:curr_pos_start] if not sources else sources[seq_idx] + curr_seq_attribution: FeatureAttributionSequenceOutput = attr.get_sequence_cls( + source=source, + target=tokenized_target_sentences[seq_idx], + attr_pos_start=pos_start[seq_idx], + attr_pos_end=attr_pos_end, ) + seq_attributions.append(curr_seq_attribution) if attr.source_attributions is not None: source_attributions = get_sequences_from_batched_steps([att.source_attributions for att in attributions]) for seq_id in range(num_sequences): @@ -241,18 +240,13 @@ def from_step_attributions( [att.target_attributions for att in attributions], padding_dims=[1] ) for seq_id in range(num_sequences): - if has_bos_token: - target_attributions[seq_id] = target_attributions[seq_id][1:, ...] start_idx = max(pos_start) - pos_start[seq_id] end_idx = start_idx + len(tokenized_target_sentences[seq_id]) target_attributions[seq_id] = target_attributions[seq_id][ start_idx:end_idx, : len(targets[seq_id]), ... # noqa: E203 ] if target_attributions[seq_id].shape[0] != len(tokenized_target_sentences[seq_id]): - empty_final_row = torch.ones( - 1, *target_attributions[seq_id].shape[1:], device=target_attributions[seq_id].device - ) * float("nan") - target_attributions[seq_id] = torch.cat([target_attributions[seq_id], empty_final_row], dim=0) + target_attributions[seq_id] = pad_with_nan(target_attributions[seq_id], dim=0, pad_size=1) seq_attributions[seq_id].target_attributions = target_attributions[seq_id] if attr.step_scores is not None: step_scores = [{} for _ in range(num_sequences)] @@ -427,47 +421,51 @@ def remap_from_filtered( self, target_attention_mask: TargetIdsTensor, batch: Union[DecoderOnlyBatch, EncoderDecoderBatch], + is_final_step_method: bool = False, ) -> None: """Remaps the attributions to the original shape of the input sequence.""" + batch_size = ( + len(batch.sources.input_tokens) if self.source_attributions is not None else len(batch.target_tokens) + ) + source_len = len(batch.sources.input_tokens[0]) + target_len = len(batch.target_tokens[0]) + # Normal per-step attribution outputs have shape (batch_size, seq_len, ...) + other_dims_start_idx = 2 + # Final step attribution outputs have shape (batch_size, seq_len, seq_len, ...) + if is_final_step_method: + other_dims_start_idx += 1 + other_dims = ( + self.source_attributions.shape[other_dims_start_idx:] + if self.source_attributions is not None + else self.target_attributions.shape[other_dims_start_idx:] + ) if self.source_attributions is not None: self.source_attributions = remap_from_filtered( - original_shape=(len(batch.sources.input_tokens), *self.source_attributions.shape[1:]), + original_shape=(batch_size, *self.source_attributions.shape[1:]), mask=target_attention_mask, filtered=self.source_attributions, ) if self.target_attributions is not None: self.target_attributions = remap_from_filtered( - original_shape=(len(batch.target_tokens), *self.target_attributions.shape[1:]), + original_shape=(batch_size, *self.target_attributions.shape[1:]), mask=target_attention_mask, filtered=self.target_attributions, ) if self.step_scores is not None: for score_name, score_tensor in self.step_scores.items(): self.step_scores[score_name] = remap_from_filtered( - original_shape=(len(batch.target_tokens), 1), + original_shape=(batch_size, 1), mask=target_attention_mask, filtered=score_tensor.unsqueeze(-1), ).squeeze(-1) if self.sequence_scores is not None: for score_name, score_tensor in self.sequence_scores.items(): if score_name.startswith("decoder"): - original_shape = ( - len(batch.target_tokens), - self.target_attributions.shape[1], - *self.target_attributions.shape[1:], - ) + original_shape = (batch_size, target_len, target_len, *other_dims) elif score_name.startswith("encoder"): - original_shape = ( - len(batch.sources.input_tokens), - self.source_attributions.shape[1], - *self.source_attributions.shape[1:], - ) + original_shape = (batch_size, source_len, source_len, *other_dims) else: # default case: cross-attention - original_shape = ( - len(batch.sources.input_tokens), - self.target_attributions.shape[1], - *self.source_attributions.shape[1:], - ) + original_shape = (batch_size, source_len, target_len, *other_dims) self.sequence_scores[score_name] = remap_from_filtered( original_shape=original_shape, mask=target_attention_mask, diff --git a/inseq/data/data_utils.py b/inseq/data/data_utils.py index b907d627..d0f90203 100644 --- a/inseq/data/data_utils.py +++ b/inseq/data/data_utils.py @@ -112,7 +112,7 @@ def _torch(attr): def _eq(self_attr: TensorClass, other_attr: TensorClass) -> bool: try: if isinstance(self_attr, torch.Tensor): - return torch.allclose(self_attr, other_attr, equal_nan=True) + return torch.allclose(self_attr, other_attr, equal_nan=True, atol=1e-5) elif isinstance(self_attr, dict): return all(TensorWrapper._eq(self_attr[k], other_attr[k]) for k in self_attr.keys()) else: @@ -175,6 +175,10 @@ def clone(self: TensorClass) -> TensorClass: out_params[field.name] = None return self.__class__(**out_params) + def clone_empty(self: TensorClass) -> TensorClass: + out_params = {k: v for k, v in self.__dict__.items() if k.startswith("_") and v is not None} + return self.__class__(**out_params) + def to_dict(self: TensorClass) -> dict[str, Any]: return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} diff --git a/inseq/data/viz.py b/inseq/data/viz.py index fef73b58..5721dff1 100644 --- a/inseq/data/viz.py +++ b/inseq/data/viz.py @@ -26,6 +26,7 @@ from rich.color import Color from rich.console import Console from rich.live import Live +from rich.markup import escape from rich.padding import Padding from rich.panel import Panel from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn @@ -255,7 +256,7 @@ def get_saliency_heatmap_rich( ): columns = [Column(header="", justify="right", overflow="fold")] for column_label in column_labels: - columns.append(Column(header=column_label, justify="center", overflow="fold")) + columns.append(Column(header=escape(column_label), justify="center", overflow="fold")) table = Table( *columns, title=f"{label + ' ' if label else ''}Saliency Heatmap", @@ -266,7 +267,7 @@ def get_saliency_heatmap_rich( ) if scores is not None: for row_index in range(scores.shape[0]): - row = [Text(row_labels[row_index], style="bold")] + row = [Text(escape(row_labels[row_index]), style="bold")] for col_index in range(scores.shape[1]): color = Color.from_rgb(*input_colors[row_index][col_index]) score = "" @@ -281,7 +282,7 @@ def get_saliency_heatmap_rich( else: threshold = step_scores_threshold.get(step_score_name, 0.5) style = lambda val, limit: "bold" if abs(val) >= limit and isinstance(val, float) else "" - score_row = [Text(step_score_name, style="bold")] + score_row = [Text(escape(step_score_name), style="bold")] for score in step_score_values: curr_score = round(score.item(), 2) if isinstance(score, float) else score.item() score_row.append(Text(f"{score:.2f}", justify="center", style=style(curr_score, threshold))) @@ -317,13 +318,13 @@ def get_progress_bar( TimeRemainingColumn(), ) for idx, (tgt, tgt_len) in enumerate(zip(sequences.targets, target_lengths)): - clean_tgt = tgt.replace("\n", "\\n") + clean_tgt = escape(tgt.replace("\n", "\\n")) job_progress.add_task(f"{idx}. {clean_tgt}", total=tgt_len) progress_table = Table.grid() row_contents = [ Panel.fit( job_progress, - title=f"[b]Attributing with {method_name}", + title=f"[b]Attributing with {escape(method_name)}", border_style="green", padding=(1, 2), ) @@ -331,7 +332,7 @@ def get_progress_bar( if sequences.sources is not None: sources = [] for idx, src in enumerate(sequences.sources): - clean_src = src.replace("\n", "\\n") + clean_src = escape(src.replace("\n", "\\n")) sources.append(f"{idx}. {clean_src}") row_contents = [ Panel.fit( @@ -370,7 +371,7 @@ def update_progress_bar( past_length = 0 for split, color in zip(split_targets, ["grey58", "green", "orange1", "grey58"]): if split[job.id]: - formatted_desc += f"[{color}]" + split[job.id].replace("\n", "\\n") + "[/]" + formatted_desc += f"[{color}]" + escape(split[job.id].replace("\n", "\\n")) + "[/]" past_length += len(split[job.id]) if past_length in whitespace_indexes[job.id]: formatted_desc += " " diff --git a/inseq/models/attribution_model.py b/inseq/models/attribution_model.py index c74e82c8..2f259a4c 100644 --- a/inseq/models/attribution_model.py +++ b/inseq/models/attribution_model.py @@ -219,6 +219,7 @@ def __init__(self, **kwargs) -> None: self.pad_token: Optional[str] = None self.embed_scale: Optional[float] = None self._device: Optional[str] = None + self.device_map: Optional[dict[str, Union[str, int, torch.device]]] = None self.attribution_method: Optional[FeatureAttribution] = None self.is_hooked: bool = False self._default_attributed_fn_id: str = "probability" @@ -386,6 +387,24 @@ def attribute( original_device = self.device if device is not None: self.device = device + attribution_method = self.get_attribution_method(method, override_default_attribution) + attributed_fn = self.get_attributed_fn(attributed_fn) + attribution_args, attributed_fn_args, step_scores_args = extract_args( + attribution_method, + attributed_fn, + step_scores, + default_args=self.formatter.get_step_function_reserved_args(), + **kwargs, + ) + if isnotebook(): + logger.debug("Pretty progress currently not supported in notebooks, falling back to tqdm.") + pretty_progress = False + if attribution_method.is_final_step_method: + if step_scores: + raise ValueError( + "Step scores are not supported for final step methods since they do not iterate over the full" + " sequence. Please remove the step scores and compute them separatly passing method='dummy'." + ) input_texts, generated_texts = format_input_texts(input_texts, generated_texts) has_generated_texts = generated_texts is not None if not self.is_encoder_decoder: @@ -411,36 +430,30 @@ def attribute( f"Generation arguments {generation_args} are provided, but will be ignored (constrained decoding)." ) logger.debug(f"reference_texts={generated_texts}") - attribution_method = self.get_attribution_method(method, override_default_attribution) - attributed_fn = self.get_attributed_fn(attributed_fn) - attribution_args, attributed_fn_args, step_scores_args = extract_args( - attribution_method, - attributed_fn, - step_scores, - default_args=self.formatter.get_step_function_reserved_args(), - **kwargs, - ) - if isnotebook(): - logger.debug("Pretty progress currently not supported in notebooks, falling back to tqdm.") - pretty_progress = False if not self.is_encoder_decoder: assert all( generated_texts[idx].startswith(input_texts[idx]) for idx in range(len(input_texts)) ), "Forced generations with decoder-only models must start with the input texts." if has_generated_texts and len(input_texts) > 1: - logger.info( + logger.warning( "Batched constrained decoding is currently not supported for decoder-only models." " Using batch size of 1." ) batch_size = 1 if len(input_texts) > 1 and (attr_pos_start is not None or attr_pos_end is not None): - logger.info( + logger.warning( "Custom attribution positions are currently not supported when batching generations for" " decoder-only models. Using batch size of 1." ) batch_size = 1 + elif attribution_method.is_final_step_method and len(input_texts) > 1: + logger.warning( + "Batched attribution with encoder-decoder models currently not supported for final-step methods." + " Using batch size of 1." + ) + batch_size = 1 if attribution_method.method_name == "lime": - logger.info("Batched attribution currently not supported for LIME. Using batch size of 1.") + logger.warning("Batched attribution currently not supported for LIME. Using batch size of 1.") batch_size = 1 attribution_outputs = attribution_method.prepare_and_attribute( input_texts, diff --git a/inseq/models/decoder_only.py b/inseq/models/decoder_only.py index f4355e42..e9f8a25e 100644 --- a/inseq/models/decoder_only.py +++ b/inseq/models/decoder_only.py @@ -119,7 +119,7 @@ def enrich_step_output( contrast_target_ids = contrast_batch.target_ids[:, contrast_aligned_idx] step_output.target = join_token_ids( tokens=target_tokens, - ids=attribution_model.convert_ids_to_tokens(contrast_target_ids), + ids=attribution_model.convert_ids_to_tokens(contrast_target_ids, skip_special_tokens=False), contrast_tokens=attribution_model.convert_ids_to_tokens( contrast_target_ids[None, ...], skip_special_tokens=False ), diff --git a/inseq/models/huggingface_model.py b/inseq/models/huggingface_model.py index c966372f..bb44f21c 100644 --- a/inseq/models/huggingface_model.py +++ b/inseq/models/huggingface_model.py @@ -95,9 +95,6 @@ def __init__( if isinstance(model, PreTrainedModel): self.model = model else: - if "output_attentions" not in model_kwargs: - model_kwargs["output_attentions"] = True - self.model = self._autoclass.from_pretrained(model, **model_kwargs) self.model_name = self.model.config.name_or_path self.tokenizer_name = tokenizer if isinstance(tokenizer, str) else None @@ -112,14 +109,23 @@ def __init__( self.tokenizer = tokenizer else: self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs) - if self.model.config.pad_token_id is not None: - self.pad_token = self._convert_ids_to_tokens(self.model.config.pad_token_id, skip_special_tokens=False) + self.eos_token_id = getattr(self.model.config, "eos_token_id", None) + pad_token_id = self.model.config.pad_token_id + if pad_token_id is None: + if self.tokenizer.pad_token_id is None: + logger.info(f"Setting `pad_token_id` to `eos_token_id`:{self.eos_token_id} for open-end generation.") + pad_token_id = self.eos_token_id + else: + pad_token_id = self.tokenizer.pad_token_id + self.pad_token = self._convert_ids_to_tokens(pad_token_id, skip_special_tokens=False) + if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.pad_token + if self.model.config.pad_token_id is None: + self.model.config.pad_token_id = pad_token_id self.bos_token_id = getattr(self.model.config, "decoder_start_token_id", None) if self.bos_token_id is None: self.bos_token_id = self.model.config.bos_token_id self.bos_token = self._convert_ids_to_tokens(self.bos_token_id, skip_special_tokens=False) - self.eos_token_id = getattr(self.model.config, "eos_token_id", None) if self.eos_token_id is None: self.eos_token_id = self.tokenizer.pad_token_id if self.tokenizer.unk_token_id is None: @@ -127,6 +133,9 @@ def __init__( self.embed_scale = 1.0 self.encoder_int_embeds = None self.decoder_int_embeds = None + self.device_map = None + if hasattr(self.model, "hf_device_map") and self.model.hf_device_map is not None: + self.device_map = self.model.hf_device_map self.is_encoder_decoder = self.model.config.is_encoder_decoder self.configure_embeddings_scale() self.setup(device, attribution_method, **kwargs) @@ -162,16 +171,19 @@ def device(self, new_device: str) -> None: is_loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) is_loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False) is_quantized = is_loaded_in_8bit or is_loaded_in_4bit + has_device_map = self.device_map is not None # Enable compatibility with 8bit models if self.model: - if not is_quantized: - self.model.to(self._device) - else: + if is_quantized: mode = "8bit" if is_loaded_in_8bit else "4bit" logger.warning( f"The model is loaded in {mode} mode. The device cannot be changed after loading the model." ) + elif has_device_map: + logger.warning("The model is loaded with a device map. The device cannot be changed after loading.") + else: + self.model.to(self._device) @abstractmethod def configure_embeddings_scale(self) -> None: @@ -195,6 +207,7 @@ def generate( inputs: Union[TextInput, BatchEncoding], return_generation_output: bool = False, skip_special_tokens: bool = True, + output_generated_only: bool = False, **kwargs, ) -> Union[list[str], tuple[list[str], ModelOutput]]: """Wrapper of model.generate to handle tokenization and decoding. @@ -204,6 +217,9 @@ def generate( Inputs to be provided to the model for generation. return_generation_output (`bool`, *optional*, defaults to False): If true, generation outputs are returned alongside the generated text. + output_generated_only (`bool`, *optional*, defaults to False): + If true, only the generated text is returned. Relevant for decoder-only models that would otherwise return + the full input + output. Returns: `Union[List[str], Tuple[List[str], ModelOutput]]`: Generated text or a tuple of generated text and @@ -220,6 +236,8 @@ def generate( **kwargs, ) sequences = generation_out.sequences + if output_generated_only and not self.is_encoder_decoder: + sequences = sequences[:, inputs.input_ids.shape[1] :] texts = self.decode(ids=sequences, skip_special_tokens=skip_special_tokens) if return_generation_output: return texts, generation_out diff --git a/inseq/models/model_config.py b/inseq/models/model_config.py index 05b8a468..52d9d47b 100644 --- a/inseq/models/model_config.py +++ b/inseq/models/model_config.py @@ -1,6 +1,7 @@ import logging from dataclasses import dataclass from pathlib import Path +from typing import Optional import yaml @@ -10,14 +11,25 @@ @dataclass class ModelConfig: """Configuration used by the methods for which the attribute ``use_model_config=True``. + Args: - attention_module (:obj:`str`): - The name of the module performing the attention computation (e.g.``attn`` for the GPT-2 model in - transformers). Can be identified by looking at the name of the attribute instantiating the attention module + self_attention_module (:obj:`str`): + The name of the module performing the self-attention computation (e.g.``attn`` for the GPT-2 model in + transformers). Can be identified by looking at the name of the self-attention module attribute in the model's transformer block class (e.g. :obj:`transformers.models.gpt2.GPT2Block` for GPT-2). + cross_attention_module (:obj:`str`): + The name of the module performing the cross-attention computation (e.g.``encoder_attn`` for MarianMT models + in transformers). Can be identified by looking at the name of the cross-attention module attribute + in the model's transformer block class (e.g. :obj:`transformers.models.marian.MarianDecoderLayer`). + value_vector (:obj:`str`): + The name of the variable in the forward pass of the attention module containing the value vector + (e.g. ``value`` for the GPT-2 model in transformers). Can be identified by looking at the forward pass of + the attention module (e.g. :obj:`transformers.models.gpt2.modeling_gpt2.GPT2Attention.forward` for GPT-2). """ - attention_module: str + self_attention_module: str + value_vector: str + cross_attention_module: Optional[str] = None MODEL_CONFIGS = { diff --git a/inseq/models/model_config.yaml b/inseq/models/model_config.yaml index b48ed209..1b2433a4 100644 --- a/inseq/models/model_config.yaml +++ b/inseq/models/model_config.yaml @@ -1,2 +1,129 @@ +# Decoder-only models +BioGptForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +BloomForCausalLM: + self_attention_module: "self_attention" + value_vector: "value_layer" +CodeGenForCausalLM: + self_attention_module: "attn" + value_vector: "value" +CohereForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +DbrxForCausalLM: + self_attention_module: "attn" + value_vector: "value_states" +FalconForCausalLM: + self_attention_module: "self_attention" + value_vector: "value_layer" +GemmaForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +GPTBigCodeForCausalLM: + self_attention_module: "attn" + value_vector: "value" +GPTJForCausalLM: + self_attention_module: "attn" + value_vector: "value" GPT2LMHeadModel: - attention_module: "attn" \ No newline at end of file + self_attention_module: "attn" + value_vector: "value" +GPTNeoForCausalLM: + self_attention_module: "attn" + value_vector: "value" +GPTNeoXForCausalLM: + self_attention_module: "attention" + value_vector: "value" +LlamaForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +MistralForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +MixtralForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +MptForCausalLM: + self_attention_module: "attn" + value_vector: "value_states" +OlmoForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +OpenAIGPTLMHeadModel: + self_attention_module: "attn" + value_vector: "value" +OPTForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +PhiForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +Phi3ForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +Qwen2ForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +Qwen2MoeForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +StableLmForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +StarCoder2ForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +XGLMForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" + +# Encoder-decoder models +BartForConditionalGeneration: + self_attention_module: "self_attn" + cross_attention_module: "encoder_attn" + value_vector: "value_states" +MarianMTModel: + self_attention_module: "self_attn" + cross_attention_module: "encoder_attn" + value_vector: "value_states" +FSMTForConditionalGeneration: + self_attention_module: "self_attn" + cross_attention_module: "encoder_attn" + value_vector: "v" +M2M100ForConditionalGeneration: + self_attention_module: "self_attn" + cross_attention_module: "encoder_attn" + value_vector: "value_states" +MBartForConditionalGeneration: + self_attention_module: "self_attn" + cross_attention_module: "encoder_attn" + value_vector: "value_states" +MT5ForConditionalGeneration: + self_attention_module: "SelfAttention" + cross_attention_module: "EncDecAttention" + value_vector: "value_states" +NllbMoeForConditionalGeneration: + self_attention_module: "self_attn" + cross_attention_module: "cross_attention" + value_vector: "value_states" +PegasusForConditionalGeneration: + self_attention_module: "self_attn" + cross_attention_module: "encoder_attn" + value_vector: "value_states" +SeamlessM4TForTextToText: + self_attention_module: "self_attn" + cross_attention_module: "cross_attention" + value_vector: "value" +SeamlessM4Tv2ForTextToText: + self_attention_module: "self_attn" + cross_attention_module: "cross_attention" + value_vector: "value" +T5ForConditionalGeneration: + self_attention_module: "SelfAttention" + cross_attention_module: "EncDecAttention" + value_vector: "value_states" +UMT5ForConditionalGeneration: + self_attention_module: "SelfAttention" + cross_attention_module: "EncDecAttention" + value_vector: "value_states" diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py index 29f81615..92a763cf 100644 --- a/inseq/utils/__init__.py +++ b/inseq/utils/__init__.py @@ -8,11 +8,14 @@ MissingAttributionMethodError, UnknownAttributionMethodError, ) +from .hooks import get_post_variable_assignment_hook from .import_utils import ( + is_accelerate_available, is_captum_available, is_datasets_available, is_ipywidgets_available, is_joblib_available, + is_nltk_available, is_scikitlearn_available, is_sentencepiece_available, is_transformers_available, @@ -49,12 +52,16 @@ check_device, euclidean_distance, filter_logits, + find_block_stack, get_default_device, get_front_padding, get_sequences_from_batched_steps, normalize, + pad_with_nan, + recursive_get_submodule, remap_from_filtered, top_p_logits_mask, + validate_indices, ) __all__ = [ @@ -94,6 +101,7 @@ "is_datasets_available", "is_captum_available", "is_joblib_available", + "is_nltk_available", "check_device", "get_default_device", "ndarray_to_bin_str", @@ -118,4 +126,9 @@ "top_p_logits_mask", "filter_logits", "cli_arg", + "get_post_variable_assignment_hook", + "validate_indices", + "pad_with_nan", + "recursive_get_submodule", + "is_accelerate_available", ] diff --git a/inseq/utils/hooks.py b/inseq/utils/hooks.py new file mode 100644 index 00000000..98fd07e5 --- /dev/null +++ b/inseq/utils/hooks.py @@ -0,0 +1,109 @@ +import re +from inspect import getsourcelines +from sys import gettrace, settrace +from types import FrameType +from typing import Callable, Optional + +from torch import nn + +from .misc import get_left_padding + + +def get_last_variable_assignment_position( + module: nn.Module, + varname: str, + fname: str = "forward", +) -> Optional[int]: + """Extract the code line number of the last variable assignment for a variable of interest in the specified method + of a `nn.Module` object. + + Args: + module (`nn.Module`): + A PyTorch module containing a method with a variable assignment after which the hook should be executed. + varname (`str`): + The name of the variable to use as anchor for the hook. + fname (`str`, *optional*, defaults to "forward"): + The name of the method in which the variable assignment should be searched. + + Returns: + `Optional[int]`: Returns the line number in the file (not relative to the method) of the last variable + assignment. Returns None if no assignment to the variable was found. + """ + # Matches any assignment of variable varname + pattern = rf"^\s*(?:\w+\s*,\s*)*\b{varname}\b\s*(?:,.+\s*)*=\s*[^\W=]+.*$" + code, startline = getsourcelines(getattr(module, fname)) + line_numbers = [] + i = 0 + while i < len(code): + line = code[i] + # Handles multi-line assignments + if re.match(pattern, line): + parentheses_count = line.count("(") - line.count(")") + ends_with_newline = lambda l: l.strip().endswith("\\") + follow_indent = lambda l, i: len(code) > i + 1 and get_left_padding(code[i + 1]) > get_left_padding(l) + while (ends_with_newline(line) or follow_indent(line, i) or parentheses_count > 0) and len(code) > i + 1: + i += 1 + line = code[i] + parentheses_count += line.count("(") - line.count(")") + line_numbers.append(i) + i += 1 + if len(line_numbers) == 0: + return None + return line_numbers[-1] + startline + 1 + + +def get_post_variable_assignment_hook( + module: nn.Module, + varname: str, + fname: str = "forward", + hook_fn: Callable[[FrameType], None] = lambda **kwargs: None, + **kwargs, +) -> Callable[[], None]: + """Creates a hook that is called after the last variable assignment in the specified method of a `nn.Module`. + + This is a hacky method using the ``sys.settrace()`` function to circumvent the limited hook points of Pytorch hooks + and set a custom hook point dynamically. This approach is preferred to ensure a broader compatibility with Hugging + Face transformers models that do not provide hook points in their architectures for the moment. + + Args: + module (`nn.Module`): + A PyTorch module containing a method with a variable assignment after which the hook should be executed. + varname (`str`): + The name of the variable to use as anchor for the hook. + fname (`str`, *optional*, defaults to "forward"): + The name of the method in which the variable assignment should be searched. + hook_fn (`Callable[[FrameType], None]`, *optional*, defaults to lambdaframe): + A custom hook function that is called after the last variable assignment in the specified method. The first + parameter is the current frame in the execution at the hook point, and any additional arguments can be + passed when creating the hook. ``frame.f_locals`` is a dictionary containing all local variables. + + Returns: + The hook function that can be registered with the module. If hooking the module's ``forward()`` method, the + hook can be registered with Pytorch native hook methods. + """ + hook_line_num = get_last_variable_assignment_position(module, varname, fname) + curr_trace_fn = gettrace() + if hook_line_num is None: + raise ValueError(f"Could not find assignment to {varname} in {module}'s {fname}() method") + + def var_tracer(frame, event, arg=None): + curr_line_num = frame.f_lineno + curr_func_name = frame.f_code.co_name + + # Matches the first executable line after hook_line_num in the same function of the same module + if ( + event == "line" + and curr_line_num >= hook_line_num + and curr_func_name == fname + and isinstance(frame.f_locals.get("self"), nn.Module) + and frame.f_locals.get("self")._get_name() == module._get_name() + ): + # Call the custom hook providing the current frame and any additional arguments as context + hook_fn(frame, **kwargs) + settrace(curr_trace_fn) + return var_tracer + + def hook(*args, **kwargs): + settrace(var_tracer) + + return hook diff --git a/inseq/utils/import_utils.py b/inseq/utils/import_utils.py index cbd03420..e8ae455e 100644 --- a/inseq/utils/import_utils.py +++ b/inseq/utils/import_utils.py @@ -7,6 +7,8 @@ _datasets_available = find_spec("datasets") is not None _captum_available = find_spec("captum") is not None _joblib_available = find_spec("joblib") is not None +_nltk_available = find_spec("nltk") is not None +_accelerate_available = find_spec("accelerate") is not None def is_ipywidgets_available(): @@ -35,3 +37,11 @@ def is_captum_available(): def is_joblib_available(): return _joblib_available + + +def is_nltk_available(): + return _nltk_available + + +def is_accelerate_available(): + return _accelerate_available diff --git a/inseq/utils/misc.py b/inseq/utils/misc.py index e09e5df7..628995bc 100644 --- a/inseq/utils/misc.py +++ b/inseq/utils/misc.py @@ -10,7 +10,6 @@ from functools import wraps from importlib import import_module from inspect import signature -from itertools import dropwhile from numbers import Number from os import PathLike, fsync from typing import Any, Callable, Optional, Union @@ -171,10 +170,10 @@ def pad(seq: Sequence[Sequence[Any]], pad_id: Any): return seq -def drop_padding(seq: Sequence[Any], pad_id: Any): +def drop_padding(seq: Sequence[TokenWithId], pad_id: str): if pad_id is None: return seq - return list(reversed(list(dropwhile(lambda x: x == pad_id, reversed(seq))))) + return [x for x in seq if x.token != pad_id] def isnotebook(): @@ -435,3 +434,8 @@ def clean_tokens(tokens: list[str], remove_tokens: list[str]) -> tuple[list[str] else: removed_token_idxs += [idx] return clean_tokens, removed_token_idxs + + +def get_left_padding(text: str): + """Returns the number of spaces at the beginning of a string.""" + return len(text) - len(text.lstrip()) diff --git a/inseq/utils/torch_utils.py b/inseq/utils/torch_utils.py index 88e807cc..86acd635 100644 --- a/inseq/utils/torch_utils.py +++ b/inseq/utils/torch_utils.py @@ -5,11 +5,14 @@ import torch import torch.nn.functional as F from jaxtyping import Int, Num +from torch import nn from torch.backends.cuda import is_built as is_cuda_built from torch.backends.mps import is_available as is_mps_available from torch.backends.mps import is_built as is_mps_built from torch.cuda import is_available as is_cuda_available +from .typing import OneOrMoreIndices + if TYPE_CHECKING: pass @@ -244,3 +247,118 @@ def get_default_device() -> str: return "cpu" else: return "cpu" + + +def find_block_stack(module): + """Recursively searches for the first instance of a `nn.ModuleList` submodule within a given `torch.nn.Module`. + + Args: + module (:obj:`torch.nn.Module`): A Pytorch :obj:`nn.Module` object. + + Returns: + :obj:`torch.nn.ModuleList`: The first instance of a :obj:`nn.Module` submodule found within the given object. + None: If no `nn.ModuleList` submodule is found within the given `nn.Module` object. + """ + # Check if the current module is an instance of nn.ModuleList + if isinstance(module, nn.ModuleList): + return module + + # Recursively search for nn.ModuleList in the submodules of the current module + for submodule in module.children(): + module_list = find_block_stack(submodule) + if module_list is not None: + return module_list + + # If nn.ModuleList is not found in any submodules, return None + return None + + +def validate_indices( + scores: torch.Tensor, + dim: int = -1, + indices: Optional[OneOrMoreIndices] = None, +) -> OneOrMoreIndices: + """Validates a set of indices for a given dimension of a tensor of scores. Supports single indices, spans and lists + of indices, including negative indices to specify positions relative to the end of the tensor. + + Args: + scores (torch.Tensor): The tensor of scores. + dim (int, optional): The dimension of the tensor that will be indexed. Defaults to -1. + indices (Union[int, tuple[int, int], list[int], None], optional): + - If an integer, it is interpreted as a single index for the dimension. + - If a tuple of two integers, it is interpreted as a span of indices for the dimension. + - If a list of integers, it is interpreted as a list of individual indices for the dimension. + + Returns: + ``Union[int, tuple[int, int], list[int]]``: The validated list of positive indices for indexing the dimension. + """ + if dim >= scores.ndim: + raise IndexError(f"Dimension {dim} is greater than tensor dimension {scores.ndim}") + n_units = scores.shape[dim] + if not isinstance(indices, (int, tuple, list)) and indices is not None: + raise TypeError( + "Indices must be an integer, a (start, end) tuple of indices representing a span, a list of individual" + " indices or a single index." + ) + if hasattr(indices, "__iter__"): + if len(indices) == 0: + raise RuntimeError("An empty sequence of indices is not allowed.") + if len(indices) == 1: + indices = indices[0] + + if isinstance(indices, int): + if indices not in range(-n_units, n_units): + raise IndexError(f"Index out of range. Scores only have {n_units} units.") + indices = indices if indices >= 0 else n_units + indices + return torch.tensor(indices) + else: + if indices is None: + indices = (0, n_units) + logger.info("No indices specified. Using all indices by default.") + + # Convert negative indices to positive indices + if hasattr(indices, "__iter__"): + indices = type(indices)([h_idx if h_idx >= 0 else n_units + h_idx for h_idx in indices]) + if not hasattr(indices, "__iter__") or ( + len(indices) == 2 and isinstance(indices, tuple) and indices[0] >= indices[1] + ): + raise RuntimeError( + "A (start, end) tuple of indices representing a span, a list of individual indices" + " or a single index must be specified." + ) + max_idx_val = n_units if isinstance(indices, list) else n_units + 1 + if not all(h in range(-n_units, max_idx_val) for h in indices): + raise IndexError(f"One or more index out of range. Scores only have {n_units} units.") + if len(set(indices)) != len(indices): + raise IndexError("Duplicate indices are not allowed.") + if isinstance(indices, tuple): + return torch.arange(indices[0], indices[1]) + else: + return torch.tensor(indices) + + +def pad_with_nan(t: torch.Tensor, dim: int, pad_size: int, front: bool = False) -> torch.Tensor: + """Utility to pad a tensor with nan values along a given dimension.""" + nan_tensor = torch.ones( + *t.shape[:dim], + pad_size, + *t.shape[dim + 1 :], + device=t.device, + ) * float("nan") + if front: + return torch.cat([nan_tensor, t], dim=dim) + return torch.cat([t, nan_tensor], dim=dim) + + +def recursive_get_submodule(parent: nn.Module, target: str) -> Optional[nn.Module]: + if target == "": + return parent + mod = None + if hasattr(parent, target): + mod = getattr(parent, target) + else: + for submodule in parent.children(): + mod = recursive_get_submodule(submodule, target) + if mod is not None: + break + return mod diff --git a/inseq/utils/typing.py b/inseq/utils/typing.py index 7599bbc7..4eec4a5b 100644 --- a/inseq/utils/typing.py +++ b/inseq/utils/typing.py @@ -1,13 +1,17 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import torch +from captum.attr._utils.attribution import Attribution from jaxtyping import Float, Float32, Int64 from transformers import PreTrainedModel TextInput = Union[str, Sequence[str]] +if TYPE_CHECKING: + from inseq.models import AttributionModel + @dataclass class TokenWithId: @@ -28,6 +32,34 @@ def __eq__(self, other: Union[str, int, "TokenWithId"]): return False +class InseqAttribution(Attribution): + """A wrapper class for the Captum library's Attribution class to type hint the ``forward_func`` attribute + as an :class:`~inseq.models.AttributionModel`. + """ + + def __init__(self, forward_func: "AttributionModel") -> None: + r""" + Args: + forward_func (:class:`~inseq.models.AttributionModel`): The model hooker to the attribution method. + """ + self.forward_func = forward_func + + attribute: Callable + + @property + def multiplies_by_inputs(self): + return False + + def has_convergence_delta(self) -> bool: + return False + + compute_convergence_delta: Callable + + @classmethod + def get_name(cls: type["InseqAttribution"]) -> str: + return "".join([char if char.islower() or idx == 0 else " " + char for idx, char in enumerate(cls.__name__)]) + + @dataclass class TextSequences: targets: TextInput @@ -40,6 +72,8 @@ class TextSequences: OneOrMoreAttributionSequences = Sequence[Sequence[float]] IndexSpan = Union[tuple[int, int], Sequence[tuple[int, int]]] +OneOrMoreIndices = Union[int, list[int], tuple[int, int]] +OneOrMoreIndicesDict = dict[int, OneOrMoreIndices] IdsTensor = Int64[torch.Tensor, "batch_size seq_len"] TargetIdsTensor = Int64[torch.Tensor, "batch_size"] diff --git a/pyproject.toml b/pyproject.toml index 3babbe9a..32592ee8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "inseq" -version = "0.6.0.dev0" +version = "0.7.0.dev0" description = "Interpretability for Sequence Generation Models šŸ”" readme = "README.md" requires-python = ">=3.9" @@ -74,7 +74,7 @@ docs = [ ] lint = [ "bandit>=1.7.4", - "safety>=2.2.0", + "safety>=3.1.0", "pydoclint>=0.4.0", "pre-commit>=2.19.0", "pytest>=7.2.0", @@ -93,6 +93,9 @@ notebook = [ "ipykernel>=6.29.2", "ipywidgets>=8.1.2" ] +nltk = [ + "nltk>=3.8.1", +] [project.urls] homepage = "https://github.com/inseq-team/inseq" diff --git a/requirements-dev.txt b/requirements-dev.txt index 91a4d3f2..2fda0ec1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ -# This file was autogenerated by uv v0.1.2 via the following command: +# This file was autogenerated by uv via the following command: # uv pip compile --all-extras pyproject.toml -o requirements-dev.txt aiohttp==3.9.3 # via @@ -32,6 +32,7 @@ charset-normalizer==3.3.2 # via requests click==8.1.7 # via + # nltk # pydoclint # safety # typer @@ -43,7 +44,7 @@ contourpy==1.2.0 # via matplotlib coverage==7.4.1 # via pytest-cov -cryptography==42.0.2 +cryptography==42.0.5 # via authlib cycler==0.12.1 # via matplotlib @@ -123,7 +124,9 @@ jinja2==3.1.3 # sphinx # torch joblib==1.3.2 - # via scikit-learn + # via + # nltk + # scikit-learn jupyter-client==8.6.0 # via ipykernel jupyter-core==5.7.1 @@ -160,6 +163,7 @@ nest-asyncio==1.6.0 # via ipykernel networkx==3.2.1 # via torch +nltk==3.8.1 nodeenv==1.8.0 # via pre-commit numpy==1.26.4 @@ -258,7 +262,9 @@ pyzmq==25.1.2 # ipykernel # jupyter-client regex==2023.12.25 - # via transformers + # via + # nltk + # transformers requests==2.31.0 # via # datasets @@ -280,7 +286,7 @@ ruamel-yaml-clib==0.2.8 ruff==0.2.1 safetensors==0.4.2 # via transformers -safety==3.0.1 +safety==3.1.0 safety-schemas==0.0.2 # via safety scikit-learn==1.4.0 @@ -351,6 +357,7 @@ tqdm==4.66.2 # captum # datasets # huggingface-hub + # nltk # transformers traitlets==5.14.1 # via @@ -361,7 +368,7 @@ traitlets==5.14.1 # jupyter-client # jupyter-core # matplotlib-inline -transformers==4.37.2 +transformers==4.38.1 typeguard==2.13.3 # via jaxtyping typer==0.9.0 diff --git a/requirements.txt b/requirements.txt index 9f392d72..93809632 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -# This file was autogenerated by uv v0.1.2 via the following command: +# This file was autogenerated by uv via the following command: # uv pip compile pyproject.toml -o requirements.txt captum==0.7.0 certifi==2024.2.2 @@ -93,7 +93,7 @@ tqdm==4.66.2 # captum # huggingface-hub # transformers -transformers==4.37.2 +transformers==4.38.1 typeguard==2.13.3 # via jaxtyping typing-extensions==4.9.0 diff --git a/tests/attr/feat/test_feature_attribution.py b/tests/attr/feat/test_feature_attribution.py index 07b2045c..80856176 100644 --- a/tests/attr/feat/test_feature_attribution.py +++ b/tests/attr/feat/test_feature_attribution.py @@ -1,8 +1,14 @@ +from typing import Any, Optional + import torch +from captum._utils.typing import TensorOrTupleOfTensorsGeneric from pytest import fixture import inseq +from inseq.attr.feat.internals_attribution import InternalsAttributionRegistry +from inseq.data import MultiDimensionalFeatureAttributionStepOutput from inseq.models import HuggingfaceDecoderOnlyModel, HuggingfaceEncoderDecoderModel +from inseq.utils.typing import InseqAttribution, MultiLayerMultiUnitScoreTensor @fixture(scope="session") @@ -69,7 +75,7 @@ def test_contrastive_attribution_seq2seq_alignments(saliency_mt_model_larger: Hu "orig_tgt": "I soldati della pace ONU", "contrast_tgt": "Le forze militari di pace delle Nazioni Unite", "alignments": [[(0, 0), (1, 1), (2, 2), (3, 4), (4, 5), (5, 7), (6, 9)]], - "aligned_tgts": ["ā–Le ā†’ ā–I", "ā–forze ā†’ ā–soldati", "ā–di ā†’ ā–della", "ā–pace", "ā–Nazioni ā†’ ā–ONU", ""], + "aligned_tgts": ["", "ā–Le ā†’ ā–I", "ā–forze ā†’ ā–soldati", "ā–di ā†’ ā–della", "ā–pace", "ā–Nazioni ā†’ ā–ONU", ""], } out = saliency_mt_model_larger.attribute( aligned["src"], @@ -129,3 +135,122 @@ def test_mcd_weighted_attribution_gpt(saliency_gpt_model): ) attribution_scores = out.sequence_attributions[0].target_attributions assert isinstance(attribution_scores, torch.Tensor) + + +class MultiStepAttentionWeights(InseqAttribution): + """Variant of the AttentionWeights class with is_final_step_method = False. + As a result, the attention matrix is computed and sliced at every generation step. + We define it here to test consistency with the final step method. + """ + + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + additional_forward_args: TensorOrTupleOfTensorsGeneric, + encoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None, + decoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None, + cross_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None, + ) -> MultiDimensionalFeatureAttributionStepOutput: + # We adopt the format [batch_size, sequence_length, num_layers, num_heads] + # for consistency with other multi-unit methods (e.g. gradient attribution) + decoder_self_attentions = decoder_self_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2) + if self.forward_func.is_encoder_decoder: + sequence_scores = {} + if len(inputs) > 1: + target_attributions = decoder_self_attentions + else: + target_attributions = None + sequence_scores["decoder_self_attentions"] = decoder_self_attentions + sequence_scores["encoder_self_attentions"] = ( + encoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2) + ) + return MultiDimensionalFeatureAttributionStepOutput( + source_attributions=cross_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2), + target_attributions=target_attributions, + sequence_scores=sequence_scores, + _num_dimensions=2, # num_layers, num_heads + ) + else: + return MultiDimensionalFeatureAttributionStepOutput( + source_attributions=None, + target_attributions=decoder_self_attentions, + _num_dimensions=2, # num_layers, num_heads + ) + + +class MultiStepAttentionWeightsAttribution(InternalsAttributionRegistry): + """Variant of the basic attention attribution method computing attention weights at every generation step.""" + + method_name = "per_step_attention" + + def __init__(self, attribution_model, **kwargs): + super().__init__(attribution_model) + # Attention weights will be passed to the attribute_step method + self.use_attention_weights = True + # Does not rely on predicted output (i.e. decoding strategy agnostic) + self.use_predicted_target = False + self.method = MultiStepAttentionWeights(attribution_model) + + def attribute_step( + self, + attribute_fn_main_args: dict[str, Any], + attribution_args: dict[str, Any], + ) -> MultiDimensionalFeatureAttributionStepOutput: + return self.method.attribute(**attribute_fn_main_args, **attribution_args) + + +def test_seq2seq_final_step_per_step_conformity(saliency_mt_model_larger: HuggingfaceEncoderDecoderModel): + out_per_step = saliency_mt_model_larger.attribute( + "Hello ladies and badgers!", + method="per_step_attention", + attribute_target=True, + show_progress=False, + output_step_attributions=True, + ) + out_final_step = saliency_mt_model_larger.attribute( + "Hello ladies and badgers!", + method="attention", + attribute_target=True, + show_progress=False, + output_step_attributions=True, + ) + assert out_per_step[0] == out_final_step[0] + + +def test_gpt_final_step_per_step_conformity(saliency_gpt_model_larger: HuggingfaceDecoderOnlyModel): + out_per_step = saliency_gpt_model_larger.attribute( + "Hello ladies and badgers!", + method="per_step_attention", + show_progress=False, + output_step_attributions=True, + ) + out_final_step = saliency_gpt_model_larger.attribute( + "Hello ladies and badgers!", + method="attention", + show_progress=False, + output_step_attributions=True, + ) + assert out_per_step[0] == out_final_step[0] + + +# Batching for Seq2Seq models is not supported when using is_final_step methods +# Passing several sentences will attributed them one by one under the hood +# def test_seq2seq_multi_step_attention_weights_batched_full_match(saliency_mt_model: HuggingfaceEncoderDecoderModel): + + +def test_gpt_multi_step_attention_weights_batched_full_match(saliency_gpt_model_larger: HuggingfaceDecoderOnlyModel): + out_per_step = saliency_gpt_model_larger.attribute( + ["Hello world!", "Colorless green ideas sleep furiously."], + method="per_step_attention", + show_progress=False, + ) + out_final_step = saliency_gpt_model_larger.attribute( + ["Hello world!", "Colorless green ideas sleep furiously."], + method="attention", + show_progress=False, + ) + for i in range(2): + assert out_per_step[i].target_attributions.shape == out_final_step[i].target_attributions.shape + assert torch.allclose( + out_per_step[i].target_attributions, out_final_step[i].target_attributions, equal_nan=True, atol=1e-5 + ) diff --git a/tests/commands/test_attribute_context.py b/tests/commands/test_attribute_context.py index d747213e..1013b1cc 100644 --- a/tests/commands/test_attribute_context.py +++ b/tests/commands/test_attribute_context.py @@ -41,6 +41,7 @@ def test_in_out_ctx_encdec_whitespace_sep(encdec_model: MarianMTModel): input_template="{context} {current}", attributed_fn="contrast_prob_diff", show_viz=False, + show_intermediate_outputs=True, # Pre-defining natural model outputs to avoid user input in unit tests output_context_text="", output_current_text="OĆ¹ sont-elles?", @@ -77,6 +78,7 @@ def test_in_ctx_deconly(deconly_model: GPT2LMHeadModel): model_name_or_path=deconly_model, input_context_text="George was sick yesterday.", input_current_text="His colleagues asked him to come", + output_current_text="to the hospital. He said he was fine", attributed_fn="contrast_prob_diff", show_viz=False, add_output_info=False, @@ -110,12 +112,14 @@ def test_out_ctx_deconly(deconly_model: GPT2LMHeadModel): # Base case for context-aware decoder-only model with forced output context mocking a reasoning chain. out_ctx_deconly = AttributeContextArgs( model_name_or_path=deconly_model, - output_template="\n\nLet's think step by step:\n{context}\n\nAnswer:\n{current}", + output_template="Let's think step by step:\n{context}\n\nAnswer:\n{current}", input_template="{current}", input_current_text="Question: How many pairs of legs do 10 horses have?", output_context_text="1. A horse has 4 legs.\n2. 10 horses have 40 legs.\n3. 40 legs make 20 pairs of legs.", output_current_text="20 pairs of legs.", attributed_fn="contrast_prob_diff", + decoder_input_output_separator="\n", + contextless_output_current_text="Answer:\n{current}", show_viz=False, add_output_info=False, ) @@ -155,16 +159,45 @@ def test_out_ctx_deconly(deconly_model: GPT2LMHeadModel): ], output_current="20 pairs of legs.", output_current_tokens=["20", "Ä pairs", "Ä of", "Ä legs", "."], - cti_scores=[4.53, 1.33, 0.43, 0.74, 0.93], + cti_scores=[0.77, 1.39, 0.54, 0.48, 0.91], cci_scores=[ CCIOutput( - cti_idx=0, - cti_token="20 ā†’ Ä 20", - cti_score=4.53, - contextual_output="Question: How many pairs of legs do 10 horses have?\n\nLet's think step by step:\n1. A horse has 4 legs.\n2. 10 horses have 40 legs.\n3. 40 legs make 20 pairs of legs.\n\nAnswer:\n20", - contextless_output="Question: How many pairs of legs do 10 horses have?\n", + cti_idx=1, + cti_token="Ä pairs", + cti_score=1.39, + contextual_output="Question: How many pairs of legs do 10 horses have?\nLet's think step by step:\n1. A horse has 4 legs.\n2. 10 horses have 40 legs.\n3. 40 legs make 20 pairs of legs.\n\nAnswer:\n20 pairs", + contextless_output="Question: How many pairs of legs do 10 horses have?\nAnswer:\n20 horses", input_context_scores=None, - output_context_scores=[0.0] * 28, + output_context_scores=[ + 0.1, + 0.1, + 0.06, + 0.19, + 0.05, + 0.07, + 0.17, + 0.1, + 0.08, + 0.07, + 0.11, + 0.22, + 0.44, + 0.07, + 0.15, + 0.17, + 0.12, + 0.13, + 0.06, + 0.14, + 0.11, + 0.19, + 0.16, + 0.34, + 1.33, + 0.04, + 0.13, + 0.07, + ], ), ], info=None, @@ -180,6 +213,7 @@ def test_in_out_ctx_deconly(deconly_model: GPT2LMHeadModel): input_context_text="George was sick yesterday.", input_current_text="His colleagues asked him if", output_context_text="something was wrong. He said", + output_current_text="he was fine.", attributed_fn="contrast_prob_diff", show_viz=False, add_output_info=False, diff --git a/tests/data/test_aggregator.py b/tests/data/test_aggregator.py index eb5086ca..f7e7c3e5 100644 --- a/tests/data/test_aggregator.py +++ b/tests/data/test_aggregator.py @@ -39,14 +39,14 @@ def test_sequence_attribution_aggregator(saliency_mt_model: HuggingfaceEncoderDe ) seqattr = out.sequence_attributions[0] assert seqattr.source_attributions.shape == (6, 7, 512) - assert seqattr.target_attributions.shape == (7, 7, 512) + assert seqattr.target_attributions.shape == (8, 7, 512) assert seqattr.step_scores["probability"].shape == (7,) for i, step in enumerate(out.step_attributions): assert step.source_attributions.shape == (1, 6, 512) assert step.target_attributions.shape == (1, i + 1, 512) out_agg = seqattr.aggregate() assert out_agg.source_attributions.shape == (6, 7) - assert out_agg.target_attributions.shape == (7, 7) + assert out_agg.target_attributions.shape == (8, 7) assert out_agg.step_scores["probability"].shape == (7,) @@ -56,9 +56,9 @@ def test_continuous_span_aggregator(saliency_mt_model: HuggingfaceEncoderDecoder ) seqattr = out.sequence_attributions[0] out_agg = seqattr.aggregate(ContiguousSpanAggregator, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) - assert out_agg.source_attributions.shape == (5, 4, 512) - assert out_agg.target_attributions.shape == (4, 4, 512) - assert out_agg.step_scores["probability"].shape == (4,) + assert out_agg.source_attributions.shape == (5, 5, 512) + assert out_agg.target_attributions.shape == (5, 5, 512) + assert out_agg.step_scores["probability"].shape == (5,) def test_span_aggregator_with_prefix(saliency_gpt_model: HuggingfaceDecoderOnlyModel): @@ -76,14 +76,14 @@ def test_aggregator_pipeline(saliency_mt_model: HuggingfaceEncoderDecoderModel): seqattr = out.sequence_attributions[0] squeezesum = AggregatorPipeline([ContiguousSpanAggregator, SequenceAttributionAggregator]) out_agg_squeezesum = seqattr.aggregate(squeezesum, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) - assert out_agg_squeezesum.source_attributions.shape == (5, 4) - assert out_agg_squeezesum.target_attributions.shape == (4, 4) - assert out_agg_squeezesum.step_scores["probability"].shape == (4,) + assert out_agg_squeezesum.source_attributions.shape == (5, 5) + assert out_agg_squeezesum.target_attributions.shape == (5, 5) + assert out_agg_squeezesum.step_scores["probability"].shape == (5,) sumsqueeze = AggregatorPipeline([SequenceAttributionAggregator, ContiguousSpanAggregator]) out_agg_sumsqueeze = seqattr.aggregate(sumsqueeze, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) - assert out_agg_sumsqueeze.source_attributions.shape == (5, 4) - assert out_agg_sumsqueeze.target_attributions.shape == (4, 4) - assert out_agg_sumsqueeze.step_scores["probability"].shape == (4,) + assert out_agg_sumsqueeze.source_attributions.shape == (5, 5) + assert out_agg_sumsqueeze.target_attributions.shape == (5, 5) + assert out_agg_sumsqueeze.step_scores["probability"].shape == (5,) assert not torch.allclose(out_agg_squeezesum.source_attributions, out_agg_sumsqueeze.source_attributions) assert not torch.allclose(out_agg_squeezesum.target_attributions, out_agg_sumsqueeze.target_attributions) # Named indexing version @@ -91,12 +91,12 @@ def test_aggregator_pipeline(saliency_mt_model: HuggingfaceEncoderDecoderModel): named_sumsqueeze = ["scores", "spans"] out_agg_squeezesum_named = seqattr.aggregate(named_squeezesum, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) out_agg_sumsqueeze_named = seqattr.aggregate(named_sumsqueeze, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) - assert out_agg_squeezesum_named.source_attributions.shape == (5, 4) - assert out_agg_squeezesum_named.target_attributions.shape == (4, 4) - assert out_agg_squeezesum_named.step_scores["probability"].shape == (4,) - assert out_agg_sumsqueeze_named.source_attributions.shape == (5, 4) - assert out_agg_sumsqueeze_named.target_attributions.shape == (4, 4) - assert out_agg_sumsqueeze_named.step_scores["probability"].shape == (4,) + assert out_agg_squeezesum_named.source_attributions.shape == (5, 5) + assert out_agg_squeezesum_named.target_attributions.shape == (5, 5) + assert out_agg_squeezesum_named.step_scores["probability"].shape == (5,) + assert out_agg_sumsqueeze_named.source_attributions.shape == (5, 5) + assert out_agg_sumsqueeze_named.target_attributions.shape == (5, 5) + assert out_agg_sumsqueeze_named.step_scores["probability"].shape == (5,) assert not torch.allclose( out_agg_squeezesum_named.source_attributions, out_agg_sumsqueeze_named.source_attributions ) diff --git a/tests/fixtures/aggregator.json b/tests/fixtures/aggregator.json index fc029eec..53123526 100644 --- a/tests/fixtures/aggregator.json +++ b/tests/fixtures/aggregator.json @@ -36,6 +36,7 @@ ], "target": "Inseq \u00e8 un framework per l'attribuzione automatica di modelli sequenziali.", "target_subwords": [ + "", "\u2581In", "se", "q", @@ -58,6 +59,7 @@ "" ], "target_merged": [ + "", "\u2581Inseq", "\u2581\u00e8", "\u2581un", diff --git a/tests/inference_commons.py b/tests/inference_commons.py index 3da21068..19810018 100644 --- a/tests/inference_commons.py +++ b/tests/inference_commons.py @@ -1,3 +1,6 @@ +import json +import os + from inseq.data import EncoderDecoderBatch from inseq.utils import json_advanced_load @@ -9,3 +12,8 @@ def get_example_batches(): dict_batches["batches"] = [batch.torch() for batch in dict_batches["batches"]] assert all(isinstance(batch, EncoderDecoderBatch) for batch in dict_batches["batches"]) return dict_batches + + +def load_examples() -> dict: + file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/huggingface_model.json") + return json.load(open(file)) diff --git a/tests/models/test_huggingface_model.py b/tests/models/test_huggingface_model.py index 993c07ac..72da4a2f 100644 --- a/tests/models/test_huggingface_model.py +++ b/tests/models/test_huggingface_model.py @@ -2,8 +2,6 @@ since it is bugged is not very elegant, this will need to be refactored. """ -import json -import os import pytest import torch @@ -15,8 +13,9 @@ from inseq.data import FeatureAttributionOutput, FeatureAttributionSequenceOutput from inseq.utils import get_default_device -EXAMPLES_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures/huggingface_model.json") -EXAMPLES = json.load(open(EXAMPLES_FILE)) +from ..inference_commons import load_examples + +EXAMPLES = load_examples() USE_REFERENCE_TEXT = [True, False] ATTRIBUTE_TARGET = [True, False] @@ -275,8 +274,8 @@ def test_attribute_slice_seq2seq(saliency_mt_model): assert ex2.attr_pos_start == len(ex2.target) assert ex2.attr_pos_end == len(ex2.target) assert ex2.source_attributions.shape[1] == 0 and ex2.target_attributions.shape[1] == 0 - assert ex3.attr_pos_start == 12 - assert ex3.attr_pos_end == 15 + assert ex3.attr_pos_start == 13 + assert ex3.attr_pos_end == 16 assert ex1.source_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start assert ex1.target_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start assert ex1.target_attributions.shape[0] == ex1.attr_pos_end @@ -303,12 +302,12 @@ def test_attribute_decoder(saliency_gpt2_model): assert ex1.target_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start assert ex1.target_attributions.shape[0] == ex1.attr_pos_end # Empty attributions outputs have start and end set to seq length - assert ex2.attr_pos_start == 17 - assert ex2.attr_pos_end == 22 + assert ex2.attr_pos_start == 9 + assert ex2.attr_pos_end == 14 assert ex2.target_attributions.shape[1] == ex2.attr_pos_end - ex2.attr_pos_start assert ex2.target_attributions.shape[0] == ex2.attr_pos_end - assert ex3.attr_pos_start == 17 - assert ex3.attr_pos_end == 22 + assert ex3.attr_pos_start == 12 + assert ex3.attr_pos_end == 17 assert ex3.target_attributions.shape[1] == ex3.attr_pos_end - ex3.attr_pos_start assert ex3.target_attributions.shape[0] == ex3.attr_pos_end assert out.info["attr_pos_start"] == 17