Skip to content

Conversation

jessechancy
Copy link
Contributor

No description provided.

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Mainly looks good, left some comments

pred = tf.keras.activations.softmax(pred, axis=-1)
# Sort preds in descending order.
sorted_preds, sorted_indices = tf.math.top_k(
pred, k=pred.shape[-1], sorted=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be consistent with -1 or 1 here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited

replaced with `pad_token_id`.
pad_token_id: int, defaults to 0. The pad token after `end_token_id`
is received.
filter_value: float, defaults to -Inf. The value for filtering out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to allow customization on filter_value? The code uses filter_value to set certain tokens to have probability 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to allow customization, we can just mask out tokens that are unused by having filter_value = 0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, in that case, let's remove this argument.

)
# Filter out unmasked tokens and sample from filtered distribution.
probs = tf.where(
shifted_keep_mask, sorted_preds, tf.fill(pred.shape, filter_value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should filter_value always be 0? I saw in the docstring it defaults to -Inf.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep filter_value should be 0, edited

self.assertAllEqual(output_logit, output_probs)


class TopPSamplingTextGenerationTest(tf.test.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep test name and method name the same - TopPSearchTextGenerationTest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited for this and other testing names

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only a few minor comments!

replaced with `pad_token_id`.
pad_token_id: int, defaults to 0. The pad token after `end_token_id`
is received.
filter_value: float, defaults to -Inf. The value for filtering out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, in that case, let's remove this argument.

prob = tf.constant([[0.4, 0.3, 0.2, 0.1]])
return tf.repeat(prob, batch_size, axis=0)

# Test that it only samples from tokens that sum up to p
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: period at end.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Nice work.

Just a few minor comments

)
if p <= 0 or p >= 1:
raise ValueError(
"p should be a float strictly between 0 and 1 (0 < p < 1)."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

surround arg names in backticks (p) also show what value we received, something like.

`p` should be in the range (0, 1). Received: `p={p}`.

"""
Text generation utility based on top-p (nucleus) sampling.
Top-p search filters the top token with probabilities that sum up to p, and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can explain this a little more clearly, also remember to backtick argument names so it's clear what they are...

Top-p search selects tokens from the smallest subset of output probabilities that sum to greater than p. Put another way, top-p will first order token predictions by likelihood, and ignore all tokens after the cumulative probably of selected tokens exceeds p.

prompt = tf.fill((BATCH_SIZE, 1), START_ID)
# Print the generated sequence (token ids).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@mattdangerw mattdangerw merged commit 5c87ada into keras-team:master Jun 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants