Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: PT's top_p enforces min_tokens_to_keep when it is 1 #24111

Merged
merged 2 commits into from
Jun 9, 2023

Conversation

gante
Copy link
Member

@gante gante commented Jun 8, 2023

What does this PR do?

Fixes #23688

Contrary to our description in the docstring, PT's top_p was not enforcing min_tokens_to_keep when it was 1 (the default). TF and FLAX were fine. This PR corrects it, and adds a check on min_tokens_to_keep (must be a non-negative integer)

@gante gante requested a review from amyeroberts June 8, 2023 13:36
@gante gante changed the title Generate: PT's top_p enforces min_tokens_to_keep when it is 1 Generate: PT's top_p enforces min_tokens_to_keep when it is 1 Jun 8, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 8, 2023

The documentation is not available anymore as the PR was closed or merged.

@amyeroberts
Copy link
Collaborator

top_p was not enforcing min_tokens_to_keep when it was 1

From the diff - I don't see how this is resolved. The checks ensure the value of min_tokens_to_keeps but doesn't seem to be conditional on top_p. Am I missing something?

@@ -266,9 +268,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
if self.min_tokens_to_keep > 1:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts -- this line was preventing the application of min_tokens_to_keep when it was 1. Removing it sorts the problem :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. if we were to set top_p=0.0, .generate() would crash due to the lack of suitable continuations, despite the default of min_tokens_to_keep being 1.

After this fix, setting top_p=0.0 does not crash the code. Refresher: Top P selects all the tokens whose cumulative probability is >= top_p. Setting it to 0.0 means it should pick exactly one token.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's a stupid ambiguity of English issue :D From the title I thought it meant 'when top_p is 1`.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@gante gante merged commit be10092 into huggingface:main Jun 9, 2023
22 checks passed
@gante gante deleted the top_p_min_tokens branch June 9, 2023 12:20
@njhill
Copy link
Contributor

njhill commented Jun 21, 2023

@gante I hit an issue related to this in the prior version of transformers, glad to see that it's fixed thanks! However why don't we enforce min_tokens_to_keep >= 1. 0 makes no sense right?

@gante
Copy link
Member Author

gante commented Jun 22, 2023

@njhill true, the initial check should be against >=1, patching it

novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
OlivierDehaene pushed a commit to huggingface/text-generation-inference that referenced this pull request Jul 4, 2023
See huggingface/transformers#24111

I didn't add validation to the `__init__` method since it's not done for
other values/warpers.
AIProphet added a commit to AIProphet/text-generation-inference that referenced this pull request Jul 12, 2023
See huggingface/transformers#24111

I didn't add validation to the `__init__` method since it's not done for
other values/warpers.
verdant621 added a commit to verdant621/text-generation-inference that referenced this pull request Oct 19, 2023
See huggingface/transformers#24111

I didn't add validation to the `__init__` method since it's not done for
other values/warpers.
cr313 added a commit to cr313/text-generation-inference-load-test that referenced this pull request Apr 19, 2024
See huggingface/transformers#24111

I didn't add validation to the `__init__` method since it's not done for
other values/warpers.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LlamaForCausalLM generate() runtime error when top_p=0
4 participants