Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multiple eos_token_ids in model.generate(...) #21461

Merged

Conversation

tokestermw
Copy link
Contributor

@tokestermw tokestermw commented Feb 5, 2023

What does this PR do?

Fixes #20727 for using multiple eos_token_ids

Small repro

import math
import torch
unfinished_sequences = torch.tensor([1,1,1])
next_tokens = torch.tensor([797, 641,  98])
unfinished_sequences.mul((math.prod(next_tokens != i for i in eos_token_id)).long())

Error

if you run

from transformers import pipeline
generator = pipeline('text-generation', 'gpt2')
generator('hello', eos_token_id=[628, 198], do_sample=True, num_return_sequences=3)

then it errors

input = tensor([[-32]])
weight = Parameter containing:
tensor([[-0.0206,  0.0125, -0.0289,  ...,  0.0018, -0.0300,  0.0111],
        [-0.0239, -0.0158,...0,  0.0075,  0.0113],
        [-0.0177, -0.0268,  0.0023,  ...,  0.0135,  0.0077, -0.0042]],
       requires_grad=True)
padding_idx = -1, max_norm = None, norm_type = 2.0, scale_grad_by_freq = False, sparse = False
...


        if has_torch_function_variadic(input, weight):
            return handle_torch_function(
                embedding,
                (input, weight),
                input,
                weight,
                padding_idx=padding_idx,
                max_norm=max_norm,
                norm_type=norm_type,
                scale_grad_by_freq=scale_grad_by_freq,
                sparse=sparse,
            )
        if padding_idx is not None:
            if padding_idx > 0:
                assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings"
            elif padding_idx < 0:
                assert padding_idx >= -weight.size(0), "Padding_idx must be within num_embeddings"
                padding_idx = weight.size(0) + padding_idx
        else:
            padding_idx = -1
        if max_norm is not None:
            # Note [embedding_renorm contiguous]
            # `embedding_renorm_` will call .contiguous() on input anyways, so we
            # call it here and take advantage of the improved locality in the
            # `embedding` call below too.
            input = input.contiguous()
            # Note [embedding_renorm set_grad_enabled]
            # XXX: equivalent to
            # with torch.no_grad():
            #   torch.embedding_renorm_
            # remove once script supports set_grad_enabled
            _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
>       return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
E       IndexError: index out of range in self

venv/lib/python3.8/site-packages/torch/nn/functional.py:2210: IndexError

Tests

pytest tests/generation/test_utils.py::GenerationIntegrationTests::test_eos_token_id_int_and_list_greedy_search --disable-warnings -vv
pytest tests/generation/test_utils.py::GenerationIntegrationTests::test_eos_token_id_int_and_list_contrastive_search --disable-warnings -vv
pytest tests/generation/test_utils.py::GenerationIntegrationTests::test_eos_token_id_int_and_list_top_k_top_sampling --disable-warnings -vv
pytest tests/generation/test_utils.py::GenerationIntegrationTests::test_eos_token_id_int_and_list_beam_search --disable-warnings -vv

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@gante

@@ -2226,7 +2227,7 @@ def greedy_search(

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
unfinished_sequences = unfinished_sequences.mul((math.prod(next_tokens != i for i in eos_token_id)).long())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before fix, this can go beyond 0 or 1, the next input_ids gets corrupted

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 5, 2023

The documentation is not available anymore as the PR was closed or merged.

@gante
Copy link
Member

gante commented Feb 6, 2023

Hey @tokestermw 👋

Thank you for spotting the issues and adding a fix! One request, for two reasons: a) thin function wrappers are very undesirable, as they add another abstraction layer b) tensor ops should ideally be done with torch operations, otherwise there will be CPU<>GPU data movement 👉 can you replace the implementation with something like the snippet below, which computes the same thing using torch operators?

import torch
eos_token_id = torch.tensor([797, 641])
unfinished_sequences = torch.tensor([1, 1, 1])
next_tokens = torch.tensor([797, 641, 98])
next_in_eos = next_tokens.tile((eos_token_id.shape[0], 1)).ne(eos_token_id.unsqueeze(1)).prod(dim=0)
unfinished_sequences = unfinished_sequences.mul(next_in_eos).long()

@hogru
Copy link

hogru commented Feb 6, 2023

I just found the same issue I think and this is the code snippet I wanted to use for reporting the bug. Probably redundant as of now but before throwing it away, maybe it helps another user finding the issue. No further comment/processing required from my point of view:

from transformers import AutoModelForCausalLM, GenerationConfig

MODEL = "gpt2"
NUM_RETURN_SEQUENCES = 2
MAX_NEW_TOKENS = 64
CONFIG_DIR = "./generation_test"

model = AutoModelForCausalLM.from_pretrained(MODEL)
model.save_pretrained(CONFIG_DIR)

