-
Notifications
You must be signed in to change notification settings - Fork 25.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: FLAX infers pad token in its absence and has functional example #21009
Conversation
eos_token_id = generation_config.eos_token_id | ||
if isinstance(eos_token_id, list): | ||
eos_token_id = eos_token_id[0] | ||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | ||
generation_config.pad_token_id = eos_token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Took the opportunity also to copy the logic to TF, so it can also handle eos_token_id
as a list 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix @gante!
What does this PR do?
Some bug fixing in advance of #21007 (PR that adds generation config to Flax), to ensure we start from a functional flax generate codebase.
In particular:
pad_token_id
when it isNone
andeos_token_id
is notNone
, like TF and PT do. This is very helpful for open text generation examples, like with GPT2, was an open request (Generating with Flax fails when using Causal Language models #18884), and was one of the causes for failure in the existing example. This also includes the recent changes of Add custom stop token ids for generation #20727, whereeos_token_id
can be a list of tokens.int32
type specification was missing in the special tokens -- when converted to JAX variables, JAX assumed they werefloat32
;