Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model.generate with prefix_allowed_tokens_fn throws RuntimeError: probability tensor contains either inf, nan or element < 0 #15169

Closed
iamjanvijay opened this issue Jan 16, 2022 · 16 comments

Comments

@iamjanvijay
Copy link

iamjanvijay commented Jan 16, 2022

Environment info

  • transformers version: 4.15.0
  • Platform: Linux-5.4.0-90-generic-x86_64-with-debian-bullseye-sid
  • Python version: 3.7.12
  • PyTorch version (GPU?): 1.10.0+cu102 (True)
  • Tensorflow version (GPU?): 2.7.0 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

@patrickvonplaten @Narsil

Information

Model I am using T5ForConditionalGeneration:

The problem arises when using my own modified scripts:
Script to reproduce error is mentioned below.

The tasks I am working on is my own task or dataset:
The task requires conditional generation from T5, in such a way, that the output vocabulary is restricted to a small set.

To reproduce

  1. Run the following script to reproduce the behaviour.
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config

lm_model = 't5-small'
model = T5ForConditionalGeneration.from_pretrained(lm_model)
tokenizer = T5Tokenizer.from_pretrained(lm_model)

def restrict_decode_vocab(batch_idx, prefix_beam):
    if len(prefix_beam)==3:
        restricted_vocab = tokenizer(' ', return_tensors="pt")['input_ids'].tolist()
    else:
        restricted_vocab = tokenizer('<extra_id_0> cute dog <extra_id_1> the <pad>', return_tensors="pt")['input_ids'].tolist()
    return restricted_vocab

source = ['The <extra_id_0> walks in <extra_id_1> park .']
source_encoding = tokenizer(source[:], padding='longest', return_tensors="pt")
input_ids, attention_mask = source_encoding['input_ids'], source_encoding['attention_mask']
decoded_beams = model.generate(input_ids=input_ids, attention_mask=attention_mask, do_sample=True, num_beams=2, prefix_allowed_tokens_fn=restrict_decode_vocab, min_length=4, max_length=4, remove_invalid_values=True)
print(decoded_beams)
  1. Above script produces the following stack trace.
/home/jsingh319/uploaded_venvs/venv-koala-torch-1.10-python-3.7.12/lib/python3.7/site-packages/transformers/generation_utils.py:2259: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  next_indices = next_tokens // vocab_size
Traceback (most recent call last):
  File "reproduce_error.py", line 17, in <module>
    decoded_beams = model.generate(input_ids=input_ids, attention_mask=attention_mask, do_sample=True, num_beams=2, prefix_allowed_tokens_fn=restrict_decode_vocab, min_length=4, max_length=4, remove_invalid_values=True)
  File "/home/jsingh319/uploaded_venvs/venv-koala-torch-1.10-python-3.7.12/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/jsingh319/uploaded_venvs/venv-koala-torch-1.10-python-3.7.12/lib/python3.7/site-packages/transformers/generation_utils.py", line 1220, in generate
    **model_kwargs,
  File "/home/jsingh319/uploaded_venvs/venv-koala-torch-1.10-python-3.7.12/lib/python3.7/site-packages/transformers/generation_utils.py", line 2253, in beam_sample
    next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

Expected behavior

No error.

Possible solution

The call function for class "InfNanRemoveLogitsProcessor" should include the following statement before returning "scores".

scores[scores == float("-inf")] = torch.finfo(scores.dtype).min
@Narsil
Copy link
Contributor

Narsil commented Jan 20, 2022

@patrickvonplaten Pinging you to get your input on this.

It seems -inf are explicitely set by prefix_allowed_tokens_fn and remove_invalid_values doesn't remove float(-inf) specifically.

However the script does seems to fail currently.

I added a PR containing the "fix" to accelerate things along, but given everything is ingrained in tests and other logits processors actively use float(-inf) I am not sure this is the desired behavior.

