diff --git a/README.md b/README.md index 2833e681..1b859637 100644 --- a/README.md +++ b/README.md @@ -114,17 +114,17 @@ model.attribute( - 🚀 Feature attribution of sequence generation for most `ForConditionalGeneration` (encoder-decoder) and `ForCausalLM` (decoder-only) models from 🀗 Transformers -- 🚀 Support for multiple feature attribution methods, sourced in part from [Captum](https://captum.ai/docs/introduction) +- 🚀 Support for multiple feature attribution methods, extending the ones supported by [Captum](https://captum.ai/docs/introduction) -- 🚀 Post-processing of attribution maps via `Aggregator` classes. +- 🚀 Post-processing, filtering and merging of attribution maps via `Aggregator` classes. - 🚀 Attribution visualization in notebooks, browser and command line. -- 🚀 Attribute single examples or entire 🀗 datasets with the Inseq CLI. +- 🚀 Efficient attribution of single examples or entire 🀗 datasets with the Inseq CLI. -- 🚀 Custom attribution of target functions, supporting advanced use cases such as contrastive and uncertainty-weighted feature attributions. +- 🚀 Custom attribution of target functions, supporting advanced methods such as [contrastive feature attributions](https://aclanthology.org/2022.emnlp-main.14/) and [context reliance detection](https://arxiv.org/abs/2310.01188). -- 🚀 Extraction and visualization of custom step scores (e.g. probability, entropy) alongsides attribution maps. +- 🚀 Extraction and visualization of custom scores (e.g. probability, entropy) at every generation step alongsides attribution maps. ### Supported methods @@ -196,32 +196,80 @@ out.show() Refer to the [documentation](https://inseq.readthedocs.io/examples/custom_attribute_target.html) for an example including custom function registration. -## Using the Inseq client +## Using the Inseq CLI The Inseq library also provides useful client commands to enable repeated attribution of individual examples and even entire 🀗 datasets directly from the console. See the available options by typing `inseq -h` in the terminal after installing the package. -For now, two commands are supported: +Three commands are supported: -- `ìnseq attribute`: Wraps the `attribute` method shown above, requires explicit inputs to be attributed. +- `inseq attribute`: Wrapper for enabling `model.attribute` usage in console. -- `inseq attribute-dataset`: Enables attribution for a full dataset using Hugging Face `datasets.load_dataset`. +- `inseq attribute-dataset`: Extends `attribute` to full dataset using Hugging Face `datasets.load_dataset` API. -Both commands support the full range of parameters available for `attribute`, attribution visualization in the console and saving outputs to disk. +- `inseq attribute-context`: Detects and attribute context dependence for generation tasks using the approach of [Sarti et al. (2023)](https://arxiv.org/abs/2310.01188). -**Example:** The following command can be used to perform attribution (both source and target-side) of Italian translations for a dummy sample of 20 English sentences taken from the FLORES-101 parallel corpus, using a MarianNMT translation model from Hugging Face `transformers`. We save the visualizations in HTML format in the file `attributions.html`. See the `--help` flag for more options. +All commands support the full range of parameters available for `attribute`, attribution visualization in the console and saving outputs to disk. -```bash -inseq attribute-dataset \ +
+ inseq attribute example + + The following example performs a simple feature attribution of an English sentence translated into Italian using a MarianNMT translation model from transformers. The final result is printed to the console. + ```bash + inseq attribute \ --model_name_or_path Helsinki-NLP/opus-mt-en-it \ --attribution_method saliency \ - --do_prefix_attribution \ - --dataset_name inseq/dummy_enit \ - --input_text_field en \ - --dataset_split "train[:20]" \ - --viz_path attributions.html \ - --batch_size 8 \ - --hide -``` + --input_texts "Hello world this is Inseq\! Inseq is a very nice library to perform attribution analysis" + ``` + +
+ +
+ inseq attribute-dataset example + + The following code can be used to perform attribution (both source and target-side) of Italian translations for a dummy sample of 20 English sentences taken from the FLORES-101 parallel corpus, using a MarianNMT translation model from Hugging Face transformers. We save the visualizations in HTML format in the file attributions.html. See the --help flag for more options. + + ```bash + inseq attribute-dataset \ + --model_name_or_path Helsinki-NLP/opus-mt-en-it \ + --attribution_method saliency \ + --do_prefix_attribution \ + --dataset_name inseq/dummy_enit \ + --input_text_field en \ + --dataset_split "train[:20]" \ + --viz_path attributions.html \ + --batch_size 8 \ + --hide + ``` +
+ +
+ inseq attribute-context example + + The following example uses a GPT-2 model to generate a continuation of input_current_text, and uses the additional context provided by input_context_text to estimate its influence on the the generation. In this case, the output "to the hospital. He said he was fine" is produced, and the generation of token hospital is found to be dependent on context token sick according to the contrast_prob_diff step function. + + ```bash + inseq attribute-context \ + --model_name_or_path gpt2 \ + --input_context_text "George was sick yesterday." \ + --input_current_text "His colleagues asked him to come" \ + --attributed_fn "contrast_prob_diff" + ``` + + **Result:** + + ``` + Context with [contextual cues] (std λ=1.00) followed by output sentence with {context-sensitive target spans} (std λ=1.00) + (CTI = "kl_divergence", CCI = "saliency" w/ "contrast_prob_diff" target) + + Input context: George was sick yesterday. + Input current: His colleagues asked him to come + Output current: to the hospital. He said he was fine + + #1. + Generated output (CTI > 0.428): to the {hospital}(0.548). He said he was fine + Input context (CCI > 0.460): George was [sick](0.516) yesterday. + ``` +
## Planned Development diff --git a/docs/source/index.rst b/docs/source/index.rst index 147447b8..b1be0a45 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -104,6 +104,7 @@ Inseq is still in early development and is currently maintained by a small team :maxdepth: 3 :caption: API Documentation + main_classes/cli main_classes/main_functions main_classes/models main_classes/feature_attribution diff --git a/docs/source/main_classes/cli.rst b/docs/source/main_classes/cli.rst new file mode 100644 index 00000000..1793360c --- /dev/null +++ b/docs/source/main_classes/cli.rst @@ -0,0 +1,52 @@ +.. + Copyright 2024 The Inseq Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +Inseq CLI +======================================================================================================================= + +The Inseq CLI is a command line interface for the Inseq library. The CLI enables repeated attribution of individual +examples and even entire 🀗 datasets directly from the console. See the available options by typing ``inseq -h`` in the +terminal after installing the package. + +Three commands are supported: + +- ``inseq attribute``: Wrapper for enabling ``model.attribute`` usage in console. + +- ``inseq attribute-dataset``: Extends ``attribute`` to full dataset using Hugging Face ``datasets.load_dataset`` API. + +- ``inseq attribute-context``: Detects and attribute context dependence for generation tasks using the approach of `Sarti et al. (2023) `__. + +``attribute`` +----------------------------------------------------------------------------------------------------------------------- + +The ``attribute`` command enables attribution of individual examples directly from the console. The command takes the +following arguments: + +.. autoclass:: inseq.commands.attribute.attribute_args.AttributeWithInputsArgs + +``attribute-dataset`` +----------------------------------------------------------------------------------------------------------------------- + +The ``attribute-dataset`` command extends the ``attribute`` command to full datasets using the Hugging Face +``datasets.load_dataset`` API. The command takes the following arguments: + +.. autoclass:: inseq.commands.attribute_dataset.attribute_dataset_args.LoadDatasetArgs + +.. autoclass:: inseq.commands.attribute.attribute_args.AttributeExtendedArgs + +``attribute-context`` +----------------------------------------------------------------------------------------------------------------------- + +The ``attribute-context`` command detects and attributes context dependence for generation tasks using the approach of +`Sarti et al. (2023) `__. The command takes the following arguments: + +.. autoclass:: inseq.commands.attribute_context.attribute_context_args.AttributeContextArgs \ No newline at end of file diff --git a/inseq/attr/step_functions.py b/inseq/attr/step_functions.py index 6369fd3f..6034f094 100644 --- a/inseq/attr/step_functions.py +++ b/inseq/attr/step_functions.py @@ -1,5 +1,6 @@ import logging from dataclasses import dataclass +from inspect import signature from typing import TYPE_CHECKING, Any, Dict, List, Optional, Protocol, Tuple, Union import torch @@ -500,3 +501,7 @@ def register_step_function( if agg_name not in DEFAULT_ATTRIBUTION_AGGREGATE_DICT["step_scores"]: DEFAULT_ATTRIBUTION_AGGREGATE_DICT["step_scores"][agg_name] = {} DEFAULT_ATTRIBUTION_AGGREGATE_DICT["step_scores"][agg_name][identifier] = aggregation_fn_identifier + + +def is_contrastive_step_function(step_fn_id: str) -> bool: + return "contrast_targets" in signature(get_step_function(step_fn_id)).parameters diff --git a/inseq/commands/attribute.py b/inseq/commands/attribute.py deleted file mode 100644 index 045898e1..00000000 --- a/inseq/commands/attribute.py +++ /dev/null @@ -1,201 +0,0 @@ -import logging -from dataclasses import dataclass, field -from typing import List, Optional - -from .. import list_feature_attribution_methods, load_model -from ..utils import get_default_device -from .base import BaseCLICommand - - -@dataclass -class AttributeBaseArgs: - model_name_or_path: str = field( - metadata={"alias": "-m", "help": "The name or path of the model on which attribution is performed."}, - ) - attribution_method: Optional[str] = field( - default="integrated_gradients", - metadata={ - "alias": "-a", - "help": "The attribution method used to perform feature attribution.", - "choices": list_feature_attribution_methods(), - }, - ) - do_prefix_attribution: bool = field( - default=False, - metadata={ - "help": "Performs the attribution procedure including the generated prefix at every step.", - }, - ) - generate_from_target_prefix: bool = field( - default=False, - metadata={ - "help": ( - "Whether the ``generated_texts`` should be used as" - "target prefixes for the generation process. If False, the ``generated_texts`` will be used as full" - "targets. This option is only available for encoder-decoder models, since the same behavior can be" - "achieved by modifying the input texts for decoder-only models. Default: False." - ) - }, - ) - step_scores: List[str] = field( - default_factory=list, metadata={"help": "Adds step scores to the attribution output."} - ) - output_step_attributions: bool = field( - default=False, metadata={"help": "Adds step-level feature attributions to the output."} - ) - include_eos_baseline: bool = field( - default=False, - metadata={ - "alias": "--eos", - "help": "Whether the EOS token should be included in the baseline, used for some attribution methods.", - }, - ) - n_approximation_steps: Optional[int] = field( - default=100, - metadata={"alias": "-n", "help": "Number of approximation steps, used for some attribution methods."}, - ) - return_convergence_delta: bool = field( - default=False, - metadata={ - "help": "Returns the convergence delta of the approximation, used for some attribution methods.", - }, - ) - batch_size: int = field( - default=8, - metadata={ - "alias": "-bs", - "help": "The batch size used for the attribution computation. By default, no batching is performed.", - }, - ) - attribution_batch_size: Optional[int] = field( - default=50, - metadata={ - "help": "The internal batch size used by the attribution method, used for some attribution methods.", - }, - ) - aggregate_output: bool = field( - default=False, - metadata={ - "help": "If specified, the attribution output is aggregated using its default aggregator before saving.", - }, - ) - device: str = field( - default=get_default_device(), - metadata={"alias": "--dev", "help": "The device used for inference with Pytorch. Multi-GPU is not supported."}, - ) - hide_attributions: bool = field( - default=False, - metadata={ - "alias": "--hide", - "help": "If specified, the attribution visualization are not shown in the output.", - }, - ) - save_path: Optional[str] = field( - default=None, - metadata={"alias": "-o", "help": "Path where the attribution output should be saved in JSON format."}, - ) - viz_path: Optional[str] = field( - default=None, - metadata={ - "help": "Path where the attribution visualization should be saved in HTML format.", - }, - ) - max_gen_length: Optional[int] = field( - default=None, - metadata={"alias": "-l", "help": "Max generation length for model outputs. Default: 512"}, - ) - start_pos: Optional[int] = field( - default=None, - metadata={"alias": "-s", "help": "Start position for the attribution. Default: first token"}, - ) - end_pos: Optional[int] = field( - default=None, - metadata={"alias": "-e", "help": "End position for the attribution. Default: last token"}, - ) - verbose: bool = field( - default=False, - metadata={"alias": "-v", "help": "If specified, use INFO as logging level for the attribution."}, - ) - very_verbose: bool = field( - default=False, - metadata={"alias": "-vv", "help": "If specified, use DEBUG as logging level for the attribution."}, - ) - - -@dataclass -class AttributeArgs(AttributeBaseArgs): - input_texts: List[str] = field( - default=None, metadata={"alias": "-i", "help": "One or more input texts used for generation."} - ) - generated_texts: Optional[List[str]] = field( - default=None, - metadata={ - "alias": "-g", - "help": "If specified, constrains the decoding procedure to the specified outputs.", - }, - ) - - def __post_init__(self): - if self.input_texts is None: - raise RuntimeError("Input texts must be specified.") - if isinstance(self.input_texts, str): - self.input_texts = list(self.input_texts) - if isinstance(self.generated_texts, str): - self.generated_texts = list(self.generated_texts) - - -def attribute(input_texts, generated_texts, args: AttributeBaseArgs): - if args.very_verbose: - log_level = logging.DEBUG - elif args.verbose: - log_level = logging.INFO - else: - log_level = logging.WARNING - logging.basicConfig( - level=log_level, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - model = load_model( - args.model_name_or_path, - attribution_method=args.attribution_method, - device=args.device, - ) - out = model.attribute( - input_texts, - generated_texts, - batch_size=args.batch_size, - attribute_target=args.do_prefix_attribution, - step_scores=args.step_scores, - output_step_attributions=args.output_step_attributions, - include_eos_baseline=args.include_eos_baseline, - n_steps=args.n_approximation_steps, - internal_batch_size=args.attribution_batch_size, - return_convergence_delta=args.return_convergence_delta, - device=args.device, - generation_args={"max_new_tokens": args.max_gen_length}, - attr_pos_start=args.start_pos, - attr_pos_end=args.end_pos, - generate_from_target_prefix=args.generate_from_target_prefix, - ) - if args.viz_path: - print(f"Saving visualization to {args.viz_path}") - html = out.show(return_html=True, display=not args.hide_attributions) - with open(args.viz_path, "w") as f: - f.write(html) - else: - out.show(display=not args.hide_attributions) - if args.save_path: - if args.aggregate_output: - out = out.aggregate() - print(f"Saving {'aggregated ' if args.aggregate_output else ''}attributions to {args.save_path}") - out.save(args.save_path, overwrite=True) - - -class AttributeCommand(BaseCLICommand): - _name = "attribute" - _help = "Perform feature attribution on one or multiple sentences" - _dataclasses = AttributeArgs - - def run(args: AttributeArgs): - attribute(args.input_texts, args.generated_texts, args) diff --git a/inseq/commands/attribute/__init__.py b/inseq/commands/attribute/__init__.py new file mode 100644 index 00000000..abb9a041 --- /dev/null +++ b/inseq/commands/attribute/__init__.py @@ -0,0 +1,10 @@ +from .attribute import AttributeCommand, aggregate_attribution_scores +from .attribute_args import AttributeBaseArgs, AttributeExtendedArgs, AttributeWithInputsArgs + +__all__ = [ + "AttributeCommand", + "aggregate_attribution_scores", + "AttributeBaseArgs", + "AttributeExtendedArgs", + "AttributeWithInputsArgs", +] diff --git a/inseq/commands/attribute/attribute.py b/inseq/commands/attribute/attribute.py new file mode 100644 index 00000000..55607abc --- /dev/null +++ b/inseq/commands/attribute/attribute.py @@ -0,0 +1,93 @@ +import logging +from typing import List, Optional + +from ... import FeatureAttributionOutput, load_model +from ..base import BaseCLICommand +from .attribute_args import AttributeExtendedArgs, AttributeWithInputsArgs + + +def aggregate_attribution_scores( + out: FeatureAttributionOutput, + selectors: Optional[List[int]] = None, + aggregators: Optional[List[str]] = None, + normalize_attributions: bool = False, +) -> FeatureAttributionOutput: + if selectors is not None and aggregators is not None: + for select_idx, aggregator_fn in zip(selectors, aggregators): + out = out.aggregate( + aggregator=aggregator_fn, + normalize=normalize_attributions, + select_idx=select_idx, + do_post_aggregation_checks=False, + ) + else: + out = out.aggregate(aggregator=aggregators, normalize=normalize_attributions) + return out + + +def attribute(input_texts, generated_texts, args: AttributeExtendedArgs): + if args.very_verbose: + log_level = logging.DEBUG + elif args.verbose: + log_level = logging.INFO + else: + log_level = logging.WARNING + logging.basicConfig( + level=log_level, + format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + model = load_model( + args.model_name_or_path, + attribution_method=args.attribution_method, + device=args.device, + model_kwargs=args.model_kwargs, + tokenizer_kwargs=args.tokenizer_kwargs, + ) + # Handle language tag for multilingual models - no need to specify it in generation kwargs + if "tgt_lang" in args.tokenizer_kwargs and "forced_bos_token_id" not in args.generation_kwargs: + tgt_lang = args.tokenizer_kwargs["tgt_lang"] + args.generation_kwargs["forced_bos_token_id"] = model.tokenizer.lang_code_to_id[tgt_lang] + + out = model.attribute( + input_texts, + generated_texts, + batch_size=args.batch_size, + attribute_target=args.attribute_target, + attributed_fn=args.attributed_fn, + step_scores=args.step_scores, + output_step_attributions=args.output_step_attributions, + include_eos_baseline=args.include_eos_baseline, + device=args.device, + generation_args=args.generation_kwargs, + attr_pos_start=args.start_pos, + attr_pos_end=args.end_pos, + generate_from_target_prefix=args.generate_from_target_prefix, + **args.attribution_kwargs, + ) + if args.viz_path: + print(f"Saving visualization to {args.viz_path}") + html = out.show(return_html=True, display=not args.hide_attributions) + with open(args.viz_path, "w") as f: + f.write(html) + else: + out.show(display=not args.hide_attributions) + if args.save_path: + if args.attribution_aggregators is not None: + out = aggregate_attribution_scores( + out=out, + selectors=args.attribution_selectors, + aggregators=args.attribution_aggregators, + normalize_attributions=args.normalize_attributions, + ) + print(f"Saving {'aggregated ' if args.aggregate_output else ''}attributions to {args.save_path}") + out.save(args.save_path, overwrite=True) + + +class AttributeCommand(BaseCLICommand): + _name = "attribute" + _help = "Perform feature attribution on one or multiple sentences" + _dataclasses = AttributeWithInputsArgs + + def run(args: AttributeWithInputsArgs): + attribute(args.input_texts, args.generated_texts, args) diff --git a/inseq/commands/attribute/attribute_args.py b/inseq/commands/attribute/attribute_args.py new file mode 100644 index 00000000..be6540d1 --- /dev/null +++ b/inseq/commands/attribute/attribute_args.py @@ -0,0 +1,157 @@ +from dataclasses import dataclass +from typing import List, Optional + +from ... import ( + list_aggregation_functions, + list_aggregators, + list_feature_attribution_methods, + list_step_functions, +) +from ...utils import cli_arg, get_default_device +from ..commands_utils import command_args_docstring + + +@command_args_docstring +@dataclass +class AttributeBaseArgs: + model_name_or_path: str = cli_arg( + default=None, aliases=["-m"], help="The name or path of the model on which attribution is performed." + ) + attribution_method: Optional[str] = cli_arg( + default="saliency", + aliases=["-a"], + help="The attribution method used to perform feature attribution.", + choices=list_feature_attribution_methods(), + ) + device: str = cli_arg( + default=get_default_device(), + aliases=["--dev"], + help="The device used for inference with Pytorch. Multi-GPU is not supported.", + ) + attributed_fn: Optional[str] = cli_arg( + default=None, + aliases=["-fn"], + choices=list_step_functions(), + help=( + "The attribution target used for the attribution method. Default: ``probability``. If a" + " step function requiring additional arguments is used (e.g. ``contrast_prob_diff``), they should be" + " specified using the ``attribution_kwargs`` argument." + ), + ) + attribution_selectors: Optional[List[int]] = cli_arg( + default=None, + help=( + "The indices of the attribution scores to be used for the attribution aggregation. If specified, the" + " aggregation function is applied only to the selected scores, and the other scores are discarded." + " If not specified, the aggregation function is applied to all the scores." + ), + ) + attribution_aggregators: List[str] = cli_arg( + default=None, + help=( + "The aggregators used to aggregate the attribution scores for each context. The outcome should" + " produce one score per input token" + ), + choices=list_aggregators() + list_aggregation_functions(), + ) + normalize_attributions: bool = cli_arg( + default=False, + help=( + "Whether to normalize the attribution scores for each context. If ``True``, the attribution scores " + "for each context are normalized to sum up to 1, providing a relative notion of input salience." + ), + ) + model_kwargs: dict = cli_arg( + default_factory=dict, + help="Additional keyword arguments passed to the model constructor in JSON format.", + ) + tokenizer_kwargs: dict = cli_arg( + default_factory=dict, + help="Additional keyword arguments passed to the tokenizer constructor in JSON format.", + ) + generation_kwargs: dict = cli_arg( + default_factory=dict, + help="Additional keyword arguments passed to the generation method in JSON format.", + ) + attribution_kwargs: dict = cli_arg( + default_factory=dict, + help="Additional keyword arguments passed to the attribution method in JSON format.", + ) + + +@command_args_docstring +@dataclass +class AttributeExtendedArgs(AttributeBaseArgs): + attribute_target: bool = cli_arg( + default=False, + help="Performs the attribution procedure including the generated target prefix at every step.", + ) + generate_from_target_prefix: bool = cli_arg( + default=False, + help=( + "Whether the ``generated_texts`` should be used as target prefixes for the generation process. If False," + " the ``generated_texts`` are used as full targets. Option only available for encoder-decoder models," + " since for decoder-only ones it is sufficient to add prefix to input string. Default: False." + ), + ) + step_scores: List[str] = cli_arg( + default_factory=list, + help="Adds the specified step scores to the attribution output.", + choices=list_step_functions(), + ) + output_step_attributions: bool = cli_arg(default=False, help="Adds step-level feature attributions to the output.") + include_eos_baseline: bool = cli_arg( + default=False, + aliases=["--eos"], + help="Whether the EOS token should be included in the baseline, used for some attribution methods.", + ) + batch_size: int = cli_arg( + default=8, aliases=["-bs"], help="The batch size used for the attribution computation. Default: no batching." + ) + aggregate_output: bool = cli_arg( + default=False, + help="If specified, the attribution output is aggregated using its default aggregator before saving.", + ) + hide_attributions: bool = cli_arg( + default=False, + aliases=["--hide"], + help="If specified, the attribution visualization are not shown in the output.", + ) + save_path: Optional[str] = cli_arg( + default=None, + aliases=["-o"], + help="Path where the attribution output should be saved in JSON format.", + ) + viz_path: Optional[str] = cli_arg( + default=None, + help="Path where the attribution visualization should be saved in HTML format.", + ) + start_pos: Optional[int] = cli_arg( + default=None, aliases=["-s"], help="Start position for the attribution. Default: first token" + ) + end_pos: Optional[int] = cli_arg( + default=None, aliases=["-e"], help="End position for the attribution. Default: last token" + ) + verbose: bool = cli_arg( + default=False, aliases=["-v"], help="If specified, use INFO as logging level for the attribution." + ) + very_verbose: bool = cli_arg( + default=False, aliases=["-vv"], help="If specified, use DEBUG as logging level for the attribution." + ) + + +@command_args_docstring +@dataclass +class AttributeWithInputsArgs(AttributeExtendedArgs): + input_texts: List[str] = cli_arg(default=None, aliases=["-i"], help="One or more input texts used for generation.") + generated_texts: Optional[List[str]] = cli_arg( + default=None, aliases=["-g"], help="If specified, constrains the decoding procedure to the specified outputs." + ) + + def __post_init__(self): + if self.input_texts is None: + raise RuntimeError("Input texts must be specified.") + if isinstance(self.input_texts, str): + self.input_texts = list(self.input_texts) + if isinstance(self.generated_texts, str): + self.generated_texts = list(self.generated_texts) diff --git a/inseq/commands/attribute_context/__init__.py b/inseq/commands/attribute_context/__init__.py new file mode 100644 index 00000000..a8d0d5b4 --- /dev/null +++ b/inseq/commands/attribute_context/__init__.py @@ -0,0 +1,10 @@ +from .attribute_context import AttributeContextCommand +from .attribute_context_args import AttributeContextArgs +from .attribute_context_helpers import AttributeContextOutput, CCIOutput + +__all__ = [ + "AttributeContextCommand", + "AttributeContextArgs", + "AttributeContextOutput", + "CCIOutput", +] diff --git a/inseq/commands/attribute_context/attribute_context.py b/inseq/commands/attribute_context/attribute_context.py new file mode 100644 index 00000000..d3202e6b --- /dev/null +++ b/inseq/commands/attribute_context/attribute_context.py @@ -0,0 +1,219 @@ +"""Implementation of the context attribution process described in `Quantifying the Plausibility of Context Reliance in +Neural Machine Translation `_ for decoder-only and encoder-decoder models. + +The process consists of two steps: + - Context-sensitive Token Identification (CTI): detects which tokens in the generated output of interest are + influenced by the presence of context. + - Contextual Cues Imputation (CCI): attributes the generation of context-sensitive tokens identified in the first + step to the input and output contexts. + +Example usage: + +```bash +inseq attribute-context \ + --model_name_or_path gpt2 \ + --input_context_text "George was sick yesterday." \ + --input_current_text "His colleagues asked him" \ + --attributed_fn contrast_prob_diff +``` +""" + +import json +import warnings +from copy import deepcopy + +import transformers + +from ... import load_model +from ...attr.step_functions import is_contrastive_step_function +from ...models import HuggingfaceModel +from ..attribute import aggregate_attribution_scores +from ..base import BaseCLICommand +from .attribute_context_args import AttributeContextArgs +from .attribute_context_helpers import ( + AttributeContextOutput, + CCIOutput, + filter_rank_tokens, + format_template, + get_contextless_prefix, + get_filtered_tokens, + get_source_target_cci_scores, + prepare_outputs, +) +from .attribute_context_viz_helpers import handle_visualization + +warnings.filterwarnings("ignore") +transformers.logging.set_verbosity_error() + + +def attribute_context(args: AttributeContextArgs): + """Attribute the generation of context-sensitive tokens in ``output_current_text`` to input/output contexts.""" + model: HuggingfaceModel = load_model( + args.model_name_or_path, + args.attribution_method, + model_kwargs=deepcopy(args.model_kwargs), + tokenizer_kwargs=deepcopy(args.tokenizer_kwargs), + ) + + # Handle language tag for multilingual models - no need to specify it in generation kwargs + has_lang_tag = "tgt_lang" in args.tokenizer_kwargs + if has_lang_tag and "forced_bos_token_id" not in args.generation_kwargs: + tgt_lang = args.tokenizer_kwargs["tgt_lang"] + args.generation_kwargs["forced_bos_token_id"] = model.tokenizer.lang_code_to_id[tgt_lang] + + # Prepare input/outputs (generate if necessary) + input_full_text = format_template(args.input_template, args.input_current_text, args.input_context_text) + args.output_context_text, args.output_current_text = prepare_outputs( + model=model, + input_context_text=args.input_context_text, + input_full_text=input_full_text, + output_context_text=args.output_context_text, + output_current_text=args.output_current_text, + output_template=args.output_template, + align_output_context_auto=args.align_output_context_auto, + generation_kwargs=deepcopy(args.generation_kwargs), + special_tokens_to_keep=args.special_tokens_to_keep, + ) + output_full_text = format_template(args.output_template, args.output_current_text, args.output_context_text) + + # Tokenize inputs/outputs and compute offset + input_context_tokens = None + if args.input_context_text is not None: + input_context_tokens = get_filtered_tokens(args.input_context_text, model, args.special_tokens_to_keep) + if not model.is_encoder_decoder: + space = " " if not output_full_text.startswith((" ", "\n")) else "" + output_full_text = input_full_text + space + output_full_text + output_current_tokens = get_filtered_tokens( + args.output_current_text, model, args.special_tokens_to_keep, is_target=True + ) + output_context_tokens = None + if args.output_context_text is not None: + output_context_tokens = get_filtered_tokens( + args.output_context_text, model, args.special_tokens_to_keep, is_target=True + ) + input_full_tokens = get_filtered_tokens(input_full_text, model, args.special_tokens_to_keep) + output_full_tokens = get_filtered_tokens(output_full_text, model, args.special_tokens_to_keep, is_target=True) + output_current_text_offset = len(output_full_tokens) - len(output_current_tokens) + if model.is_encoder_decoder: + prefixed_output_current_text = args.output_current_text + else: + space = " " if not args.output_current_text.startswith((" ", "\n")) else "" + prefixed_output_current_text = args.input_current_text + space + args.output_current_text + + # Part 1: Context-sensitive Token Identification (CTI) + cti_out = model.attribute( + args.input_current_text, + prefixed_output_current_text, + attribute_target=model.is_encoder_decoder, + step_scores=[args.context_sensitivity_metric], + contrast_sources=input_full_text if model.is_encoder_decoder else None, + contrast_targets=output_full_text, + show_progress=False, + method="dummy", + )[0] + if args.show_intermediate_outputs: + cti_out.show(do_aggregation=False) + + start_pos = 1 if has_lang_tag else 0 + cti_ranked_tokens, cti_threshold = filter_rank_tokens( + tokens=[t.token for t in cti_out.target][start_pos + cti_out.attr_pos_start :], + scores=cti_out.step_scores[args.context_sensitivity_metric][start_pos:].tolist(), + std_threshold=args.context_sensitivity_std_threshold, + topk=args.context_sensitivity_topk, + ) + cti_scores = cti_out.step_scores[args.context_sensitivity_metric].tolist() + if model.is_encoder_decoder: + cti_scores = cti_scores[:-1] + if has_lang_tag: + cti_scores = cti_scores[1:] + output = AttributeContextOutput( + input_context=args.input_context_text, + input_context_tokens=input_context_tokens, + output_context=args.output_context_text, + output_context_tokens=output_context_tokens, + output_current=args.output_current_text, + output_current_tokens=output_current_tokens, + cti_scores=cti_scores, + info=args if args.add_output_info else None, + ) + # Part 2: Contextual Cues Imputation (CCI) + for cti_idx, cti_score, cti_tok in cti_ranked_tokens: + contextual_prefix = model.convert_tokens_to_string( + output_full_tokens[: output_current_text_offset + cti_idx + 1], skip_special_tokens=False + ) + cci_kwargs = {} + contextless_prefix = None + if args.attributed_fn is not None and is_contrastive_step_function(args.attributed_fn): + contextless_prefix = get_contextless_prefix( + model, + args.input_current_text, + output_current_tokens, + cti_idx, + args.special_tokens_to_keep, + deepcopy(args.generation_kwargs), + ) + cci_kwargs["contrast_sources"] = args.input_current_text if model.is_encoder_decoder else None + cci_kwargs["contrast_targets"] = contextless_prefix + output_ctx_tokens = model.convert_string_to_tokens(contextual_prefix, skip_special_tokens=False) + output_ctxless_tokens = model.convert_string_to_tokens(contextless_prefix, skip_special_tokens=False) + tok_pos = -2 if model.is_encoder_decoder else -1 + if args.attributed_fn == "kl_divergence" or output_ctx_tokens[tok_pos] == output_ctxless_tokens[tok_pos]: + cci_kwargs["contrast_force_inputs"] = True + pos_start = output_current_text_offset + cti_idx + int(model.is_encoder_decoder) + int(has_lang_tag) + cci_attrib_out = model.attribute( + input_full_text, + contextual_prefix, + attribute_target=model.is_encoder_decoder and args.has_output_context, + show_progress=False, + attr_pos_start=pos_start, + attributed_fn=args.attributed_fn, + method=args.attribution_method, + **cci_kwargs, + **args.attribution_kwargs, + ) + cci_attrib_out = aggregate_attribution_scores( + out=cci_attrib_out, + selectors=args.attribution_selectors, + aggregators=args.attribution_aggregators, + normalize_attributions=args.normalize_attributions, + )[0] + if args.show_intermediate_outputs: + cci_attrib_out.show(do_aggregation=False) + source_scores, target_scores = get_source_target_cci_scores( + model, + cci_attrib_out, + args.input_template, + input_context_tokens, + input_full_tokens, + args.output_template, + output_context_tokens, + args.has_input_context, + args.has_output_context, + has_lang_tag, + args.special_tokens_to_keep, + ) + cci_out = CCIOutput( + cti_idx=cti_idx, + cti_token=cti_tok, + cti_score=cti_score, + contextual_prefix=contextual_prefix, + contextless_prefix=contextless_prefix, + input_context_scores=source_scores, + output_context_scores=target_scores, + ) + output.cci_scores.append(cci_out) + if args.save_path: + with open(args.save_path, "w") as f: + json.dump(output.to_dict(), f, indent=4) + if args.show_viz or args.viz_path: + handle_visualization(args, model, output, cti_threshold) + return output + + +class AttributeContextCommand(BaseCLICommand): + _name = "attribute-context" + _help = "Detect context-sensitive tokens in a generated text and attribute their predictions to available context." + _dataclasses = AttributeContextArgs + + def run(args: AttributeContextArgs): + attribute_context(args) diff --git a/inseq/commands/attribute_context/attribute_context_args.py b/inseq/commands/attribute_context/attribute_context_args.py new file mode 100644 index 00000000..f102b6e6 --- /dev/null +++ b/inseq/commands/attribute_context/attribute_context_args.py @@ -0,0 +1,221 @@ +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from ... import list_step_functions +from ...attr.step_functions import is_contrastive_step_function +from ...utils import cli_arg, pretty_dict +from ..attribute import AttributeBaseArgs +from ..commands_utils import command_args_docstring + +logger = logging.getLogger(__name__) + + +@command_args_docstring +@dataclass +class AttributeContextInputArgs: + input_current_text: str = cli_arg( + default="", + help=( + "The input text used for generation. If the model is a decoder-only model, the input text is a " + "prompt used for language modeling. If the model is an encoder-decoder model, the input text is the " + "source text provided as input to the encoder. It will be formatted as {current} in the " + "``input_template``." + ), + ) + input_context_text: Optional[str] = cli_arg( + default=None, + help=( + "Additional input context influencing the generation of ``output_current_text``. If the model is a" + " decoder-only model, the input text is a prefix to the ``input_current_text`` prompt. If the model is an" + " encoder-decoder model, the input context is part of the source text provided as input to the encoder. " + " It will be formatted as {context} in the ``input_template``." + ), + ) + input_template: Optional[str] = cli_arg( + default=None, + help=( + "The template used to format model inputs. The template must contain at least the" + " ``{current}`` placeholder, which will be replaced by ``input_current_text``. If ``{context}`` is" + " also specified, input-side context will be used. Can be modified for models requiring special tokens or" + " formatting in the input text (e.g. tags to separate context and current inputs)." + " Defaults to '{context} {current}' if ``input_context_text`` is provided, '{current}' otherwise." + ), + ) + output_context_text: Optional[str] = cli_arg( + default=None, + help=( + "An output contexts for which context sensitivity should be detected. For encoder-decoder models, this" + " is a target-side prefix to the output_current_text used as input to the decoder. For decoder-only " + " models this is a portion of the model generation that should be considered as an additional context " + " (e.g. a chain-of-thought sequence). It will be formatted as {context} in the ``output_template``." + " If not provided but specified in the ``output_template``, the output context will be generated" + " along with the output current text, and user validation might be required to separate the two." + ), + ) + output_current_text: Optional[str] = cli_arg( + default=None, + help=( + "The output text generated by the model when all available contexts are provided. Tokens in " + " ``output_current_text`` will be tested for context-sensitivity, and their generation will be attributed " + " to input/target contexts (if present) in case they are found to be context-sensitive. If specified, " + " this output is force-decoded. Otherwise, it is generated by the model using infilled ``input_template`` " + " and ``output_template``. It will be formatted as {current} in the ``output_template``." + ), + ) + output_template: Optional[str] = cli_arg( + default=None, + help=( + "The template used to format model outputs. The template must contain at least the" + " ``{current}`` placeholder, which will be replaced by ``output_current_text``. If ``{context}`` is" + " also specified, output-side context will be used. Can be modified for models requiring special tokens or" + " formatting in the output text (e.g. tags to separate context and current outputs)." + " Defaults to '{context} {current}' if ``output_context_text`` is provided, '{current}' otherwise." + ), + ) + has_input_context: bool = cli_arg(init=False) + has_output_context: bool = cli_arg(init=False) + + def __post_init__(self): + if self.input_template is None: + self.input_template = "{current}" if self.input_context_text is None else "{context} {current}" + if self.output_template is None: + self.output_template = "{current}" if self.output_context_text is None else "{context} {current}" + self.has_input_context = "{context}" in self.input_template + self.has_output_context = "{context}" in self.output_template + if not self.input_current_text: + raise ValueError("--input_current_text must be a non-empty string.") + if self.input_context_text and not self.has_input_context: + logger.warning( + f"input_template has format {self.input_template} (no {{context}}), but --input_context_text is" + " specified. Ignoring provided --input_context_text." + ) + self.input_context_text = None + if self.output_context_text and not self.has_output_context: + logger.warning( + f"output_template has format {self.output_template} (no {{context}}), but --output_context_text is" + " specified. Ignoring provided --output_context_text." + ) + self.output_context_text = None + if not self.input_context_text and self.has_input_context: + raise ValueError( + f"{{context}} format placeholder is present in input_template {self.input_template}," + " but --input_context_text is not specified." + ) + if "{current}" not in self.input_template: + raise ValueError(f"{{current}} format placeholder is missing from input_template {self.input_template}.") + if "{current}" not in self.output_template: + raise ValueError(f"{{current}} format placeholder is missing from output_template {self.output_template}.") + if not self.input_current_text: + raise ValueError("--input_current_text must be a non-empty string.") + if self.has_input_context and self.input_template.find("{context}") > self.input_template.find("{current}"): + raise ValueError( + f"{{context}} placeholder must appear before {{current}} in input_template '{self.input_template}'." + ) + if self.has_output_context and self.output_template.find("{context}") > self.output_template.find("{current}"): + raise ValueError( + f"{{context}} placeholder must appear before {{current}} in output_template '{self.output_template}'." + ) + if not self.output_template.endswith("{current}"): + *_, suffix = self.output_template.partition("{current}") + logger.warning( + f"Suffix '{suffix}' was specified in output_template and will be used to ignore the specified suffix" + " tokens during context sensitivity detection. Make sure that the suffix corresponds to the end of the" + " output_current_text by forcing --output_current_text if necessary." + ) + + +@command_args_docstring +@dataclass +class AttributeContextMethodArgs(AttributeBaseArgs): + context_sensitivity_metric: str = cli_arg( + default="kl_divergence", + help="The contrastive metric used to detect context-sensitive tokens in ``output_current_text``.", + choices=[fn for fn in list_step_functions() if is_contrastive_step_function(fn)], + ) + align_output_context_auto: bool = cli_arg( + default=False, + help=( + "Argument used for encoder-decoder model when generating text with an output template including both" + " {context} and {current}, to attempt an automatic detection of which parts of the output belong to" + " context vs. current in absence of other explicit cues. If set to True, the input and output context" + " and current texts are aligned automatically (assuming an MT-like task), and the alignments are " + " assumed to be valid to separate the two without further user validation. Otherwise, the user is " + " prompted to manually specify which part of the generated text corresponds to the output context." + ), + ) + special_tokens_to_keep: List[str] = cli_arg( + default_factory=list, + help="Special tokens to preserve in the generated string, e.g. ```` separator between context and current.", + ) + context_sensitivity_std_threshold: float = cli_arg( + default=1.0, + help=( + "Parameter to control the selection of ``output_current_text`` tokens considered as context-sensitive for " + "moving onwards with attribution. Corresponds to the number of standard deviations above or below the mean" + " ``context_sensitivity_metric`` score for tokens to be considered context-sensitive." + ), + ) + context_sensitivity_topk: Optional[int] = cli_arg( + default=None, + help=( + "If set, after selecting the salient context-sensitive tokens with ``context_sensitivity_std_threshold`` " + "only the top-K remaining tokens are used. By default no top-k selection is performed." + ), + ) + attribution_std_threshold: float = cli_arg( + default=1.0, + help=( + "Parameter to control the selection of ``input_context_text`` and ``output_context_text`` tokens " + "considered as salient as a result for the attribution process. Corresponds to the number of standard " + "deviations above or below the mean ``attribution_method`` score for tokens to be considered salient. " + "CCI scores for all context tokens are saved in the output, but this parameter controls which tokens are " + "used in the visualization of context reliance." + ), + ) + attribution_topk: Optional[int] = cli_arg( + default=None, + help=( + "If set, after selecting the most salient tokens with ``attribution_std_threshold`` " + "only the top-K remaining tokens are used. By default no top-k selection is performed." + ), + ) + + +@command_args_docstring +@dataclass +class AttributeContextOutputArgs: + show_intermediate_outputs: bool = cli_arg( + default=False, + help=( + "If specified, the intermediate outputs produced by the Inseq library for context-sensitive target " + "identification (CTI) and contextual cues imputation (CCI) are shown during the process.", + ), + ) + save_path: Optional[str] = cli_arg( + default=None, + aliases=["-o"], + help="If present, the output of the two-step process will be saved in JSON format at the specified path.", + ) + add_output_info: bool = cli_arg( + default=True, + help="If specified, additional information about the attribution process is added to the saved output.", + ) + viz_path: Optional[str] = cli_arg( + default=None, + help="If specified, the visualization produced from the output is saved in HTML format at the specified path.", + ) + show_viz: bool = cli_arg( + default=True, + help="If specified, the visualization produced from the output is shown in the terminal.", + ) + + +@command_args_docstring +@dataclass +class AttributeContextArgs(AttributeContextInputArgs, AttributeContextMethodArgs, AttributeContextOutputArgs): + def __repr__(self): + return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})" + + def to_dict(self) -> Dict[str, Any]: + return dict(self.__dict__.items()) diff --git a/inseq/commands/attribute_context/attribute_context_helpers.py b/inseq/commands/attribute_context/attribute_context_helpers.py new file mode 100644 index 00000000..713b58e5 --- /dev/null +++ b/inseq/commands/attribute_context/attribute_context_helpers.py @@ -0,0 +1,366 @@ +import logging +import re +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +from rich import print as rprint +from rich.prompt import Confirm, Prompt +from torch import tensor + +from ...data import FeatureAttributionSequenceOutput +from ...models import HuggingfaceModel +from ...utils import pretty_dict +from ...utils.alignment_utils import compute_word_aligns +from .attribute_context_args import AttributeContextArgs + +logger = logging.getLogger(__name__) + + +@dataclass +class CCIOutput: + """Output of the Contextual Cues Imputation (CCI) step.""" + + cti_idx: int + cti_token: str + cti_score: float + contextual_prefix: str + contextless_prefix: str + input_context_scores: Optional[List[float]] = None + output_context_scores: Optional[List[float]] = None + + def __repr__(self): + return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})" + + def to_dict(self) -> Dict[str, Any]: + return dict(self.__dict__.items()) + + +@dataclass +class AttributeContextOutput: + """Output of the overall context attribution process.""" + + input_context: Optional[str] = None + input_context_tokens: Optional[List[str]] = None + output_context: Optional[str] = None + output_context_tokens: Optional[List[str]] = None + output_current: Optional[str] = None + output_current_tokens: Optional[List[str]] = None + cti_scores: Optional[List[float]] = None + cci_scores: List[CCIOutput] = field(default_factory=list) + info: Optional[AttributeContextArgs] = None + + def __repr__(self): + return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})" + + def to_dict(self) -> Dict[str, Any]: + out_dict = {k: v for k, v in self.__dict__.items() if k not in ["cci_scores", "info"]} + out_dict["cci_scores"] = [cci_out.to_dict() for cci_out in self.cci_scores] + if self.info: + out_dict["info"] = self.info.to_dict() + return out_dict + + +def format_template(template: str, current: str, context: Optional[str] = None) -> str: + kwargs = {"current": current} + if context is not None: + kwargs["context"] = context + return template.format(**kwargs) + + +def get_filtered_tokens( + text: str, + model: HuggingfaceModel, + special_tokens_to_keep: List[str], + replace_special_characters: bool = False, + is_target: bool = False, +) -> List[str]: + """Tokenize text and filter out special tokens, keeping only those in ``special_tokens_to_keep``.""" + as_targets = is_target and model.is_encoder_decoder + return [ + t.replace("Ä ", " ").replace("Ċ", " ").replace("▁", " ") if replace_special_characters else t + for t in model.convert_string_to_tokens(text, skip_special_tokens=False, as_targets=as_targets) + if t not in model.special_tokens or t in special_tokens_to_keep + ] + + +def generate_with_special_tokens( + model: HuggingfaceModel, + model_input: str, + special_tokens_to_keep: List[str] = [], + **generation_kwargs, +) -> str: + """Generate text preserving special tokens in ``special_tokens_to_keep``.""" + # Generate outputs, strip special tokens and remove prefix/suffix + output_gen = model.generate(model_input, skip_special_tokens=False, **generation_kwargs)[0] + output_tokens = get_filtered_tokens(output_gen, model, special_tokens_to_keep, is_target=True) + return model.convert_tokens_to_string(output_tokens, skip_special_tokens=False) + + +def generate_model_output( + model: HuggingfaceModel, + model_input: str, + generation_kwargs: Dict[str, Any], + special_tokens_to_keep: List[str], + output_template: str, + prefix: str, + suffix: str, +) -> str: + """Generate the model output, validating the presence of a prefix/suffix and stripping them from the generation.""" + output_gen = generate_with_special_tokens(model, model_input, special_tokens_to_keep, **generation_kwargs) + if prefix: + if not output_gen.startswith(prefix): + raise ValueError( + f"Output template '{output_template}' contains prefix '{prefix}' but output '{output_gen}' does" + " not match the prefix. Please check whether the template is correct, or force context/current" + " outputs." + ) + output_gen = output_gen[len(prefix) :] + if suffix: + if not output_gen.endswith(suffix): + raise ValueError( + f"Output template {output_template} contains suffix {suffix} but output '{output_gen}' does" + " not match the suffix. Please check whether the template is correct, or force context/current" + " outputs." + ) + output_gen = output_gen[: -len(suffix)] + return output_gen + + +def prompt_user_for_context(output: str, context_candidate: Optional[str] = None) -> str: + """Prompt the user to provide the correct context for the provided output.""" + while True: + if context_candidate: + is_correct_candidate = Confirm.ask( + f'\n:arrow_right: The model generated the following output: "[bold]{output}[/bold]"' + f'\n:question: Is [bold]"{context_candidate}"[/bold] the correct context you want to attribute?' + ) + if is_correct_candidate: + user_context = context_candidate + else: + user_context = Prompt.ask( + ":writing_hand: Please enter the portion of the generated output representing the correct context" + ) + if output.startswith(user_context): + if not user_context.strip(): + use_empty_context = Confirm.ask( + ":question: The provided context is empty. Do you want to use an empty context?" + ) + if use_empty_context: + user_context = "" + else: + continue + break + rprint( + "[prompt.invalid]The provided context is invalid. Please provide a non-empty substring of" + " the model output above to use as context." + ) + return user_context + + +def get_output_context_from_aligned_inputs(input_context: str, output_text: str) -> str: + """Retrieve the output context from alignments between input context and the full output text.""" + aligned_context = compute_word_aligns(input_context, output_text, split_pattern=r"\s+|\b") + max_context_id = max(pair[1] for pair in aligned_context.alignments) + output_text_boundary_token = aligned_context.target_tokens[max_context_id] + # Empty spans correspond to token boundaries + spans = [m.span() for m in re.finditer(r"\s+|\b", output_text)] + tok_start_positions = list({start if start == end else end for start, end in spans}) + output_text_context_candidate_boundary = tok_start_positions[max_context_id] + len(output_text_boundary_token) + return output_text[:output_text_context_candidate_boundary] + + +def prepare_outputs( + model: HuggingfaceModel, + input_context_text: Optional[str], + input_full_text: str, + output_context_text: Optional[str], + output_current_text: Optional[str], + output_template: str, + align_output_context_auto: bool = False, + generation_kwargs: Dict[str, Any] = {}, + special_tokens_to_keep: List[str] = [], +) -> Tuple[Optional[str], str]: + """Handle model outputs and prepare them for attribution. + This procedure is valid both for encoder-decoder and decoder-only models. + + | use_out_ctx | has_out_ctx | has_out_curr | setting + |-------------|-------------|--------------|-------- + | True | True | True | 1. Use forced context + current as output + | False | False | True | 2. Use forced current as output + | True | True | False | 3. Set inputs with forced context, generate output, use as current + | False | False | False | 4. Generate output, use it as current + | True | False | False | 5. Generate output, handle context/current splitting + | True | False | True | 6. Generate output, handle context/current splitting, force current + + NOTE: If ``use_out_ctx`` is True but ``has_out_ctx`` is False, the model generation is assumed to contain both + a context and a current portion which need to be separated. ``has_out_ctx`` cannot be True if ``use_out_ctx`` + is False (pre-check in ``__post_init__``). + """ + use_out_ctx = "{context}" in output_template + has_out_ctx = output_context_text is not None + has_out_curr = output_current_text is not None + model_input = input_full_text + final_current = output_current_text + final_context = output_context_text + + # E.g. output template "A{context}B{current}C" -> prefix = "A", suffix = "C", separator = "B" + prefix, _ = output_template.split("{context}" if use_out_ctx else "{current}") + output_current_prefix_template, suffix = output_template.split("{current}") + separator = output_template.split("{context}")[1].split("{current}")[0] if use_out_ctx else None + + # Settings 1, 2 + if (has_out_ctx == use_out_ctx) and has_out_curr: + return final_context, final_current + + # Prepend output prefix and context, if available, if current output needs to be generated + output_current_prefix = prefix + if has_out_ctx and not has_out_curr: + output_current_prefix = output_current_prefix_template.strip().format(context=output_context_text) + if model.is_encoder_decoder: + generation_kwargs["decoder_input_ids"] = model.encode( + output_current_prefix, as_targets=True, add_special_tokens=False + ).input_ids + if "forced_bos_token_id" in generation_kwargs: + generation_kwargs["decoder_input_ids"][0, 0] = generation_kwargs["forced_bos_token_id"] + else: + space = " " if output_current_prefix and not output_current_prefix.startswith((" ", "\n")) else "" + model_input = input_full_text + space + output_current_prefix + output_current_prefix = model_input + + output_gen = generate_model_output( + model, model_input, generation_kwargs, special_tokens_to_keep, output_template, output_current_prefix, suffix + ) + + # Settings 3, 4 + if (has_out_ctx == use_out_ctx) and not has_out_curr: + final_current = output_gen if model.is_encoder_decoder or use_out_ctx else output_gen[len(model_input) :] + return final_context, final_current.strip() + + # Settings 5, 6 + # Try splitting the output into context and current text using ``separator``. As we have no guarantees of its + # uniqueness (e.g. it could be whitespace, also found between tokens in context and current) we consider the + # splitting successful if exactly 2 substrings are produced. If this fails, we try splitting on punctuation. + output_context_candidate = None + separator_split_context_current_substring = output_gen.split(separator) + if len(separator_split_context_current_substring) == 2: + output_context_candidate = separator_split_context_current_substring[0] + if not output_context_candidate: + punct_expr = re.compile(r"[\s{}]+".format(re.escape(".?!,;:)]}"))) + punctuation_split_context_current_substring = [s for s in punct_expr.split(output_gen) if s] + if len(punctuation_split_context_current_substring) == 2: + output_context_candidate = punctuation_split_context_current_substring[0] + + # Final resort: if the model is an encoder-decoder model, we align the full input and full output, identifying + # which tokens correspond to context and which to current. This assumes that input and output texts are alignable + # (e.g. translations of each other). We prompt the user a yes/no question asking whether the context identified is + # correct. If not, the user is asked to provide the correct context. If align_output_context_auto = True, aligned + # texts are assumed to be correct (no user input required, to automate the procedure) + if not output_context_candidate and model.is_encoder_decoder and input_context_text is not None: + output_context_candidate = get_output_context_from_aligned_inputs(input_context_text, output_gen) + + if output_context_candidate and align_output_context_auto: + final_context = output_context_candidate + else: + final_context = prompt_user_for_context(output_gen, output_context_candidate) + template_output_context = output_template.split("{current}")[0].format(context=final_context) + if not final_context: + template_output_context = template_output_context.strip() + final_current = output_gen[min(len(template_output_context), len(output_gen)) :] + if not has_out_curr and not final_current: + raise ValueError( + f"The model produced an empty current output given the specified context '{final_context}'. If no" + " context is generated naturally by the model, you can force an output context using the" + " --output_context_text option." + ) + if has_out_curr: + logger.warning( + f"The model produced current text '{final_current}', but the specified output_current_text" + f" '{output_current_text}'is used instead. If you want to use the original current output text generated" + " by the model, remove the --output_current_text option." + ) + return final_context, final_current.strip() + + +def filter_rank_tokens( + tokens: List[str], + scores: List[float], + std_threshold: Optional[float] = None, + topk: Optional[int] = None, +) -> Tuple[List[Tuple[int, float, str]], float]: + indices = list(range(0, len(scores))) + token_score_tuples = sorted(zip(indices, scores, tokens), key=lambda x: abs(x[1]), reverse=True) + if std_threshold: + threshold = tensor(scores).mean() + std_threshold * tensor(scores).std() + token_score_tuples = [(i, s, t) for i, s, t in token_score_tuples if abs(s) > threshold] + if topk: + token_score_tuples = token_score_tuples[:topk] + return token_score_tuples, threshold + + +def get_contextless_prefix( + model: HuggingfaceModel, + input_current_text: str, + output_current_tokens: List[str], + cti_idx: int, + special_tokens_to_keep: List[str] = [], + generation_kwargs: Dict[str, Any] = {}, +) -> Tuple[str, str]: + """Generate the contextless prefix for the current token identified as context-sensitive.""" + output_current_prefix_tokens = output_current_tokens[:cti_idx] + output_current_prefix = model.convert_tokens_to_string(output_current_prefix_tokens, skip_special_tokens=False) + if model.is_encoder_decoder: + # One extra token for the EOS which is always forced at the end for encoder-decoders + generation_kwargs["max_new_tokens"] = 2 + decoder_input_ids = model.encode(output_current_prefix, as_targets=True).input_ids + if int(decoder_input_ids[0, -1]) == model.eos_token_id: + decoder_input_ids = decoder_input_ids[0, :-1][None, ...] + generation_kwargs["decoder_input_ids"] = decoder_input_ids + generation_input = input_current_text + else: + generation_kwargs["max_new_tokens"] = 1 + space = " " if output_current_prefix and not output_current_prefix.startswith((" ", "\n")) else "" + generation_input = input_current_text + space + output_current_prefix + output_contextless = generate_with_special_tokens( + model, + generation_input, + special_tokens_to_keep, + **generation_kwargs, + ) + return output_contextless + + +def get_source_target_cci_scores( + model: HuggingfaceModel, + cci_attrib_out: FeatureAttributionSequenceOutput, + input_template: str, + input_context_tokens: List[str], + input_full_tokens: List[str], + output_template: str, + output_context_tokens: List[str], + has_input_context: bool, + has_output_context: bool, + model_has_lang_tag: bool, + special_tokens_to_keep: List[str] = [], +) -> Tuple[Optional[List[float]], Optional[List[float]]]: + """Extract attribution scores for the input and output contexts.""" + input_scores, output_scores = None, None + if has_input_context: + if model.is_encoder_decoder: + input_scores = cci_attrib_out.source_attributions[:, 0].tolist() + if model_has_lang_tag: + input_scores = input_scores[1:] + else: + input_scores = cci_attrib_out.target_attributions[:, 0].tolist() + input_prefix, *_ = input_template.partition("{context}") + input_prefix_tokens = get_filtered_tokens(input_prefix, model, special_tokens_to_keep, is_target=False) + input_prefix_len = len(input_prefix_tokens) + input_scores = input_scores[input_prefix_len : len(input_context_tokens) + input_prefix_len] + if has_output_context: + output_scores = cci_attrib_out.target_attributions[:, 0].tolist() + if model_has_lang_tag: + output_scores = output_scores[1:] + output_prefix, *_ = output_template.partition("{context}") + output_prefix_tokens = get_filtered_tokens(output_prefix, model, special_tokens_to_keep, is_target=True) + prefix_len = len(output_prefix_tokens) + int(not model.is_encoder_decoder) * len(input_full_tokens) + output_scores = output_scores[prefix_len : len(output_context_tokens) + prefix_len] + return input_scores, output_scores diff --git a/inseq/commands/attribute_context/attribute_context_viz_helpers.py b/inseq/commands/attribute_context/attribute_context_viz_helpers.py new file mode 100644 index 00000000..841af392 --- /dev/null +++ b/inseq/commands/attribute_context/attribute_context_viz_helpers.py @@ -0,0 +1,128 @@ +from typing import List, Literal, Optional + +from rich.console import Console + +from ...models import HuggingfaceModel +from .attribute_context_args import AttributeContextArgs +from .attribute_context_helpers import AttributeContextOutput, filter_rank_tokens, get_filtered_tokens + + +def get_formatted_procedure_details(args: AttributeContextArgs) -> str: + def format_comment(std: Optional[float] = None, topk: Optional[int] = None) -> str: + comment = [] + if std: + comment.append(f"std λ={std:.2f}") + if topk: + comment.append(f"top {topk}") + if len(comment) > 0: + return ", ".join(comment) + return "all" + + cti_comment = format_comment(args.context_sensitivity_std_threshold, args.context_sensitivity_topk) + cci_comment = format_comment(args.attribution_std_threshold, args.attribution_topk) + input_context_comment, output_context_comment = "", "" + if args.has_input_context: + input_context_comment = f"\n[bold]Input context:[/bold]\t{args.input_context_text}" + if args.has_output_context: + output_context_comment = f"\n[bold]Output context:[/bold]\t{args.output_context_text}" + return ( + f"\nContext with [bold green]contextual cues[/bold green] ({cci_comment}) followed by output" + f" sentence with [bold dodger_blue1]context-sensitive target spans[/bold dodger_blue1] ({cti_comment})\n" + f'(CTI = "{args.context_sensitivity_metric}", CCI = "{args.attribution_method}" w/ "{args.attributed_fn}" ' + f"target)\n{input_context_comment}\n[bold]Input current:[/bold] {args.input_current_text}" + f"{output_context_comment}\n[bold]Output current:[/bold]\t{args.output_current_text}" + ) + + +def get_formatted_attribute_context_results( + model: HuggingfaceModel, + args: AttributeContextArgs, + output: AttributeContextOutput, + cti_threshold: float, +) -> str: + """Format the results of the context attribution process.""" + + def format_context_comment( + model: HuggingfaceModel, + has_other_context: bool, + special_tokens_to_keep: List[str], + context: str, + context_scores: List[float], + other_context_scores: Optional[List[float]] = None, + is_target: bool = False, + context_type: Literal["Input", "Output"] = "Input", + ) -> str: + context_tokens = get_filtered_tokens( + context, model, special_tokens_to_keep, replace_special_characters=True, is_target=is_target + ) + scores = context_scores + if has_other_context: + scores += other_context_scores + context_ranked_tokens, threshold = filter_rank_tokens( + tokens=context_tokens, + scores=scores, + std_threshold=args.attribution_std_threshold, + topk=args.attribution_topk, + ) + for idx, score, tok in context_ranked_tokens: + context_tokens[idx] = f"[bold green]{tok}({score:.3f})[/bold green]" + cci_threshold_comment = f"(CCI > {threshold:.3f})" + return f"\n[bold]{context_type} context {cci_threshold_comment}:[/bold]\t{''.join(context_tokens)}" + + out_string = "" + output_current_tokens = get_filtered_tokens( + output.output_current, model, args.special_tokens_to_keep, replace_special_characters=True, is_target=True + ) + cti_theshold_comment = f"(CTI > {cti_threshold:.3f})" + for example_idx, cci_out in enumerate(output.cci_scores, start=1): + curr_output_tokens = output_current_tokens.copy() + cti_idx = cci_out.cti_idx + cti_score = cci_out.cti_score + cti_tok = curr_output_tokens[cti_idx] + curr_output_tokens[cti_idx] = f"[bold dodger_blue1]{cti_tok}({cti_score:.3f})[/bold dodger_blue1]" + output_current_comment = "".join(curr_output_tokens) + input_context_comment, output_context_comment = "", "" + if args.has_input_context: + input_context_comment = format_context_comment( + model, + args.has_output_context, + args.special_tokens_to_keep, + output.input_context, + cci_out.input_context_scores, + cci_out.output_context_scores, + ) + if args.has_output_context: + output_context_comment = format_context_comment( + model, + args.has_input_context, + args.special_tokens_to_keep, + output.output_context, + cci_out.output_context_scores, + cci_out.input_context_scores, + is_target=True, + context_type="Output", + ) + out_string += ( + f"#{example_idx}." + f"\n[bold]Generated output {cti_theshold_comment}:[/bold]\t{output_current_comment}" + f"{input_context_comment}{output_context_comment}\n" + ) + return out_string + + +def handle_visualization( + args: AttributeContextArgs, + model: HuggingfaceModel, + output: AttributeContextOutput, + cti_threshold: float, +) -> None: + console = Console(record=True) + viz = get_formatted_procedure_details(args) + viz += "\n\n" + get_formatted_attribute_context_results(model, args, output, cti_threshold) + if args.viz_path: + with console.capture() as _: + console.print(viz, soft_wrap=True) + with open(args.viz_path, "w") as f: + f.write(console.export_html()) + if args.show_viz: + console.print(viz, soft_wrap=True) diff --git a/inseq/commands/attribute_dataset.py b/inseq/commands/attribute_dataset.py deleted file mode 100644 index c9635ec1..00000000 --- a/inseq/commands/attribute_dataset.py +++ /dev/null @@ -1,80 +0,0 @@ -from dataclasses import dataclass, field -from typing import List, Optional, Tuple - -from ..utils import is_datasets_available -from .attribute import AttributeBaseArgs, attribute -from .base import BaseCLICommand - -if is_datasets_available(): - from datasets import load_dataset - - -@dataclass -class AttributeDatasetArgs: - dataset_name: str = field( - metadata={ - "alias": "-d", - "help": "The type of dataset to be loaded for attribution.", - }, - ) - input_text_field: Optional[str] = field( - metadata={"alias": "-f", "help": "Name of the field containing the input texts used for attribution."} - ) - generated_text_field: Optional[str] = field( - default=None, - metadata={ - "alias": "-fgen", - "help": "Name of the field containing the generated texts used for constrained decoding.", - }, - ) - dataset_config: Optional[str] = field( - default=None, metadata={"alias": "-dconf", "help": "The name of the Huggingface dataset configuration."} - ) - dataset_dir: Optional[str] = field( - default=None, metadata={"alias": "-ddir", "help": "Path to the directory containing the data files."} - ) - dataset_files: Optional[List[str]] = field( - default=None, metadata={"alias": "-dfiles", "help": "Path to the dataset files."} - ) - dataset_split: Optional[str] = field(default="train", metadata={"alias": "-dsplit", "help": "Dataset split."}) - dataset_revision: Optional[str] = field( - default=None, metadata={"alias": "-drev", "help": "The Huggingface dataset revision."} - ) - dataset_auth_token: Optional[str] = field( - default=None, metadata={"alias": "-dauth", "help": "The auth token for the Huggingface dataset."} - ) - - -def load_fields_from_dataset(dataset_args: AttributeDatasetArgs) -> Tuple[List[str], Optional[List[str]]]: - if not is_datasets_available(): - raise ImportError("The datasets library needs to be installed to use the attribute-dataset client.") - dataset = load_dataset( - dataset_args.dataset_name, - dataset_args.dataset_config, - data_dir=dataset_args.dataset_dir, - data_files=dataset_args.dataset_files, - split=dataset_args.dataset_split, - revision=dataset_args.dataset_revision, - use_auth_token=dataset_args.dataset_auth_token, - ) - df = dataset.to_pandas() - if dataset_args.input_text_field in df.columns: - input_texts = list(df[dataset_args.input_text_field]) - else: - raise ValueError(f"The input text field {dataset_args.input_text_field} does not exist in the dataset.") - generated_texts = None - if dataset_args.generated_text_field is not None: - if dataset_args.generated_text_field in df.columns: - generated_texts = list(df[dataset_args.generated_text_field]) - return input_texts, generated_texts - - -class AttributeDatasetCommand(BaseCLICommand): - _name = "attribute-dataset" - _help = "Perform feature attribution on a full dataset and save the results to a file" - _dataclasses = AttributeBaseArgs, AttributeDatasetArgs - - def run(args: Tuple[AttributeBaseArgs, AttributeDatasetArgs]): - attribute_args, dataset_args = args - input_texts, generated_texts = load_fields_from_dataset(dataset_args) - attribute(input_texts, generated_texts, attribute_args) diff --git a/inseq/commands/attribute_dataset/__init__.py b/inseq/commands/attribute_dataset/__init__.py new file mode 100644 index 00000000..beff886a --- /dev/null +++ b/inseq/commands/attribute_dataset/__init__.py @@ -0,0 +1,5 @@ +from .attribute_dataset import AttributeDatasetCommand + +__all__ = [ + "AttributeDatasetCommand", +] diff --git a/inseq/commands/attribute_dataset/attribute_dataset.py b/inseq/commands/attribute_dataset/attribute_dataset.py new file mode 100644 index 00000000..6d633926 --- /dev/null +++ b/inseq/commands/attribute_dataset/attribute_dataset.py @@ -0,0 +1,46 @@ +from typing import List, Optional, Tuple + +from ...utils import is_datasets_available +from ..attribute import AttributeExtendedArgs +from ..attribute.attribute import attribute +from ..base import BaseCLICommand +from .attribute_dataset_args import LoadDatasetArgs + +if is_datasets_available(): + from datasets import load_dataset + + +def load_fields_from_dataset(dataset_args: LoadDatasetArgs) -> Tuple[List[str], Optional[List[str]]]: + if not is_datasets_available(): + raise ImportError("The datasets library needs to be installed to use the attribute-dataset client.") + dataset = load_dataset( + dataset_args.dataset_name, + dataset_args.dataset_config, + data_dir=dataset_args.dataset_dir, + data_files=dataset_args.dataset_files, + split=dataset_args.dataset_split, + revision=dataset_args.dataset_revision, + token=dataset_args.dataset_auth_token, + **dataset_args.dataset_kwargs, + ) + df = dataset.to_pandas() + if dataset_args.input_text_field in df.columns: + input_texts = list(df[dataset_args.input_text_field]) + else: + raise ValueError(f"The input text field {dataset_args.input_text_field} does not exist in the dataset.") + generated_texts = None + if dataset_args.generated_text_field is not None: + if dataset_args.generated_text_field in df.columns: + generated_texts = list(df[dataset_args.generated_text_field]) + return input_texts, generated_texts + + +class AttributeDatasetCommand(BaseCLICommand): + _name = "attribute-dataset" + _help = "Perform feature attribution on a full dataset and save the results to a file" + _dataclasses = AttributeExtendedArgs, LoadDatasetArgs + + def run(args: Tuple[AttributeExtendedArgs, LoadDatasetArgs]): + attribute_args, dataset_args = args + input_texts, generated_texts = load_fields_from_dataset(dataset_args) + attribute(input_texts, generated_texts, attribute_args) diff --git a/inseq/commands/attribute_dataset/attribute_dataset_args.py b/inseq/commands/attribute_dataset/attribute_dataset_args.py new file mode 100644 index 00000000..3f02ee19 --- /dev/null +++ b/inseq/commands/attribute_dataset/attribute_dataset_args.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +from typing import List, Optional + +from ...utils import cli_arg +from ..commands_utils import command_args_docstring + + +@command_args_docstring +@dataclass +class LoadDatasetArgs: + dataset_name: str = cli_arg( + aliases=["-d", "--dataset"], + help="The type of dataset to be loaded for attribution.", + ) + input_text_field: Optional[str] = cli_arg( + aliases=["-in", "--input"], help="Name of the field containing the input texts used for attribution." + ) + generated_text_field: Optional[str] = cli_arg( + default=None, + aliases=["-gen", "--generated"], + help="Name of the field containing the generated texts used for constrained decoding.", + ) + dataset_config: Optional[str] = cli_arg( + default=None, aliases=["--config"], help="The name of the Huggingface dataset configuration." + ) + dataset_dir: Optional[str] = cli_arg( + default=None, aliases=["--dir"], help="Path to the directory containing the data files." + ) + dataset_files: Optional[List[str]] = cli_arg(default=None, aliases=["--files"], help="Path to the dataset files.") + dataset_split: Optional[str] = cli_arg(default="train", aliases=["--split"], help="Dataset split.") + dataset_revision: Optional[str] = cli_arg( + default=None, aliases=["--revision"], help="The Huggingface dataset revision." + ) + dataset_auth_token: Optional[str] = cli_arg( + default=None, aliases=["--auth"], help="The auth token for the Huggingface dataset." + ) + dataset_kwargs: Optional[dict] = cli_arg( + default_factory=dict, + help="Additional keyword arguments passed to the dataset constructor in JSON format.", + ) diff --git a/inseq/commands/cli.py b/inseq/commands/cli.py index d7735998..b4e72257 100644 --- a/inseq/commands/cli.py +++ b/inseq/commands/cli.py @@ -1,11 +1,14 @@ """Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/commands/transformers_cli.py.""" import sys +from typing import List from ..utils import InseqArgumentParser from .attribute import AttributeCommand +from .attribute_context import AttributeContextCommand from .attribute_dataset import AttributeDatasetCommand +from .base import BaseCLICommand -COMMANDS = [AttributeCommand, AttributeDatasetCommand] +COMMANDS: List[BaseCLICommand] = [AttributeCommand, AttributeDatasetCommand, AttributeContextCommand] def main(): diff --git a/inseq/commands/commands_utils.py b/inseq/commands/commands_utils.py new file mode 100644 index 00000000..ed138f7b --- /dev/null +++ b/inseq/commands/commands_utils.py @@ -0,0 +1,25 @@ +import dataclasses +import textwrap +import typing + + +def command_args_docstring(cls): + """ + A decorator that automatically generates a Google-style docstring for a dataclass. + """ + docstring = f"{cls.__name__}\n\n" + fields = dataclasses.fields(cls) + resolved_hints = typing.get_type_hints(cls) + resolved_field_types = {field.name: resolved_hints[field.name] for field in fields} + if fields: + docstring += "**Attributes:**\n" + for field in fields: + field_type = resolved_field_types[field.name] + field_help = field.metadata.get("help", "") + docstring += textwrap.dedent( + f""" + **{field.name}** (``{field_type}``): {field_help} + """ + ) + cls.__doc__ = docstring + return cls diff --git a/inseq/data/viz.py b/inseq/data/viz.py index 86d924cb..59634420 100644 --- a/inseq/data/viz.py +++ b/inseq/data/viz.py @@ -23,8 +23,8 @@ import numpy as np from matplotlib.colors import Colormap from rich import box -from rich import print as rprint from rich.color import Color +from rich.console import Console from rich.live import Live from rich.padding import Padding from rich.panel import Panel @@ -102,12 +102,15 @@ def show_attributions( display(HTML(curr_html)) html_out += curr_html if not isnotebook(): + console = Console() curr_color = None if attribution.source_attributions is not None: curr_color = colors[idx] if display: print("\n\n") - rprint(get_heatmap_type(attribution, curr_color, "Source", use_html=False)) + console.print( + get_heatmap_type(attribution, curr_color, "Source", use_html=False), overflow="ignore" + ) if attribution.target_attributions is not None: curr_color = colors[idx + 1] display_scores = attribution.source_attributions is None and attribution.step_scores @@ -115,7 +118,7 @@ def show_attributions( if curr_color is None and colors: curr_color = colors[idx] print("\n\n") - rprint(get_heatmap_type(attribution, curr_color, "Target", use_html=False)) + console.print(get_heatmap_type(attribution, curr_color, "Target", use_html=False), overflow="ignore") if any(x is None for x in [attribution.source_attributions, attribution.target_attributions]): idx += 1 else: @@ -250,9 +253,9 @@ def get_saliency_heatmap_rich( label: str = "", step_scores_threshold: Union[float, Dict[str, float]] = 0.5, ): - columns = [Column(header="", justify="right")] + columns = [Column(header="", justify="right", overflow="fold")] for column_label in column_labels: - columns.append(Column(header=column_label, justify="center")) + columns.append(Column(header=column_label, justify="center", overflow="fold")) table = Table( *columns, title=f"{label + ' ' if label else ''}Saliency Heatmap", diff --git a/inseq/models/huggingface_model.py b/inseq/models/huggingface_model.py index 08ee0c4a..23ab3a9c 100644 --- a/inseq/models/huggingface_model.py +++ b/inseq/models/huggingface_model.py @@ -1,7 +1,7 @@ """HuggingFace Seq2seq model.""" import logging from abc import abstractmethod -from typing import Dict, List, NoReturn, Optional, Tuple, Union +from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union import torch from torch import long @@ -69,6 +69,8 @@ def __init__( attribution_method: Optional[str] = None, tokenizer: Union[str, PreTrainedTokenizerBase, None] = None, device: Optional[str] = None, + model_kwargs: Optional[Dict[str, Any]] = {}, + tokenizer_kwargs: Optional[Dict[str, Any]] = {}, **kwargs, ) -> None: """AttributionModel subclass for Huggingface-compatible models. @@ -90,15 +92,13 @@ def __init__( raise ValueError( f"Invalid autoclass {self._autoclass}. Must be one of {[x.__name__ for x in SUPPORTED_AUTOCLASSES]}." ) - model_args = kwargs.pop("model_args", {}) - model_kwargs = kwargs.pop("model_kwargs", {}) if isinstance(model, PreTrainedModel): self.model = model else: if "output_attentions" not in model_kwargs: model_kwargs["output_attentions"] = True - self.model = self._autoclass.from_pretrained(model, *model_args, **model_kwargs) + self.model = self._autoclass.from_pretrained(model, **model_kwargs) self.model_name = self.model.config.name_or_path self.tokenizer_name = tokenizer if isinstance(tokenizer, str) else None if tokenizer is None: @@ -108,13 +108,10 @@ def __init__( "Unspecified tokenizer for model loaded from scratch. Use explicit identifier as tokenizer=" "during model loading." ) - tokenizer_inputs = kwargs.pop("tokenizer_inputs", {}) - tokenizer_kwargs = kwargs.pop("tokenizer_kwargs", {}) - if isinstance(tokenizer, PreTrainedTokenizerBase): self.tokenizer = tokenizer else: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, *tokenizer_inputs, **tokenizer_kwargs) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs) if self.model.config.pad_token_id is not None: self.pad_token = self.tokenizer.convert_ids_to_tokens(self.model.config.pad_token_id) self.tokenizer.pad_token = self.pad_token @@ -140,6 +137,8 @@ def load( attribution_method: Optional[str] = None, tokenizer: Union[str, PreTrainedTokenizerBase, None] = None, device: str = None, + model_kwargs: Optional[Dict[str, Any]] = {}, + tokenizer_kwargs: Optional[Dict[str, Any]] = {}, **kwargs, ) -> "HuggingfaceModel": """Loads a HuggingFace model and tokenizer and wraps them in the appropriate AttributionModel.""" @@ -148,9 +147,13 @@ def load( else: is_encoder_decoder = model.config.is_encoder_decoder if is_encoder_decoder: - return HuggingfaceEncoderDecoderModel(model, attribution_method, tokenizer, device, **kwargs) + return HuggingfaceEncoderDecoderModel( + model, attribution_method, tokenizer, device, model_kwargs, tokenizer_kwargs, **kwargs + ) else: - return HuggingfaceDecoderOnlyModel(model, attribution_method, tokenizer, device, **kwargs) + return HuggingfaceDecoderOnlyModel( + model, attribution_method, tokenizer, device, model_kwargs, tokenizer_kwargs, **kwargs + ) @AttributionModel.device.setter def device(self, new_device: str) -> None: @@ -409,16 +412,6 @@ class HuggingfaceEncoderDecoderModel(HuggingfaceModel, EncoderDecoderAttribution _autoclass = AutoModelForSeq2SeqLM - def __init__( - self, - model: Union[str, PreTrainedModel], - attribution_method: Optional[str] = None, - tokenizer: Union[str, PreTrainedTokenizerBase, None] = None, - device: str = None, - **kwargs, - ) -> NoReturn: - super().__init__(model, attribution_method, tokenizer, device, **kwargs) - def configure_embeddings_scale(self): encoder = self.model.get_encoder() decoder = self.model.get_decoder() @@ -470,9 +463,11 @@ def __init__( attribution_method: Optional[str] = None, tokenizer: Union[str, PreTrainedTokenizerBase, None] = None, device: str = None, + model_kwargs: Optional[Dict[str, Any]] = {}, + tokenizer_kwargs: Optional[Dict[str, Any]] = {}, **kwargs, ) -> NoReturn: - super().__init__(model, attribution_method, tokenizer, device, **kwargs) + super().__init__(model, attribution_method, tokenizer, device, model_kwargs, tokenizer_kwargs, **kwargs) self.tokenizer.padding_side = "left" self.tokenizer.truncation_side = "left" if self.pad_token is None: diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py index 5f59ea04..29f81615 100644 --- a/inseq/utils/__init__.py +++ b/inseq/utils/__init__.py @@ -1,5 +1,5 @@ from .alignment_utils import get_adjusted_alignments, get_aligned_idx -from .argparse import InseqArgumentParser +from .argparse import InseqArgumentParser, cli_arg from .cache import INSEQ_ARTIFACTS_CACHE, INSEQ_HOME_CACHE, cache_results from .errors import ( InseqDeprecationWarning, @@ -117,4 +117,5 @@ "get_aligned_idx", "top_p_logits_mask", "filter_logits", + "cli_arg", ] diff --git a/inseq/utils/alignment_utils.py b/inseq/utils/alignment_utils.py index 86a1a18d..1b3e256b 100644 --- a/inseq/utils/alignment_utils.py +++ b/inseq/utils/alignment_utils.py @@ -231,7 +231,7 @@ def auto_align_sequences( clean_a_tokens, removed_a_token_idxs = clean_tokens(a_tokens, filter_special_tokens) clean_b_tokens, removed_b_token_idxs = clean_tokens(b_tokens, filter_special_tokens) if len(removed_a_token_idxs) != len(removed_b_token_idxs): - logger.warning( + logger.debug( "The number of special tokens in the target and contrast sequences do not match. " "Trying to match special tokens based on their identity." ) @@ -266,7 +266,7 @@ def auto_align_sequences( alignments=a_to_b_aligns_with_special_tokens, ) except Exception as e: - logger.warning( + logger.error( "Failed to compute alignments using the aligner. " f"Please check the following error and provide custom alignments if needed.\n{e}" ) @@ -302,7 +302,7 @@ def get_adjusted_alignments( ).alignments alignments = [(a_idx, b_idx) for a_idx, b_idx in alignments if start_pos <= a_idx < end_pos] is_auto_aligned = True - logger.warning( + logger.debug( f"Using {ALIGN_MODEL_ID} for automatic alignments. Provide custom alignments for non-linguistic " f"sequences, or for languages not covered by the aligner." ) @@ -316,13 +316,14 @@ def get_adjusted_alignments( # Filter alignments (restrict to one per token) filter_aligns = [] - for pair_idx in range(start_pos, end_pos): - match_pairs = [(p0, p1) for p0, p1 in alignments if p0 == pair_idx and 0 <= p1 < len(contrast_tokens)] - if match_pairs: - # If found, use the first match that containing an unaligned target token, first match otherwise - match_pairs_unaligned = [p for p in match_pairs if p[1] not in [f[1] for f in filter_aligns]] - valid_match = match_pairs_unaligned[0] if match_pairs_unaligned else match_pairs[0] - filter_aligns.append(valid_match) + if len(alignments) > 0: + for pair_idx in range(start_pos, end_pos): + match_pairs = [(p0, p1) for p0, p1 in alignments if p0 == pair_idx and 0 <= p1 < len(contrast_tokens)] + if match_pairs: + # If found, use the first match that containing an unaligned target token, first match otherwise + match_pairs_unaligned = [p for p in match_pairs if p[1] not in [f[1] for f in filter_aligns]] + valid_match = match_pairs_unaligned[0] if match_pairs_unaligned else match_pairs[0] + filter_aligns.append(valid_match) # Filling alignments with missing tokens if fill_missing: @@ -333,10 +334,10 @@ def get_adjusted_alignments( # Default behavior: fill missing alignments with 1:1 position alignments starting from the bottom of the # two sequences if not match_pairs: - if (len(contrast_tokens) - step_idx) < start_pos: - filled_alignments.append((pair_idx, len(contrast_tokens) - 1)) - else: + if (len(contrast_tokens) - step_idx) > 0: filled_alignments.append((pair_idx, len(contrast_tokens) - step_idx)) + else: + filled_alignments.append((pair_idx, len(contrast_tokens) - 1)) if filter_aligns != filled_alignments: existing_aligns_message = ( @@ -346,13 +347,13 @@ def get_adjusted_alignments( "No target alignments were provided for the contrastive target. " "Use e.g. 'contrast_targets_alignments=[(0,1), ...] to provide them in model.attribute" ) - logger.warning( + logger.debug( f"{existing_aligns_message if filter_aligns else no_aligns_message}\n" "Filling missing position with right-aligned 1:1 position alignments." ) filter_aligns = sorted(set(filled_alignments), key=lambda x: (x[0], x[1])) if is_auto_aligned or (fill_missing and filter_aligns != filled_alignments): - logger.warning(f"Generated alignments: {filter_aligns}") + logger.debug(f"Generated alignments: {filter_aligns}") return filter_aligns diff --git a/inseq/utils/argparse.py b/inseq/utils/argparse.py index 1752aaec..f074274e 100644 --- a/inseq/utils/argparse.py +++ b/inseq/utils/argparse.py @@ -15,12 +15,15 @@ import dataclasses import json import sys +import types from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError from copy import copy from enum import Enum from inspect import isclass from pathlib import Path -from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints +from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints + +import yaml DataClass = NewType("DataClass", Any) DataClassType = NewType("DataClassType", Any) @@ -40,17 +43,90 @@ def string_to_bool(v): ) +def make_choice_type_function(choices: list) -> Callable[[str], Any]: + """ + Creates a mapping function from each choices string representation to the actual value. Used to support multiple + value types for a single argument. + + Args: + choices (list): List of choices. + + Returns: + Callable[[str], Any]: Mapping function from string representation to actual value for each choice. + """ + str_to_choice = {str(choice): choice for choice in choices} + return lambda arg: str_to_choice.get(arg, arg) + + +def cli_arg( + *, + aliases: Union[str, List[str]] = None, + help: str = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[[], Any] = dataclasses.MISSING, + choices: List[Any] = None, + metadata: dict = None, + **kwargs, +) -> dataclasses.Field: + """Argument helper enabling a concise syntax to create dataclass fields for parsing with `InseqArgumentParser`. + + Example comparing the use of `cli_arg` and `dataclasses.field`: + ``` + @dataclass + class Args: + regular_arg: str = dataclasses.field(default="Test", metadata={"aliases": ["--example", "-e"], "help": "Long"}) + cli_arg: str = cli_arg(default="Test", aliases=["--example", "-e"], help="What a nice syntax!") + ``` + + Args: + aliases (Union[str, List[str]], optional): + Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`. + Defaults to None. + help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None. + default (Any, optional): + Default value for the argument. If not default or default_factory is specified, the argument is required. + Defaults to dataclasses.MISSING. + default_factory (Callable[[], Any], optional): + The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide + default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`. + Defaults to dataclasses.MISSING. + choices (List[Any], optional): + List of choices. If specified, the argument will be restricted to these choices. Defaults to None. + metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None. + + Returns: + Field: A `dataclasses.Field` with the desired properties. + """ + if metadata is None: + # Important, don't use as default param in function signature: dict is mutable and shared across function calls + metadata = {} + if aliases is not None: + metadata["aliases"] = aliases + if help is not None: + metadata["help"] = help + if choices is not None: + metadata["choices"] = choices + return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs) + + class InseqArgumentParser(ArgumentParser): - """Taken from https://github.com/huggingface/transformers/blob/main/src/transformers/hf_argparser.py.""" + """Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/hf_argparser.py. + This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments. + + The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed) + arguments to the parser after initialization and you'll get the output back after parsing as an additional + namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass. + """ dataclass_types: Iterable[DataClassType] def __init__(self, dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None, **kwargs): - """Args: - dataclass_types: - Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args. - kwargs: - (Optional) Passed to `argparse.ArgumentParser()` in the regular way. + """ + Args: + dataclass_types (`Union[DataClassType, Iterable[DataClassType]]`, *optional*): + Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args. + kwargs (`Dict[str, Any]`, *optional*): + Passed to `argparse.ArgumentParser()` in the regular way. """ # To make the default appear when using --help if "formatter_class" not in kwargs: @@ -67,7 +143,6 @@ def __init__(self, dataclass_types: Optional[Union[DataClassType, Iterable[DataC def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): field_name = f"--{field.name}" kwargs = field.metadata.copy() - alias = kwargs.pop("alias", None) # field.metadata is not used at all by Data Classes, # it is provided as a third-party extension mechanism. if isinstance(field.type, str): @@ -76,11 +151,25 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): "`typing.get_type_hints` method by default" ) + aliases = kwargs.pop("aliases", []) + if isinstance(aliases, str): + aliases = [aliases] + origin_type = getattr(field.type, "__origin__", field.type) - if origin_type is Union: - if len(field.type.__args__) != 2 or type(None) not in field.type.__args__: - raise ValueError("Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union`") - if bool not in field.type.__args__: + if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)): + if str not in field.type.__args__ and ( + len(field.type.__args__) != 2 or type(None) not in field.type.__args__ + ): + raise ValueError( + "Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because" + " the argument parser only supports one type per argument." + f" Problem encountered in field '{field.name}'." + ) + if type(None) not in field.type.__args__: + # filter `str` in Union + field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1] + origin_type = getattr(field.type, "__origin__", field.type) + elif bool not in field.type.__args__: # filter `NoneType` in Union (except for `Union[bool, NoneType]`) field.type = ( field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1] @@ -90,14 +179,19 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): # A variable to store kwargs for a boolean field, if needed # so that we can init a `no_*` complement argument (see below) bool_kwargs = {} - if isinstance(field.type, type) and issubclass(field.type, Enum): - kwargs["choices"] = [x.value for x in field.type] - kwargs["type"] = type(kwargs["choices"][0]) + if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)): + if origin_type is Literal: + kwargs["choices"] = field.type.__args__ + else: + kwargs["choices"] = [x.value for x in field.type] + + kwargs["type"] = make_choice_type_function(kwargs["choices"]) + if field.default is not dataclasses.MISSING: kwargs["default"] = field.default else: kwargs["required"] = True - elif field.type is bool or field.type is Optional[bool]: + elif field.type is bool or field.type == Optional[bool]: # Copy the currect kwargs to use to instantiate a `no_*` complement argument below. # We do not initialize it here because the `no_*` alternative must be instantiated after the real argument bool_kwargs = copy(kwargs) @@ -113,6 +207,9 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): kwargs["nargs"] = "?" # This is the value that will get picked if we do --field_name (without value) kwargs["const"] = True + elif field.type is dict: + kwargs["type"] = json.loads + kwargs["default"] = {} elif isclass(origin_type) and issubclass(origin_type, list): kwargs["type"] = field.type.__args__[0] kwargs["nargs"] = "+" @@ -128,16 +225,13 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): kwargs["default"] = field.default_factory() else: kwargs["required"] = True - if alias is not None: - parser.add_argument(field_name, alias, **kwargs) - else: - parser.add_argument(field_name, **kwargs) + parser.add_argument(field_name, *aliases, **kwargs) # Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added. # Order is important for arguments with the same destination! # We use a copy of earlier kwargs because the original kwargs have changed a lot before reaching down # here and we do not need those changes/additional keys. - if field.default is True and (field.type is bool or field.type is Optional[bool]): + if field.default is True and (field.type is bool or field.type == Optional[bool]): bool_kwargs["default"] = False parser.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **bool_kwargs) @@ -149,13 +243,25 @@ def _add_dataclass_arguments(self, dtype: DataClassType): try: type_hints: Dict[str, type] = get_type_hints(dtype) - except NameError as err: + except NameError as ex: raise RuntimeError( - f"Type resolution failed for f{dtype}. Try declaring the class in global scope or " + f"Type resolution failed for {dtype}. Try declaring the class in global scope or " "removing line of `from __future__ import annotations` which opts in Postponed " "Evaluation of Annotations (PEP 563)" - ) from err - + ) from ex + except TypeError as ex: + # Remove this block when we drop Python 3.9 support + if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex): + python_version = ".".join(map(str, sys.version_info[:3])) + raise RuntimeError( + f"Type resolution failed for {dtype} on Python {python_version}. Try removing " + "line of `from __future__ import annotations` which opts in union types as " + "`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To " + "support Python versions that lower than 3.10, you need to use " + "`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of " + "`X | None`." + ) from ex + raise for field in dataclasses.fields(dtype): if not field.init: continue @@ -163,9 +269,15 @@ def _add_dataclass_arguments(self, dtype: DataClassType): self._parse_dataclass_field(parser, field) def parse_args_into_dataclasses( - self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None + self, + args=None, + return_remaining_strings=False, + look_for_args_file=True, + args_filename=None, + args_file_flag=None, ) -> Tuple[DataClass, ...]: - """Parse command-line args into instances of the specified dataclass types. + """ + Parse command-line args into instances of the specified dataclass types. This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at: docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args @@ -180,6 +292,9 @@ def parse_args_into_dataclasses( process, and will append its potential content to the command line args. args_filename: If not None, will uses this file instead of the ".args" file specified in the previous argument. + args_file_flag: + If not None, will look for a file in the command-line args specified with this flag. The flag can be + specified multiple times and precedence is determined by the order (last one wins). Returns: Tuple consisting of: @@ -189,17 +304,36 @@ def parse_args_into_dataclasses( after initialization. - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args) """ - if args_filename or (look_for_args_file and len(sys.argv)): + + if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)): + args_files = [] + if args_filename: - args_file = Path(args_filename) - else: - args_file = Path(sys.argv[0]).with_suffix(".args") + args_files.append(Path(args_filename)) + elif look_for_args_file and len(sys.argv): + args_files.append(Path(sys.argv[0]).with_suffix(".args")) + + # args files specified via command line flag should overwrite default args files so we add them last + if args_file_flag: + # Create special parser just to extract the args_file_flag values + args_file_parser = ArgumentParser() + args_file_parser.add_argument(args_file_flag, type=str, action="append") + + # Use only remaining args for further parsing (remove the args_file_flag) + cfg, args = args_file_parser.parse_known_args(args=args) + cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None) - if args_file.exists(): - fargs = args_file.read_text().split() - args = fargs + args if args is not None else fargs + sys.argv[1:] - # in case of duplicate arguments the first one has precedence - # so we append rather than prepend. + if cmd_args_file_paths: + args_files.extend([Path(p) for p in cmd_args_file_paths]) + + file_args = [] + for args_file in args_files: + if args_file.exists(): + file_args += args_file.read_text().split() + + # in case of duplicate arguments the last one has precedence + # args specified via the command line should overwrite args from files, so we add them last + args = file_args + args if args is not None else file_args + sys.argv[1:] namespace, remaining_args = self.parse_known_args(args=args) outputs = [] for dtype in self.dataclass_types: @@ -220,27 +354,72 @@ def parse_args_into_dataclasses( return (*outputs,) - def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: - """Alternative helper method that does not use `argparse` at all, instead loading a json file and populating - the dataclass types. + def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: """ - data = json.loads(Path(json_file).read_text()) - outputs = [] - for dtype in self.dataclass_types: - keys = {f.name for f in dataclasses.fields(dtype) if f.init} - inputs = {k: v for k, v in data.items() if k in keys} - obj = dtype(**inputs) - outputs.append(obj) - return (*outputs,) + Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass + types. - def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: - """Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the - dataclass types. + Args: + args (`dict`): + dict containing config values + allow_extra_keys (`bool`, *optional*, defaults to `False`): + Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed. + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer. """ + unused_keys = set(args.keys()) outputs = [] for dtype in self.dataclass_types: keys = {f.name for f in dataclasses.fields(dtype) if f.init} inputs = {k: v for k, v in args.items() if k in keys} + unused_keys.difference_update(inputs.keys()) obj = dtype(**inputs) outputs.append(obj) - return (*outputs,) + if not allow_extra_keys and unused_keys: + raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}") + return tuple(outputs) + + def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: + """ + Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the + dataclass types. + + Args: + json_file (`str` or `os.PathLike`): + File name of the json file to parse + allow_extra_keys (`bool`, *optional*, defaults to `False`): + Defaults to False. If False, will raise an exception if the json file contains keys that are not + parsed. + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer. + """ + with open(Path(json_file), encoding="utf-8") as open_json_file: + data = json.loads(open_json_file.read()) + outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys) + return tuple(outputs) + + def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: + """ + Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the + dataclass types. + + Args: + yaml_file (`str` or `os.PathLike`): + File name of the yaml file to parse + allow_extra_keys (`bool`, *optional*, defaults to `False`): + Defaults to False. If False, will raise an exception if the json file contains keys that are not + parsed. + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer. + """ + outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys) + return tuple(outputs) diff --git a/inseq/utils/misc.py b/inseq/utils/misc.py index 69172d34..29b1bcbf 100644 --- a/inseq/utils/misc.py +++ b/inseq/utils/misc.py @@ -71,7 +71,7 @@ def pretty_list(l: Optional[Sequence[Any]], lpad: int = 8) -> str: return out_txt if len(l) > 20: return out_txt - return f"{out_txt}:{_pretty_list(l, lpad)}" + return f"{out_txt}: {_pretty_list(l, lpad)}" def pretty_tensor(t: Optional[Tensor] = None, lpad: int = 8) -> str: @@ -99,8 +99,12 @@ def pretty_dict(d: Dict[str, Any], lpad: int = 4) -> str: out_txt += pretty_dict(v, lpad + 4) elif hasattr(v, "to_dict") and not isinstance(v, type): out_txt += f"{v.__class__.__name__}({pretty_dict(v.to_dict(), lpad + 4)})" + elif v is None: + out_txt += "None" + elif isinstance(v, str): + out_txt += f'"{v}"' else: - out_txt += "None" if v is None else str(v) + out_txt += str(v) out_txt += ",\n" return out_txt + f"{' ' * (lpad - 4)}}}" diff --git a/tests/commands/__init__.py b/tests/commands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/commands/test_attribute_context.py b/tests/commands/test_attribute_context.py new file mode 100644 index 00000000..5a076f51 --- /dev/null +++ b/tests/commands/test_attribute_context.py @@ -0,0 +1,337 @@ +import pytest +from pytest import fixture +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, GPT2LMHeadModel, MarianMTModel + +from inseq.commands.attribute_context import AttributeContextArgs, AttributeContextOutput, CCIOutput +from inseq.commands.attribute_context.attribute_context import attribute_context + + +@fixture(scope="session") +def encdec_model() -> MarianMTModel: + return AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-fr") + + +@fixture(scope="session") +def deconly_model() -> GPT2LMHeadModel: + return AutoModelForCausalLM.from_pretrained("gpt2") + + +def round_scores(cli_out: AttributeContextOutput) -> AttributeContextOutput: + cli_out.cti_scores = [round(score, 2) for score in cli_out.cti_scores] + for idx in range(len(cli_out.cci_scores)): + cci = cli_out.cci_scores[idx] + cli_out.cci_scores[idx].cti_score = round(cli_out.cci_scores[idx].cti_score, 2) + if cci.input_context_scores is not None: + cli_out.cci_scores[idx].input_context_scores = [round(s, 2) for s in cci.input_context_scores] + if cci.output_context_scores is not None: + cli_out.cci_scores[idx].output_context_scores = [round(s, 2) for s in cci.output_context_scores] + return cli_out + + +def test_in_out_ctx_encdec_whitespace_sep(encdec_model: MarianMTModel): + # Base case for context-aware encoder-decoder: no language tag, no special token in separator + # source context (input context) is translated into target context (output context). + in_out_ctx_encdec_whitespace_sep = AttributeContextArgs( + model_name_or_path=encdec_model, + input_context_text="The girls were away.", + input_current_text="Where are they?", + output_template="{context} {current}", + input_template="{context} {current}", + attributed_fn="contrast_prob_diff", + show_viz=False, + # Pre-defining natural model outputs to avoid user input in unit tests + output_context_text="", + output_current_text="Où sont-elles?", + add_output_info=False, + ) + expected_output = AttributeContextOutput( + input_context="The girls were away.", + input_context_tokens=["▁The", "▁girls", "▁were", "▁away", "."], + output_context="", + output_context_tokens=[], + output_current="Où sont-elles?", + output_current_tokens=["▁Où", "▁sont", "-", "elles", "?"], + cti_scores=[1.36, 0.08, 0.34, 1.23, 0.27], + cci_scores=[ + CCIOutput( + cti_idx=0, + cti_token="▁Où", + cti_score=1.36, + contextual_prefix="Où", + contextless_prefix="Où", + input_context_scores=[0.01, 0.01, 0.01, 0.01, 0.01], + output_context_scores=[], + ), + CCIOutput( + cti_idx=3, + cti_token="elles", + cti_score=1.23, + contextual_prefix="Où sont-elles", + contextless_prefix="Où sont-ils", + input_context_scores=[0.03, 0.12, 0.03, 0.03, 0.02], + output_context_scores=[], + ), + ], + info=None, + ) + cli_out = attribute_context(in_out_ctx_encdec_whitespace_sep) + assert round_scores(cli_out) == expected_output + + +def test_in_ctx_deconly(deconly_model: GPT2LMHeadModel): + # Base case for context-aware decoder-only model with input-only context. + in_ctx_deconly = AttributeContextArgs( + model_name_or_path=deconly_model, + input_context_text="George was sick yesterday.", + input_current_text="His colleagues asked him to come", + attributed_fn="contrast_prob_diff", + show_viz=False, + add_output_info=False, + ) + expected_output = AttributeContextOutput( + input_context="George was sick yesterday.", + input_context_tokens=["George", "Ä was", "Ä sick", "Ä yesterday", "."], + output_context=None, + output_context_tokens=None, + output_current="to the hospital. He said he was fine", + output_current_tokens=["to", "Ä the", "Ä hospital", ".", "Ä He", "Ä said", "Ä he", "Ä was", "Ä fine"], + cti_scores=[0.31, 0.25, 0.55, 0.16, 0.43, 0.19, 0.13, 0.07, 0.37], + cci_scores=[ + CCIOutput( + cti_idx=2, + cti_token="Ä hospital", + cti_score=0.55, + contextual_prefix="George was sick yesterday. His colleagues asked him to come to the hospital", + contextless_prefix="His colleagues asked him to come to the office", + input_context_scores=[0.39, 0.29, 0.52, 0.26, 0.16], + output_context_scores=None, + ) + ], + info=None, + ) + cli_out = attribute_context(in_ctx_deconly) + assert round_scores(cli_out) == expected_output + + +def test_out_ctx_deconly(deconly_model: GPT2LMHeadModel): + # Base case for context-aware decoder-only model with forced output context mocking a reasoning chain. + out_ctx_deconly = AttributeContextArgs( + model_name_or_path=deconly_model, + output_template="\n\nLet's think step by step:\n{context}\n\nAnswer:\n{current}", + input_template="{current}", + input_current_text="Question: How many pairs of legs do 10 horses have?", + output_context_text="1. A horse has 4 legs.\n2. 10 horses have 40 legs.\n3. 40 legs make 20 pairs of legs.", + output_current_text="20 pairs of legs.", + attributed_fn="contrast_prob_diff", + show_viz=False, + add_output_info=False, + ) + expected_output = AttributeContextOutput( + input_context=None, + input_context_tokens=None, + output_context="1. A horse has 4 legs.\n2. 10 horses have 40 legs.\n3. 40 legs make 20 pairs of legs.", + output_context_tokens=[ + "1", + ".", + "Ä A", + "Ä horse", + "Ä has", + "Ä 4", + "Ä legs", + ".", + "Ċ", + "2", + ".", + "Ä 10", + "Ä horses", + "Ä have", + "Ä 40", + "Ä legs", + ".", + "Ċ", + "3", + ".", + "Ä 40", + "Ä legs", + "Ä make", + "Ä 20", + "Ä pairs", + "Ä of", + "Ä legs", + ".", + ], + output_current="20 pairs of legs.", + output_current_tokens=["20", "Ä pairs", "Ä of", "Ä legs", "."], + cti_scores=[4.53, 1.33, 0.43, 0.74, 0.93], + cci_scores=[ + CCIOutput( + cti_idx=0, + cti_token="20 → Ä 20", + cti_score=4.53, + contextual_prefix="Question: How many pairs of legs do 10 horses have?\n\nLet's think step by step:\n1. A horse has 4 legs.\n2. 10 horses have 40 legs.\n3. 40 legs make 20 pairs of legs.\n\nAnswer:\n20", + contextless_prefix="Question: How many pairs of legs do 10 horses have?\n", + input_context_scores=None, + output_context_scores=[0.0] * 28, + ), + ], + info=None, + ) + cli_out = attribute_context(out_ctx_deconly) + assert round_scores(cli_out).cci_scores[0] == expected_output.cci_scores[0] + + +def test_in_out_ctx_deconly(deconly_model: GPT2LMHeadModel): + # Base case for context-aware decoder-only model with input and forced output context. + in_out_ctx_deconly = AttributeContextArgs( + model_name_or_path=deconly_model, + input_context_text="George was sick yesterday.", + input_current_text="His colleagues asked him if", + output_context_text="something was wrong. He said", + attributed_fn="contrast_prob_diff", + show_viz=False, + add_output_info=False, + ) + expected_output = AttributeContextOutput( + input_context="George was sick yesterday.", + input_context_tokens=["George", "Ä was", "Ä sick", "Ä yesterday", "."], + output_context="something was wrong. He said", + output_context_tokens=["something", "Ä was", "Ä wrong", ".", "Ä He", "Ä said"], + output_current="he was fine.", + output_current_tokens=["he", "Ä was", "Ä fine", "."], + cti_scores=[1.2, 0.72, 1.5, 0.49], + cci_scores=[ + CCIOutput( + cti_idx=2, + cti_token="Ä fine", + cti_score=1.5, + contextual_prefix="George was sick yesterday. His colleagues asked him if something was wrong. He said he was fine", + contextless_prefix="His colleagues asked him if he was a", + input_context_scores=[0.19, 0.15, 0.33, 0.13, 0.15], + output_context_scores=[0.08, 0.07, 0.14, 0.12, 0.09, 0.14], + ) + ], + info=None, + ) + cli_out = attribute_context(in_out_ctx_deconly) + assert round_scores(cli_out) == expected_output + + +def test_in_ctx_encdec_special_sep(): + # Encoder-decoder model with special separator tags in input only, context is given only in the source + # (input context) but not produced in the target (output context). + in_ctx_encdec_special_sep = AttributeContextArgs( + model_name_or_path="context-mt/scat-marian-small-ctx4-cwd1-en-fr", + input_context_text="The girls were away.", + input_current_text="Where are they?", + output_template="{current}", + input_template="{context} {current}", + special_tokens_to_keep=[""], + attributed_fn="contrast_prob_diff", + show_viz=False, + add_output_info=False, + ) + expected_output = AttributeContextOutput( + input_context="The girls were away.", + input_context_tokens=["▁The", "▁girls", "▁were", "▁away", "."], + output_context=None, + output_context_tokens=None, + output_current="Où sont-elles ?", + output_current_tokens=["▁Où", "▁sont", "-", "elles", "▁?"], + cti_scores=[0.08, 0.04, 0.01, 0.32, 0.06], + cci_scores=[ + CCIOutput( + cti_idx=3, + cti_token="elles", + cti_score=0.32, + contextual_prefix="Où sont-elles", + contextless_prefix="Où sont-elles", + input_context_scores=[0.0, 0.0, 0.0, 0.0, 0.0], + output_context_scores=None, + ) + ], + info=None, + ) + cli_out = attribute_context(in_ctx_encdec_special_sep) + assert round_scores(cli_out) == expected_output + + +def test_in_out_ctx_encdec_special_sep(): + # Encoder-decoder model with special separator tags in input and output, context is given in the source (input context) + # and produced in the target (output context) before the special token separator. + in_out_ctx_encdec_special_sep = AttributeContextArgs( + model_name_or_path="context-mt/scat-marian-small-target-ctx4-cwd0-en-fr", + input_context_text="The girls were away.", + input_current_text="Where are they?", + output_template="{context} {current}", + input_template="{context} {current}", + special_tokens_to_keep=[""], + attributed_fn="contrast_prob_diff", + show_viz=False, + add_output_info=False, + # Pre-defining natural model outputs to avoid user input in unit tests + output_context_text="Les filles étaient parties.", + ) + expected_output = AttributeContextOutput( + input_context="The girls were away.", + input_context_tokens=["▁The", "▁girls", "▁were", "▁away", "."], + output_context="Les filles étaient parties.", + output_context_tokens=["▁Les", "▁filles", "▁étaient", "▁parties", "."], + output_current="Où sont-elles ?", + output_current_tokens=["▁Où", "▁sont", "-", "elles", "▁?"], + cti_scores=[0.17, 0.03, 0.02, 3.99, 0.0], + cci_scores=[ + CCIOutput( + cti_idx=3, + cti_token="elles", + cti_score=3.99, + contextual_prefix="Les filles étaient parties. Où sont-elles", + contextless_prefix="Où sont-ils", + input_context_scores=[0.0, 0.0, 0.0, 0.0, 0.0], + output_context_scores=[0.0] * 5, + ) + ], + ) + cli_out = attribute_context(in_out_ctx_encdec_special_sep) + assert round_scores(cli_out) == expected_output + + +@pytest.mark.slow +def test_in_out_ctx_encdec_langtag_whitespace_sep(): + # Base case for context-aware encoder-decoder model with language tag in input and output. + # Context is given in the source (input context) and translated into target context (output context) + in_out_ctx_encdec_langtag_whitespace_sep = AttributeContextArgs( + model_name_or_path="facebook/mbart-large-50-one-to-many-mmt", + input_context_text="The girls were away.", + input_current_text="Where are they?", + output_template="{context} {current}", + input_template="{context} {current}", + tokenizer_kwargs={"src_lang": "en_XX", "tgt_lang": "fr_XX"}, + attributed_fn="contrast_prob_diff", + show_viz=False, + add_output_info=False, + show_intermediate_outputs=False, + # Pre-defining natural model outputs to avoid user input in unit tests + output_context_text="Les filles étaient loin.", + ) + expected_output = AttributeContextOutput( + input_context="The girls were away.", + input_context_tokens=["▁The", "▁girls", "▁were", "▁away", "."], + output_context="Les filles étaient loin.", + output_context_tokens=["▁Les", "▁filles", "▁étaient", "▁loin", "."], + output_current="Où sont-elles?", + output_current_tokens=["▁O", "ù", "▁sont", "-", "elles", "?"], + cti_scores=[0.33, 0.03, 0.21, 0.52, 4.49, 0.01], + cci_scores=[ + CCIOutput( + cti_idx=4, + cti_token="elles", + cti_score=4.49, + contextual_prefix="Les filles étaient loin. Où sont-elles", + contextless_prefix="Où sont-ils", + input_context_scores=[0.0, 0.0, 0.0, 0.0, 0.0], + output_context_scores=[0.0, 0.01, 0.0, 0.0, 0.0], + ) + ], + ) + cli_out = attribute_context(in_out_ctx_encdec_langtag_whitespace_sep) + assert round_scores(cli_out) == expected_output