Skip to content

Commit

Permalink
Minor fixes to attribute-context, manual contextless_output_next_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Jan 30, 2024
1 parent a210b48 commit 3240149
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 55 deletions.
63 changes: 43 additions & 20 deletions inseq/commands/attribute_context/attribute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@
CCIOutput,
filter_rank_tokens,
format_template,
get_contextless_prefix,
get_contextless_output,
get_filtered_tokens,
get_source_target_cci_scores,
prepare_outputs,
prompt_user_for_contextless_output_next_tokens,
)
from .attribute_context_viz_helpers import visualize_attribute_context

Expand Down Expand Up @@ -140,41 +141,63 @@ def attribute_context(args: AttributeContextArgs):
info=args,
)
# Part 2: Contextual Cues Imputation (CCI)
for cti_idx, cti_score, cti_tok in cti_ranked_tokens:
contextual_prefix = model.convert_tokens_to_string(
for cci_step_idx, (cti_idx, cti_score, cti_tok) in enumerate(cti_ranked_tokens):
contextual_output = model.convert_tokens_to_string(
output_full_tokens[: output_current_text_offset + cti_idx + 1], skip_special_tokens=False
)
if not contextual_prefix:
if not contextual_output:
logger.warning(
f"Empty contextual prefix for token {cti_tok} at position {cti_idx} - skipping CCI for this token."
f"Empty contextual output for token {cti_tok} at position {cti_idx} - skipping CCI for this token."
)
continue
cci_kwargs = {}
contextless_prefix = None
contextless_output = None
if args.attributed_fn is not None and is_contrastive_step_function(args.attributed_fn):
contextless_prefix = get_contextless_prefix(
model,
args.input_current_text,
output_current_tokens,
cti_idx,
args.special_tokens_to_keep,
deepcopy(args.generation_kwargs),
)
n_ctxless_next_tokens = len(args.contextless_output_next_tokens)
next_ctxless_token = None
if n_ctxless_next_tokens > 0:
if n_ctxless_next_tokens != len(cti_ranked_tokens):
raise ValueError(
"The number of manually specified contextless output next tokens must be equal to the number "
"of context-sensitive tokens identified by CTI."
)
next_ctxless_token = args.contextless_output_next_tokens[cci_step_idx]
if args.prompt_user_for_contextless_output_next_tokens:
next_ctxless_token = prompt_user_for_contextless_output_next_tokens(
output_current_tokens, cti_idx, model
)
if isinstance(next_ctxless_token, str):
next_ctxless_token = model.convert_string_to_tokens(
next_ctxless_token, skip_special_tokens=False, as_targets=model.is_encoder_decoder
)[0]
contextless_output_tokens = output_current_tokens[:cti_idx] + [next_ctxless_token]
contextless_output = model.convert_tokens_to_string(
contextless_output_tokens, skip_special_tokens=False
)
else:
contextless_output = get_contextless_output(
model,
args.input_current_text,
output_current_tokens,
cti_idx,
args.special_tokens_to_keep,
deepcopy(args.generation_kwargs),
)
cci_kwargs["contrast_sources"] = args.input_current_text if model.is_encoder_decoder else None
cci_kwargs["contrast_targets"] = contextless_prefix
cci_kwargs["contrast_targets"] = contextless_output
output_ctx_tokens = model.convert_string_to_tokens(
contextual_prefix, skip_special_tokens=False, as_targets=model.is_encoder_decoder
contextual_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder
)
output_ctxless_tokens = model.convert_string_to_tokens(
contextless_prefix, skip_special_tokens=False, as_targets=model.is_encoder_decoder
contextless_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder
)
tok_pos = -2 if model.is_encoder_decoder else -1
if args.attributed_fn == "kl_divergence" or output_ctx_tokens[tok_pos] == output_ctxless_tokens[tok_pos]:
cci_kwargs["contrast_force_inputs"] = True
pos_start = output_current_text_offset + cti_idx + int(model.is_encoder_decoder) + int(has_lang_tag)
cci_attrib_out = model.attribute(
input_full_text,
contextual_prefix,
contextual_output,
attribute_target=model.is_encoder_decoder and args.has_output_context,
show_progress=False,
attr_pos_start=pos_start,
Expand Down Expand Up @@ -208,8 +231,8 @@ def attribute_context(args: AttributeContextArgs):
cti_idx=cti_idx,
cti_token=cti_tok,
cti_score=cti_score,
contextual_prefix=contextual_prefix,
contextless_prefix=contextless_prefix,
contextual_output=contextual_output,
contextless_output=contextless_output,
input_context_scores=source_scores,
output_context_scores=target_scores,
)
Expand Down
23 changes: 23 additions & 0 deletions inseq/commands/attribute_context/attribute_context_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,22 @@ class AttributeContextMethodArgs(AttributeBaseArgs):
" prompted to manually specify which part of the generated text corresponds to the output context."
),
)
contextless_output_next_tokens: list[str] = cli_arg(
default_factory=list,
help=(
"If specified, it should provide a list of one token per CCI output indicating the next token that should"
" be force-decoded as contextless output instead of the natural output produced by"
" ``get_contextless_output``. This is ignored if the ``attributed_fn`` used is not contrastive."
),
)
prompt_user_for_contextless_output_next_tokens: bool = cli_arg(
default=False,
help=(
"If specified, the user is prompted to provide the next token that should be force-decoded as contextless"
" output instead of the natural output produced by ``get_contextless_output``. This is ignored if the"
" ``attributed_fn`` used is not contrastive."
),
)
special_tokens_to_keep: list[str] = cli_arg(
default_factory=list,
help="Special tokens to preserve in the generated string, e.g. ``<brk>`` separator between context and current.",
Expand Down Expand Up @@ -177,6 +193,13 @@ class AttributeContextMethodArgs(AttributeBaseArgs):
),
)