Other options I consider viable:

  • Stop using float(-inf) directly and use torch.finfo(scores.dtype).min instead (we don't introduce infinities anymore so should solve it)
  • Change float(-inf) only before using torch.multinomialt.

@patrickvonplaten
Copy link
Contributor

I don't think that this is related in any way to the InfNanRemoveLogitsProcessor processor. IMO, the reason for the error here is that in the 3rd generation step, all values of next_token_scores are set to -inf (I think) due to the prefix_allowed_tokens_fn that you've added. This is not a bug IMO with transformers, but with the prefix_allowed_tokens_fn function as it should not set all values to -inf.

A tip from my side @iamjanvijay would be to do the following. Create the PrefixConstrainedLogitsProcessor object with your function and just play around with it locally (what happens at generation step 3) I think you'll see then that it sets all values to -inf at some point which it shouldn't do

@iamjanvijay
Copy link
Author

iamjanvijay commented Jan 21, 2022

@patrickvonplaten @Narsil Thanks for your response. I was trying to check why this is happening. I found that if the restricted_vocab at any generation step only includes "</s>" (end-of-sentence token) this error occurs. In other cases, the script doesn't encounter such an error. I'll try to look if all the elements at that generation step are set to -inf.

@Narsil
Copy link
Contributor

Narsil commented Jan 21, 2022

I'll close my PR in the meantime. We can reopen it if needed, but I tend to agree with @patrickvonplaten that having everything float(-inf) can be considered a bug already.

@mindojune
Copy link

@patrickvonplaten @Narsil Thanks for your response. I was trying to check why this is happening. I found that if the restricted_vocab at any generation step only includes "" (end-of-sentence token) this error occurs. In other cases, the script doesn't encounter such an error. I'll try to look if all the elements at that generation step are set to -inf.

Have you found what was causing the issue by any chance @iamjanvijay? I'm encountering the same issue while I'm using the generate function with BART, but I'm not using any prefix_allowed_tokens, and this error usually happens when I've been training the model for a while. Like @iamjanvijay said, I suspect something to do with cases where some tokens are masked or filtered, but I haven't really figured out where/why it's happening. I'd appreciate any pointers.

@patrickvonplaten
Copy link
Contributor

@mindojune - could you maybe open a new issue as it's not related to prefix_allowed_tokens_fn ?

@github-actions
Copy link

github-actions bot commented Mar 4, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@hongyuntw
Copy link

@mindojune
Hi, I am facing the same problem as you and this error usually happens after I have trained the model for a while. And I am also using BART.
Do you have any idea why this is happening or how to fix this error?
Thank you a lot.

@patrickvonplaten
Copy link
Contributor

Hey @hongyuntw, the reason is that BART forces the second token to be this id https://huggingface.co/facebook/bart-large/blob/main/config.json#L27 . However if you use additionally something like prefix_allowed_tokens_fn which might also not allow this id: https://huggingface.co/facebook/bart-large/blob/main/config.json#L27 => then all ids are set to -inf in which case the model cannot generate anything anymore. To solve this I would probably set this config: https://huggingface.co/facebook/bart-large/blob/main/config.json#L27 to None

@paulbricman
Copy link

paulbricman commented Dec 12, 2022

Was running into similar issues when using prefix_allowed_tokens_fn in tandem with beam-search multinomial sampling, and realized the top_k and top_p args were sometimes preventing all the allowed tokens from being used, as they were outside those two tops. no_repeat_ngram_size can have a similar effect.

Consider removing top_k and top_p if only allowing certain tokens is more important.

@chentiao
Copy link

i also have this problem,but i dont know how to fix

@Dhawgupta
Copy link

Same I am also running into this issue, has there been any resolution for this ?

@chentiao
Copy link

chentiao commented May 5, 2023

i try another way to avoid this problem,i was apply minigpt-4 in 3090,when i use v0 version weight,the problem happen,for result it ,i try many pytorch version ,but it dosent work.finally,i use a new model version v1.1,this problem also away.so,i think the problem is relative with model.for minigpt4,model decoding is relateted with fschat here

@Chanwhistle
Copy link

Did somebody solved this problem?

@amyeroberts
Copy link
Collaborator

cc @gante for reference

@gante
Copy link
Member

gante commented Feb 14, 2024

@Chanwhistle have a look at this comment

If you believe this comment does not apply to you, then a reproducer of the issue will be needed 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.