Skip to content

Commit

Permalink
add checks for contrastive attribution methods
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Dec 19, 2023
1 parent dd91de3 commit 092bb3a
Showing 1 changed file with 33 additions and 18 deletions.
51 changes: 33 additions & 18 deletions inseq/commands/attribute_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re
from dataclasses import dataclass, field
from inspect import signature
from typing import Any, Dict, List, Optional, Tuple

from rich import print as rprint
Expand All @@ -9,6 +10,7 @@
from .. import (
list_step_functions,
)
from ..attr.step_functions import get_step_function
from ..data import FeatureAttributionSequenceOutput
from ..models import HuggingfaceModel
from ..utils import cli_arg, pretty_dict
Expand Down Expand Up @@ -483,15 +485,23 @@ def get_contextless_contextual_outputs(
return output_contextual, output_contextless


def get_start_attribution_position(
output_current_text_offset: int, cti_idx: int, model: HuggingfaceModel, has_lang_tag: bool = False
) -> int:
pos_start = output_current_text_offset + cti_idx
if model.is_encoder_decoder:
pos_start += 1
if has_lang_tag:
pos_start += 1
return pos_start
def get_source_target_cci_scores(
cci_attrib_out: FeatureAttributionSequenceOutput,
input_template: str,
input_context_tokens: List[str],
has_input_context: bool,
has_output_context: bool,
model_is_encoder_decoder: bool,
model_has_lang_tag: bool,
) -> Tuple[Optional[List[float]], Optional[List[float]]]:
source_scores, target_scores = None, None
if model_is_encoder_decoder and has_input_context:
source_scores = cci_attrib_out.source_attributions[:, 0].tolist()
if model_has_lang_tag:
source_scores = source_scores[1:]
if has_output_context:
target_scores = cci_attrib_out.target_attributions[:, 0].tolist()
return source_scores, target_scores


def attribute_context(args: AttributeContextArgs):
Expand Down Expand Up @@ -595,7 +605,15 @@ def attribute_context(args: AttributeContextArgs):
args.special_tokens_to_keep,
args.generation_kwargs,
)
pos_start = get_start_attribution_position(output_current_text_offset, cti_idx, model, has_lang_tag)
cci_kwargs = {}
if "contrast_targets" in signature(get_step_function(args.attributed_fn)).parameters:
cci_kwargs["contrast_sources"] = args.input_current_text if model.is_encoder_decoder else None
cci_kwargs["contrast_targets"] = output_full_text
output_ctx_tokens = model.convert_string_to_tokens(output_full_contextual, skip_special_tokens=False)
output_ctxless_tokens = model.convert_string_to_tokens(output_full_contextless, skip_special_tokens=False)
if args.attributed_fn == "kl_divergence" or output_ctx_tokens[-1] == output_ctxless_tokens[-1]:
cci_kwargs["contrast_force_inputs"] = True
pos_start = output_current_text_offset + cti_idx + int(model.is_encoder_decoder) + int(has_lang_tag)
cci_attrib_out = model.attribute(
input_full_text,
output_full_contextual,
Expand All @@ -604,8 +622,8 @@ def attribute_context(args: AttributeContextArgs):
attr_pos_start=pos_start,
attributed_fn=args.attributed_fn,
method=args.attribution_method,
contrast_sources=args.input_current_text if model.is_encoder_decoder else None,
contrast_targets=output_full_contextless,
**cci_kwargs,
**args.attribution_kwargs,
)
cci_attrib_out = aggregate_attribution_scores(
out=cci_attrib_out,
Expand All @@ -615,12 +633,9 @@ def attribute_context(args: AttributeContextArgs):
)[0]
if args.show_intermediate_outputs:
cci_attrib_out.show(do_aggregation=False)
source_scores = None
if model.is_encoder_decoder and args.has_input_context:
source_scores = cci_attrib_out.source_attributions[:, 0].tolist()
target_scores = None
if args.has_output_context:
target_scores = cci_attrib_out.target_attributions[:, 0].tolist()
source_scores, target_scores = get_source_target_cci_scores(
cci_attrib_out,
)
cci_out = CCIOutput(
cti_idx=cti_idx,
cti_token=cti_tok,
Expand Down

0 comments on commit 092bb3a

Please sign in to comment.