Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Generation] Allow inputs_embeds as an input #14443

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Nov 18, 2021

What does this PR do?

This PR allows inputs_embeds to be used as an input argument for generate(). Fixes: #12218

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

inputs_embeds = model.get_input_embeddings()(input_ids)

# cannot generate from `inputs_embeds` for decoder only
with pytest.raises(ValueError):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decoder-only can't generate with inputs_embeds since predicted ids are append to starting ids => so it's assumed that starting ids are input_ids and not inputs_embeds

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With not use self.assertRaises(ValueError) here?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding support for this!

src/transformers/generation_utils.py Outdated Show resolved Hide resolved
tests/test_generation_utils.py Outdated Show resolved Hide resolved
inputs_embeds = model.get_input_embeddings()(input_ids)

# cannot generate from `inputs_embeds` for decoder only
with pytest.raises(ValueError):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With not use self.assertRaises(ValueError) here?

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some nits (+ agree with @sgugger ones)

torch_device
)
model.config.eos_token_id = None
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the length of this ?
Without the information it's hard to infer if self.assertEqual(output_sequences.shape, (1, 5)) is actually correct.
(I know it's a encoder-decoder but it gets important in the other decoder only test.)

Maybe as additional guarantee we can force the decoder_input_ids (To prove to reader that some tokens were generated ?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it comes pretty much only from max_length=5

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment

model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to(
torch_device
)
model.config.eos_token_id = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that necessary for this test? is it so EOS is not produced before max_length?
Maybe add a small comment where this is present ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure it can't finish before hitting max_length

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

T5 model seq2seq text generation using word embeddings instead of token_ids does not work
3 participants