Skip to content

Commit

Permalink
Fix prefixed generation for mismatching tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Mar 13, 2024
1 parent 70be7cf commit 66f3f33
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
12 changes: 9 additions & 3 deletions inseq/commands/attribute_context/attribute_context_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,14 @@ def generate_with_special_tokens(
model: HuggingfaceModel,
model_input: str,
special_tokens_to_keep: list[str] = [],
output_generated_only: bool = True,
**generation_kwargs,
) -> str:
"""Generate text preserving special tokens in ``special_tokens_to_keep``."""
# Generate outputs, strip special tokens and remove prefix/suffix
output_gen = model.generate(model_input, skip_special_tokens=False, **generation_kwargs)[0]
output_gen = model.generate(
model_input, skip_special_tokens=False, output_generated_only=output_generated_only, **generation_kwargs
)[0]
output_tokens = get_filtered_tokens(output_gen, model, special_tokens_to_keep, is_target=True)
return model.convert_tokens_to_string(output_tokens, skip_special_tokens=False)

Expand Down Expand Up @@ -247,13 +250,15 @@ def prepare_outputs(
model_input = concat_with_sep(input_full_text, output_current_prefix, decoder_input_output_separator)
output_current_prefix = model_input

output_gen = generate_model_output(
if not model.is_encoder_decoder:
model_input = concat_with_sep(input_full_text, "", decoder_input_output_separator)

final_current = generate_model_output(
model, model_input, generation_kwargs, special_tokens_to_keep, output_template, output_current_prefix, suffix
)

# Settings 3, 4
if (has_out_ctx == use_out_ctx) and not has_out_curr:
final_current = output_gen if model.is_encoder_decoder or use_out_ctx else output_gen[len(model_input) :]
return final_context, final_current.strip()

# Settings 5, 6
Expand Down Expand Up @@ -395,6 +400,7 @@ def generate_contextless_output(
model,
generation_input,
special_tokens_to_keep,
output_generated_only=False,
**generation_kwargs,
)
return contextless_output
Expand Down
6 changes: 6 additions & 0 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def generate(
inputs: Union[TextInput, BatchEncoding],
return_generation_output: bool = False,
skip_special_tokens: bool = True,
output_generated_only: bool = False,
**kwargs,
) -> Union[list[str], tuple[list[str], ModelOutput]]:
"""Wrapper of model.generate to handle tokenization and decoding.
Expand All @@ -204,6 +205,9 @@ def generate(
Inputs to be provided to the model for generation.
return_generation_output (`bool`, *optional*, defaults to False):
If true, generation outputs are returned alongside the generated text.
output_generated_only (`bool`, *optional*, defaults to False):
If true, only the generated text is returned. Relevant for decoder-only models that would otherwise return
the full input + output.
Returns:
`Union[List[str], Tuple[List[str], ModelOutput]]`: Generated text or a tuple of generated text and
Expand All @@ -220,6 +224,8 @@ def generate(
**kwargs,
)
sequences = generation_out.sequences
if output_generated_only and not self.is_encoder_decoder:
sequences = sequences[:, inputs.input_ids.shape[1] :]
texts = self.decode(ids=sequences, skip_special_tokens=skip_special_tokens)
if return_generation_output:
return texts, generation_out
Expand Down
2 changes: 2 additions & 0 deletions tests/commands/test_attribute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_in_ctx_deconly(deconly_model: GPT2LMHeadModel):
model_name_or_path=deconly_model,
input_context_text="George was sick yesterday.",
input_current_text="His colleagues asked him to come",
output_current_text="to the hospital. He said he was fine",
attributed_fn="contrast_prob_diff",
show_viz=False,
add_output_info=False,
Expand Down Expand Up @@ -212,6 +213,7 @@ def test_in_out_ctx_deconly(deconly_model: GPT2LMHeadModel):
input_context_text="George was sick yesterday.",
input_current_text="His colleagues asked him if",
output_context_text="something was wrong. He said",
output_current_text="he was fine.",
attributed_fn="contrast_prob_diff",
show_viz=False,
add_output_info=False,
Expand Down

0 comments on commit 66f3f33

Please sign in to comment.