def __post_init__(self):
if len(self.contextless_output_next_tokens) > 0 and self.prompt_user_for_contextless_output_next_tokens:
raise ValueError(
"Only one of contextless_output_next_tokens and prompt_user_for_contextless_output_next_tokens can be"
" specified."
)


@command_args_docstring
@dataclass
Expand Down
79 changes: 64 additions & 15 deletions inseq/commands/attribute_context/attribute_context_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class CCIOutput:
cti_idx: int
cti_token: str
cti_score: float
contextual_prefix: str
contextless_prefix: str
contextual_output: str
contextless_output: str
input_context_scores: Optional[list[float]] = None
output_context_scores: Optional[list[float]] = None

Expand Down Expand Up @@ -293,6 +293,15 @@ def prepare_outputs(
return final_context, final_current.strip()


def get_scores_threshold(scores: list[float], std_weight: float) -> float:
"""Compute the threshold for a given weight."""
if std_weight is None or len(scores) == 0:
return 0
if std_weight == 0 or len(scores) == 1:
return tensor(scores).mean()
return tensor(scores).mean() + std_weight * tensor(scores).std()


def filter_rank_tokens(
tokens: list[str],
scores: list[float],
Expand All @@ -301,45 +310,43 @@ def filter_rank_tokens(
) -> tuple[list[tuple[int, float, str]], float]:
indices = list(range(0, len(scores)))
token_score_tuples = sorted(zip(indices, scores, tokens), key=lambda x: abs(x[1]), reverse=True)
threshold = None
if std_threshold is not None:
threshold = tensor(scores).mean() + std_threshold * tensor(scores).std()
token_score_tuples = [(i, s, t) for i, s, t in token_score_tuples if abs(s) >= threshold]
threshold = get_scores_threshold(scores, std_threshold)
token_score_tuples = [(i, s, t) for i, s, t in token_score_tuples if abs(s) >= threshold]
if topk:
token_score_tuples = token_score_tuples[:topk]
return token_score_tuples, threshold


def get_contextless_prefix(
def get_contextless_output(
model: HuggingfaceModel,
input_current_text: str,
output_current_tokens: list[str],
cti_idx: int,
special_tokens_to_keep: list[str] = [],
generation_kwargs: dict[str, Any] = {},
) -> tuple[str, str]:
"""Generate the contextless prefix for the current token identified as context-sensitive."""
output_current_prefix_tokens = output_current_tokens[:cti_idx]
output_current_prefix = model.convert_tokens_to_string(output_current_prefix_tokens, skip_special_tokens=False)
"""Generate the contextless output for the current token identified as context-sensitive."""
contextual_prefix_tokens = output_current_tokens[:cti_idx]
contextual_prefix = model.convert_tokens_to_string(contextual_prefix_tokens, skip_special_tokens=False)
if model.is_encoder_decoder:
# One extra token for the EOS which is always forced at the end for encoder-decoders
generation_kwargs["max_new_tokens"] = 2
decoder_input_ids = model.encode(output_current_prefix, as_targets=True).input_ids
decoder_input_ids = model.encode(contextual_prefix, as_targets=True).input_ids
if int(decoder_input_ids[0, -1]) == model.eos_token_id:
decoder_input_ids = decoder_input_ids[0, :-1][None, ...]
generation_kwargs["decoder_input_ids"] = decoder_input_ids
generation_input = input_current_text
else:
generation_kwargs["max_new_tokens"] = 1
space = " " if output_current_prefix and not output_current_prefix.startswith((" ", "\n")) else ""
generation_input = input_current_text + space + output_current_prefix
output_contextless = generate_with_special_tokens(
space = " " if contextual_prefix and not contextual_prefix.startswith((" ", "\n")) else ""
generation_input = input_current_text + space + contextual_prefix
contextless_output = generate_with_special_tokens(
model,
generation_input,
special_tokens_to_keep,
**generation_kwargs,
)
return output_contextless
return contextless_output


def get_source_target_cci_scores(
Expand Down Expand Up @@ -377,3 +384,45 @@ def get_source_target_cci_scores(
prefix_len = len(output_prefix_tokens) + int(not model.is_encoder_decoder) * len(input_full_tokens)
output_scores = output_scores[prefix_len : len(output_context_tokens) + prefix_len]
return input_scores, output_scores


def prompt_user_for_contextless_output_next_tokens(
output_current_tokens: list[str],
cti_idx: int,
model: HuggingfaceModel,
special_tokens_to_keep: list[str] = [],
) -> Optional[str]:
"""Prompt the user to provide the next tokens of the contextless output.
Args:
output_current_tokens (str): list of tokens of the current output
cti_idx (int): index of the current token identified as context-sensitive
Returns:
str: next tokens of the contextless output specified by the user. If None, the user does not want to specify
the contextless output.
"""
contextual_prefix_tokens = output_current_tokens[:cti_idx]
contextual_prefix = model.convert_tokens_to_string(contextual_prefix_tokens, skip_special_tokens=False)
contextual_output_token = get_filtered_tokens(
output_current_tokens[cti_idx],
model,
special_tokens_to_keep=special_tokens_to_keep,
is_target=True,
replace_special_characters=True,
)[0]
while True:
force_contextless_output = Confirm.ask(
f'\n:arrow_right: Contextual prefix: "[bold]{contextual_prefix}[/bold]"'
f'\n:question: The token [bold]"{contextual_output_token}"[/bold] is produced in the contextual setting.'
" Do you want to specify a word for comparison?"
)
if not force_contextless_output:
return None
provided_contextless_output = Prompt.ask(
":writing_hand: Please enter the word to use for comparison with the contextual output:"
)
if provided_contextless_output.strip():
break
rprint("[prompt.invalid]The provided word is empty. Please provide a non-empty word.")
return provided_contextless_output
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
from typing import Literal, Optional, Union

from rich.console import Console
from torch import tensor

from ... import load_model
from ...models import HuggingfaceModel
from .attribute_context_args import AttributeContextArgs
from .attribute_context_helpers import AttributeContextOutput, filter_rank_tokens, get_filtered_tokens
from .attribute_context_helpers import (
AttributeContextOutput,
filter_rank_tokens,
get_filtered_tokens,
get_scores_threshold,
)


def get_formatted_procedure_details(args: AttributeContextArgs) -> str:
Expand Down Expand Up @@ -135,10 +139,7 @@ def visualize_attribute_context(
elif not isinstance(model, HuggingfaceModel):
raise TypeError(f"Unsupported model type {type(model)} for visualization.")
if cti_threshold is None and len(output.cti_scores) > 1:
cti_threshold = (
tensor(output.cti_scores).mean()
+ output.info.context_sensitivity_std_threshold * tensor(output.cti_scores).std()
)
cti_threshold = get_scores_threshold(output.cti_scores, output.info.context_sensitivity_std_threshold)
viz += "\n\n" + get_formatted_attribute_context_results(model, output.info, output, cti_threshold)
with console.capture() as _:
console.print(viz, soft_wrap=False)
Expand Down
28 changes: 14 additions & 14 deletions tests/commands/test_attribute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def test_in_out_ctx_encdec_whitespace_sep(encdec_model: MarianMTModel):
cti_idx=0,
cti_token="▁Où",
cti_score=1.36,
contextual_prefix="Où",
contextless_prefix="Où",
contextual_output="Où",
contextless_output="Où",
input_context_scores=[0.01, 0.01, 0.01, 0.01, 0.01],
output_context_scores=[],
)
Expand Down Expand Up @@ -94,8 +94,8 @@ def test_in_ctx_deconly(deconly_model: GPT2LMHeadModel):
cti_idx=2,
cti_token="Ġhospital",
cti_score=0.55,
contextual_prefix="George was sick yesterday. His colleagues asked him to come to the hospital",
contextless_prefix="His colleagues asked him to come to the office",
contextual_output="George was sick yesterday. His colleagues asked him to come to the hospital",
contextless_output="His colleagues asked him to come to the office",
input_context_scores=[0.39, 0.29, 0.52, 0.26, 0.16],
output_context_scores=None,
)
Expand Down Expand Up @@ -161,8 +161,8 @@ def test_out_ctx_deconly(deconly_model: GPT2LMHeadModel):
cti_idx=0,
cti_token="20 → Ġ20",
cti_score=4.53,
contextual_prefix="Question: How many pairs of legs do 10 horses have?\n\nLet's think step by step:\n1. A horse has 4 legs.\n2. 10 horses have 40 legs.\n3. 40 legs make 20 pairs of legs.\n\nAnswer:\n20",
contextless_prefix="Question: How many pairs of legs do 10 horses have?\n",
contextual_output="Question: How many pairs of legs do 10 horses have?\n\nLet's think step by step:\n1. A horse has 4 legs.\n2. 10 horses have 40 legs.\n3. 40 legs make 20 pairs of legs.\n\nAnswer:\n20",
contextless_output="Question: How many pairs of legs do 10 horses have?\n",
input_context_scores=None,
output_context_scores=[0.0] * 28,
),
Expand Down Expand Up @@ -197,8 +197,8 @@ def test_in_out_ctx_deconly(deconly_model: GPT2LMHeadModel):
cti_idx=2,
cti_token="Ġfine",
cti_score=1.5,
contextual_prefix="George was sick yesterday. His colleagues asked him if something was wrong. He said he was fine",
contextless_prefix="His colleagues asked him if he was a",
contextual_output="George was sick yesterday. His colleagues asked him if something was wrong. He said he was fine",
contextless_output="His colleagues asked him if he was a",
input_context_scores=[0.19, 0.15, 0.33, 0.13, 0.15],
output_context_scores=[0.08, 0.07, 0.14, 0.12, 0.09, 0.14],
)
Expand Down Expand Up @@ -236,8 +236,8 @@ def test_in_ctx_encdec_special_sep():
cti_idx=3,
cti_token="elles",
cti_score=0.32,
contextual_prefix="Où sont-elles",
contextless_prefix="Où sont-elles",
contextual_output="Où sont-elles",
contextless_output="Où sont-elles",
input_context_scores=[0.0, 0.0, 0.0, 0.0, 0.0],
output_context_scores=None,
)
Expand Down Expand Up @@ -277,8 +277,8 @@ def test_in_out_ctx_encdec_special_sep():
cti_idx=3,
cti_token="elles",
cti_score=3.99,
contextual_prefix="Les filles étaient parties.<brk> Où sont-elles",
contextless_prefix="Où sont-ils",
contextual_output="Les filles étaient parties.<brk> Où sont-elles",
contextless_output="Où sont-ils",
input_context_scores=[0.0, 0.0, 0.0, 0.0, 0.0],
output_context_scores=[0.0] * 5,
)
Expand Down Expand Up @@ -319,8 +319,8 @@ def test_in_out_ctx_encdec_langtag_whitespace_sep():
cti_idx=4,
cti_token="elles",
cti_score=4.49,
contextual_prefix="Les filles étaient loin. Où sont-elles",
contextless_prefix="Où sont-ils",
contextual_output="Les filles étaient loin. Où sont-elles",
contextless_output="Où sont-ils",
input_context_scores=[0.0, 0.0, 0.0, 0.0, 0.0],
output_context_scores=[0.0, 0.01, 0.0, 0.0, 0.0],
)
Expand Down

0 comments on commit 3240149

Please sign in to comment.