Skip to content

Commit

Permalink
Fix attribute-context with non-contrastive attribution
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Mar 14, 2024
1 parent ac3d18e commit a960309
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions inseq/commands/attribute_context/attribute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM

cci_kwargs = {}
contextless_output = None
output_ctx_tokens = model.convert_string_to_tokens(
contextual_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder
)
if args.attributed_fn is not None and is_contrastive_step_function(args.attributed_fn):
if not model.is_encoder_decoder:
formatted_input_current_text = concat_with_sep(
Expand All @@ -181,9 +184,6 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
)
cci_kwargs["contrast_sources"] = formatted_input_current_text if model.is_encoder_decoder else None
cci_kwargs["contrast_targets"] = contextless_output
output_ctx_tokens = model.convert_string_to_tokens(
contextual_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder
)
output_ctxless_tokens = model.convert_string_to_tokens(
contextless_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder
)
Expand Down
4 changes: 2 additions & 2 deletions inseq/commands/attribute_context/attribute_context_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,13 @@ def prepare_outputs(
if not model.is_encoder_decoder:
model_input = concat_with_sep(input_full_text, "", decoder_input_output_separator)

final_current = generate_model_output(
output_gen = generate_model_output(
model, model_input, generation_kwargs, special_tokens_to_keep, output_template, output_current_prefix, suffix
)

# Settings 3, 4
if (has_out_ctx == use_out_ctx) and not has_out_curr:
return final_context, final_current.strip()
return final_context, output_gen.strip()

# Settings 5, 6
# Try splitting the output into context and current text using ``separator``. As we have no guarantees of its
Expand Down

0 comments on commit a960309

Please sign in to comment.