Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INF encountered when using sampling with temperature. #19509

Open
1 of 4 tasks
ElliottYan opened this issue Oct 12, 2022 · 3 comments
Open
1 of 4 tasks

INF encountered when using sampling with temperature. #19509

ElliottYan opened this issue Oct 12, 2022 · 3 comments
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@ElliottYan
Copy link

ElliottYan commented Oct 12, 2022

System Info

latest transformers version == 4.24.0
When generating samples with mBART, I encounter this problem:
image

Looking deeply into the codes, I find the problem roots from the beam score added to next_token_scores here:

next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)

The original value of beam_scores is 0, but when using temperature like 0.5, the score is also divided the temperature value in logit_warper and gets larger and larger. And finally it causes the overflow of next_token_scores.

Who can help?

@patrickvonplaten @Narsil @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I provide a simple code that can reproduce this issue.

import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

model = model.cuda()

src = 'In einem Notruf erzählte Professor Shannon Lamb mit einer etwas zittrigen Stimme der Polizei, dass er seine Freundin erschossen habe und dass die Beamten zu seinem Haus kommen müssten.'

encoded_hi = tokenizer(src, return_tensors="pt", padding=True).to('cuda') # do_sample=True
generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id['en_XX'], temperature=0.5, do_sample=True, num_beams=10, num_return_sequences=10)

tgt_txt = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

Expected behavior

I think this should be solved but I'm not sure about the effect of the beam_scores.

@gante
Copy link
Member

gante commented Oct 12, 2022

Hi @ElliottYan 👋 Thank you for pointing it out, it seems like a bug indeed. I will look into it.

@ElliottYan
Copy link
Author

Great! Looking forward to your solution.
For now, I just swap these two lines (L2566 && 2567) and the error disappears. But I'm not sure what I do is correct.

@patrickvonplaten
Copy link
Contributor

Are you using half or full precision here? Also inf values are not necessarily the reason for a bug, it might also be that mBart has some default logit processor settings that 0 out values which the lead to inf (cc @gante)

@huggingface huggingface deleted a comment from github-actions bot Nov 14, 2022
@gante gante added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Nov 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests

3 participants