Skip to content

Conversation

chenmoneygithub
Copy link
Contributor

resolve #963

@chenmoneygithub
Copy link
Contributor Author

Tested on CPU, with 25 iterations, after the fix it's 37s compared to 55s.

@chenmoneygithub
Copy link
Contributor Author

/gcbrun

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Dropped some comments.

Args:
p: float, the `p` value of top-p.
sample_from_top_k: int, defaults to None. If set, only sample from top
Copy link
Member

@mattdangerw mattdangerw Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just call this k honestly (parallel with the same option in top-k), and in the description say that it is a heuristic cutoff point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought is having both p and k is a bit confusing to readers, here p is the main arg while k is a supporting role, using p and k combo kinda suggest they are parallel to each other.

No strong opinion here and agree sample_from_top_k isn't a great name.

)
self.assertEqual(self.join_as_string(output), ["sequentzzzzz"])

def test_sample_from_all_tokens(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this test adds to much. If we want to add a test here, it should test that our cutoff point is working.

e.g. p=1.0, k=5, uniform logits, assert that all outputs in the first five vocab options.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case is just a sanity check on sampler call with k=None works.

Agree the proposed test works better! changed.

@chenmoneygithub
Copy link
Contributor Author

/gcbrun

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Think we should be more detailed in the docstring for this one.

Args:
p: float, the `p` value of top-p.
k: int, defaults to None. If set, only sample from top `k` tokens.
Copy link
Member

@mattdangerw mattdangerw Apr 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... If set, this argument defines a heuristic "top-k" cutoff applied before the "top-p" sampling. All logits not in the top k will be discarded, and the remaining logits will be sorted to find a cutoff point for p. Setting this argument can significantly speed sampling by reducing the size of the sort.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@chenmoneygithub chenmoneygithub merged commit 33fdb1f into keras-team:master Apr 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Investigate top-p performance
2 participants