Skip to content

Commit

Permalink
Fix attribute-context for current preceding context in input
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Feb 27, 2024
1 parent dc969bd commit 6e94e11
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions inseq/commands/attribute_context/attribute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
# Prepare input/outputs (generate if necessary)
input_full_text = format_template(args.input_template, args.input_current_text, args.input_context_text)
if "{current}" in args.contextless_input_current_text:
args.input_current_text = args.contextless_input_current_text.format(current=args.input_current_text)
formatted_input_current_text = args.contextless_input_current_text.format(current=args.input_current_text)
else:
args.input_current_text = args.contextless_input_current_text
formatted_input_current_text = args.contextless_input_current_text
args.output_context_text, args.output_current_text = prepare_outputs(
model=model,
input_context_text=args.input_context_text,
Expand Down Expand Up @@ -109,11 +109,11 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
prefixed_output_current_text = args.output_current_text
else:
sep = args.decoder_input_output_separator if not args.output_current_text.startswith((" ", "\n")) else ""
prefixed_output_current_text = args.input_current_text + sep + args.output_current_text
prefixed_output_current_text = formatted_input_current_text + sep + args.output_current_text

# Part 1: Context-sensitive Token Identification (CTI)
cti_out = model.attribute(
args.input_current_text,
formatted_input_current_text,
prefixed_output_current_text,
attribute_target=model.is_encoder_decoder,
step_scores=[args.context_sensitivity_metric],
Expand Down Expand Up @@ -160,7 +160,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
if args.attributed_fn is not None and is_contrastive_step_function(args.attributed_fn):
contextless_output = get_contextless_output(
model,
args.input_current_text,
formatted_input_current_text,
output_current_tokens,
cti_idx,
cti_ranked_tokens,
Expand All @@ -171,7 +171,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
args.special_tokens_to_keep,
deepcopy(args.generation_kwargs),
)
cci_kwargs["contrast_sources"] = args.input_current_text if model.is_encoder_decoder else None
cci_kwargs["contrast_sources"] = formatted_input_current_text if model.is_encoder_decoder else None
cci_kwargs["contrast_targets"] = contextless_output
output_ctx_tokens = model.convert_string_to_tokens(
contextual_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder
Expand Down

0 comments on commit 6e94e11

Please sign in to comment.