config = GenerationConfig(
    num_return_sequences=NUM_RETURN_SEQUENCES,
    max_new_tokens=MAX_NEW_TOKENS,
    return_full_text=True,
    do_sample=True,
    bos_token_id=50256,
    pad_token_id=50256,
    eos_token_id=[50000,50256],  # the 50000 is just an example to prove the issue
)
config.save_pretrained(CONFIG_DIR)
model = AutoModelForCausalLM.from_pretrained(CONFIG_DIR)

tokenizer = AutoTokenizer.from_pretrained(MODEL)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
generated = pipe("As always this is a")

print(generated[0]["generated_text"])

@tokestermw
Copy link
Contributor Author

tokestermw commented Feb 8, 2023

Thanks @gante! will make the change in a bit

Another issue I just found with beam search + multiple eos_token_id is that, on occasion we get this error:

ValueError: At most 3 tokens in tensor([  198,   198,   198,     0,   628, 14373], device='cuda:0') can be equal to
`eos_token_id: [198, 628]`. Make sure tensor([  198,   198,   198,     0,   628, 14373], device='cuda:0') are corrected.

Screenshot 2023-02-07 at 16 26 49

This is because we generate 2 * num_beams,
https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L2766

which can fail this check when we have more than one eos_token_id
https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L612

(I can post a separate issue if that's better)

@gante
Copy link
Member

gante commented Feb 8, 2023

@tokestermw if that is not breaking the existing tests, yes, let's move it to a new issue.

In essence, we probably want to keep 1+len(eos_token_id) beam candidates running, to ensure we have at least 1 non-eos_token_id candidate to proceed.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍

@gante gante requested a review from sgugger February 8, 2023 13:31
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution!

@sgugger
Copy link
Collaborator

sgugger commented Feb 8, 2023

Mmm, looks like a lot of tests have started failing @gante and @tokestermw

@tokestermw
Copy link
Contributor Author

tokestermw commented Feb 8, 2023

@sgugger
Copy link
Collaborator

sgugger commented Feb 8, 2023

Yes, this one has been fixed on main :-)

@sgugger sgugger merged commit 9960506 into huggingface:main Feb 8, 2023
miyu386 pushed a commit to miyu386/transformers that referenced this pull request Feb 9, 2023
* add tests with multiple eos_token_ids

* make math.prod instead of sum

* make fixup

* fix long and also use np.prod since math.prod does not exist <python 3.8

* make fixup

* add prod util

* use prod util instead of np.prod

* make fixup

* previous .long location

* use tensor ops

* remove prod

* remove prod

* update device

* make fixup

* fix none
@ydshieh
Copy link
Collaborator

ydshieh commented Feb 9, 2023

Hi @tokestermw Thank you for working on this. After this PR being merged to main, there are some CI regression. Could you take a look 🙏 . Also cc @gante

To reproduce:

We can check with specific commit on main branch

git checkout 06d940ef  # One commit before this PR on `main`
git checkout 9960506c  # This PR - failed the following tests

Then prepare the file format for doctests

python utils/prepare_for_doc_test.py src docs

This

python3 -m pytest -v --make-reports doc_tests_gpu --doctest-modules docs/source/en/model_doc/t5.mdx::t5.mdx -sv --doctest-continue-on-failure --doctest-glob="*.mdx"

gives error

Expected:
    ['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.']
Got:
    ['Das Haus ist wunderbar. Das Haus ist wunderschön. Sehr', 'Ich arbeite gerne in NYC. Ich arbeite in NYC.']

and this

python3 -m pytest -v --make-reports doc_tests_gpu --doctest-modules docs/source/en/model_doc/tapex.mdx::tapex.mdx -sv --doctest-continue-on-failure --doctest-glob="*.mdx"

gives error

Expected:
    [' 53', ' george clooney', ' brad pitt']
Got:
    [' 53 lithuania, french montana, french montana, french montana, french montana, french montana ...(very long non-sense string)]

@tokestermw tokestermw mentioned this pull request Feb 9, 2023
5 tasks
@tokestermw
Copy link
Contributor Author

@ydshieh thanks, ah i see the issue 😓 . we're not carrying over the unfinished_sequences

making a fix here: #21529

ArthurZucker pushed a commit to ArthurZucker/transformers that referenced this pull request Mar 2, 2023
* add tests with multiple eos_token_ids

* make math.prod instead of sum

* make fixup

* fix long and also use np.prod since math.prod does not exist <python 3.8

* make fixup

* add prod util

* use prod util instead of np.prod

* make fixup

* previous .long location

* use tensor ops

* remove prod

* remove prod

* update device

* make fixup

* fix none
hogru added a commit to hogru/MolReactGen that referenced this pull request Oct 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants