From 8e856430447db02ab96bd4236f52f709326f9aa1 Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Tue, 27 Feb 2024 21:21:35 +0100 Subject: [PATCH] Fix attribute-context for current preceding context in input --- inseq/commands/attribute_context/attribute_context.py | 1 + inseq/commands/attribute_context/attribute_context_helpers.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/inseq/commands/attribute_context/attribute_context.py b/inseq/commands/attribute_context/attribute_context.py index f4cbb7cb..7edd72e2 100644 --- a/inseq/commands/attribute_context/attribute_context.py +++ b/inseq/commands/attribute_context/attribute_context.py @@ -206,6 +206,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM model, cci_attrib_out, args.input_template, + args.input_current_text, input_context_tokens, input_full_tokens, args.output_template, diff --git a/inseq/commands/attribute_context/attribute_context_helpers.py b/inseq/commands/attribute_context/attribute_context_helpers.py index f436bba4..41e9df72 100644 --- a/inseq/commands/attribute_context/attribute_context_helpers.py +++ b/inseq/commands/attribute_context/attribute_context_helpers.py @@ -402,6 +402,7 @@ def get_source_target_cci_scores( model: HuggingfaceModel, cci_attrib_out: FeatureAttributionSequenceOutput, input_template: str, + input_current_text: str, input_context_tokens: list[str], input_full_tokens: list[str], output_template: str, @@ -421,6 +422,8 @@ def get_source_target_cci_scores( else: input_scores = cci_attrib_out.target_attributions[:, 0].tolist() input_prefix, *_ = input_template.partition("{context}") + if "{current}" in input_prefix: + input_prefix = input_prefix.format(current=input_current_text) input_prefix_tokens = get_filtered_tokens(input_prefix, model, special_tokens_to_keep, is_target=False) input_prefix_len = len(input_prefix_tokens) input_scores = input_scores[input_prefix_len : len(input_context_tokens) + input_prefix_len]