Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a Bug, trainer_seq2seq.py, in the else branch at Line 172, generation_inputs should be a dict #14546

Merged
merged 2 commits into from
Dec 7, 2021

Conversation

TranSirius
Copy link
Contributor

Fixing Bug

Fixes # (issue)

In trainer_seq2seq.py / Seq2SeqTrainer / prediction_step, Line 174 reads:

generated_tokens = self.model.generate(
    **generation_inputs,
    **gen_kwargs,
)

which require the generated_tokens to be a dict. However, in the else branch in Line 171, the generation_inputs is created as a Tensor object, which will cause a problem.

Fix this by creating generation_inputs as a dict, and add a key called input_ids.

@patrickvonplaten
Copy link
Contributor

Hey @TranSirius,

Thanks a lot for your PR here! It looks good to me - @sgugger can you maybe take a look as well?

@patrickvonplaten
Copy link
Contributor

Should we maybe write some tests for this use case as well?

@patrickvonplaten patrickvonplaten removed the request for review from patil-suraj December 7, 2021 13:47
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@sgugger sgugger merged commit 39f1dff into huggingface:master Dec 7, 2021
@sgugger
Copy link
Collaborator

sgugger commented Dec 7, 2021

Oops didn't see your comment @patrickvonplaten. Adding a test would be nice to have indeed @TranSirius if you want to work on it on a separate PR.

Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 27, 2022
…tion_inputs should be a dict (huggingface#14546)

* fix bug, trainer_seq2seq.py, Line 172, generation_inputs must be a dict before feeding into self.model.generation()

* fix bug, trainer_seq2seq.py, Line 172, generation_inputs must be a dict before feeding into self.model.generation()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants