-
Notifications
You must be signed in to change notification settings - Fork 25.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: PT's top_p
enforces min_tokens_to_keep
when it is 1
#24111
Conversation
top_p
enforces min_tokens_to_keep
when it is 1top_p
enforces min_tokens_to_keep
when it is 1
The documentation is not available anymore as the PR was closed or merged. |
From the diff - I don't see how this is resolved. The checks ensure the value of |
@@ -266,9 +268,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to | |||
|
|||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) | |||
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p) | |||
if self.min_tokens_to_keep > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts -- this line was preventing the application of min_tokens_to_keep
when it was 1
. Removing it sorts the problem :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e.g. if we were to set top_p=0.0
, .generate()
would crash due to the lack of suitable continuations, despite the default of min_tokens_to_keep
being 1
.
After this fix, setting top_p=0.0
does not crash the code. Refresher: Top P selects all the tokens whose cumulative probability is >= top_p
. Setting it to 0.0
means it should pick exactly one token.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, it's a stupid ambiguity of English issue :D From the title I thought it meant 'when top_p
is 1`.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
@gante I hit an issue related to this in the prior version of transformers, glad to see that it's fixed thanks! However why don't we enforce |
@njhill true, the initial check should be against |
See huggingface/transformers#24111 I didn't add validation to the `__init__` method since it's not done for other values/warpers.
See huggingface/transformers#24111 I didn't add validation to the `__init__` method since it's not done for other values/warpers.
See huggingface/transformers#24111 I didn't add validation to the `__init__` method since it's not done for other values/warpers.
See huggingface/transformers#24111 I didn't add validation to the `__init__` method since it's not done for other values/warpers.
What does this PR do?
Fixes #23688
Contrary to our description in the docstring, PT's
top_p
was not enforcingmin_tokens_to_keep
when it was 1 (the default). TF and FLAX were fine. This PR corrects it, and adds a check onmin_tokens_to_keep
(must be a non-negative integer)