diff --git a/inseq/commands/attribute_context/attribute_context.py b/inseq/commands/attribute_context/attribute_context.py index 2ea74ff..1469f2f 100644 --- a/inseq/commands/attribute_context/attribute_context.py +++ b/inseq/commands/attribute_context/attribute_context.py @@ -161,6 +161,9 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM cci_kwargs = {} contextless_output = None + output_ctx_tokens = model.convert_string_to_tokens( + contextual_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder + ) if args.attributed_fn is not None and is_contrastive_step_function(args.attributed_fn): if not model.is_encoder_decoder: formatted_input_current_text = concat_with_sep( @@ -181,9 +184,6 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM ) cci_kwargs["contrast_sources"] = formatted_input_current_text if model.is_encoder_decoder else None cci_kwargs["contrast_targets"] = contextless_output - output_ctx_tokens = model.convert_string_to_tokens( - contextual_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder - ) output_ctxless_tokens = model.convert_string_to_tokens( contextless_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder ) diff --git a/inseq/commands/attribute_context/attribute_context_helpers.py b/inseq/commands/attribute_context/attribute_context_helpers.py index e35aa4e..cb8793e 100644 --- a/inseq/commands/attribute_context/attribute_context_helpers.py +++ b/inseq/commands/attribute_context/attribute_context_helpers.py @@ -253,13 +253,13 @@ def prepare_outputs( if not model.is_encoder_decoder: model_input = concat_with_sep(input_full_text, "", decoder_input_output_separator) - final_current = generate_model_output( + output_gen = generate_model_output( model, model_input, generation_kwargs, special_tokens_to_keep, output_template, output_current_prefix, suffix ) # Settings 3, 4 if (has_out_ctx == use_out_ctx) and not has_out_curr: - return final_context, final_current.strip() + return final_context, output_gen.strip() # Settings 5, 6 # Try splitting the output into context and current text using ``separator``. As we have no guarantees of its