Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No way to get ONLY the generated text, not including the prompt. #17117

Closed
4 tasks
monsieurpooh opened this issue May 6, 2022 · 5 comments
Closed
4 tasks

No way to get ONLY the generated text, not including the prompt. #17117

monsieurpooh opened this issue May 6, 2022 · 5 comments
Labels

Comments

@monsieurpooh
Copy link

monsieurpooh commented May 6, 2022

System Info

- `transformers` version: 4.15.0
- Platform: Windows-10-10.0.19041-SP0
- Python version: 3.8.5
- PyTorch version (GPU?): 1.10.2+cu113 (True)
- Tensorflow version (GPU?): 2.5.1 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no

Who can help?

@Narsil @patrickvonplaten

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

At first I thought I could just substring by the prompt's length. This doesn't work because there's a bug where it converts every instance of " ," to "," in the generated text.

For example, "Characters in this scene: King , Gertrude." becomes "Characters in this scene: King, Gertrude."

In https://github.com/huggingface/transformers/blob/main/src/transformers/generation_utils.py there are tons of options but not a single one of them allows us to specify it to ONLY return the generated text, not including the prompt.

I could do a workaround where I replace all the " ," with "," myself, but I'm sure this is a code smell which could lead to future problems.

Example code:

 gen_tokens = model.generate(input_ids, do_sample=specifiedDoSample, temperature=specifiedTemperature, max_length=calculated_max_length, min_length=calculated_min_length, repetition_penalty=specifiedRepetitionPenalty, bad_words_ids=badWordsTokens)

        #gen_text = tokenizer.batch_decode(gen_tokens)[0]

Expected behavior

Two possibilities: Either don't modify the prompt at all so I can substring by the prompt's length, or have an option where we get only the generated text not including the prompt.

@Narsil
Copy link
Contributor

Narsil commented May 9, 2022

Hi @monsieurpooh ,

generate will not change, since it's a relatively low level function, it really does exactly what it should do to the relative tensors (encoder-decoder and decoder-only don't work the same for instance).

Two suggestions:

  • Simple modification gen_text = tokenizer.batch_decode(gen_tokens[input_ids.shape[0]:])[0] (Ignore the first ids you sent)
  • Use a pipeline:
from transformers import pipeline

# This will remove the text for you.
pipe = pipeline(model="gpt2", return_full_text=False)
print(pipe("This is a test"))

Does that solve your issue ?

@monsieurpooh
Copy link
Author

Thanks so much for your help Narsil! After a tiny bit of debugging and learning how to slice tensors, I figured out the correct code is: tokenizer.batch_decode(gen_tokens[:, input_ids.shape[1]:])[0]
It returns the correct tokens even when there's a space after some commas and periods.

@Narsil
Copy link
Contributor

Narsil commented May 12, 2022

Thank you for giving the correct code here, will help other users for sure ! :)

@GonyRosenman
Copy link

is there a fix for this?

i'm using an workaround like this:
encoding = tokenizer(batch['prompt'], return_tensors='pt', padding=True).to(device) with torch.no_grad(): generated_ids = model.generate(**encoding) generated_ids = generated_ids[:, encoding.input_ids.shape[1]:] generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

@srama2512
Copy link

Thanks so much for your help Narsil! After a tiny bit of debugging and learning how to slice tensors, I figured out the correct code is: tokenizer.batch_decode(gen_tokens[:, input_ids.shape[1]:])[0] It returns the correct tokens even when there's a space after some commas and periods.

Small observation. This works only if

  • batch size = 1, or
  • all elements of the batch have the same input context length.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants