Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def sample(
mask,
num_steps,
from_logits=True,
end_token_id=None,
cache=None,
token_probs=None,
):
"""Sampling logic implementation.

Expand Down Expand Up @@ -139,7 +139,6 @@ def one_step(beams, beams_prob, length, mask):
if from_logits:
preds = keras.activations.softmax(preds, axis=-1)
# Reshape `preds` to shape `(batch_size, num_beams * vocab_size)`.

preds = tf.reshape(preds, shape=[batch_size, -1])

cum_probs = tf.math.log(preds) + tf.repeat(
Expand Down
1 change: 0 additions & 1 deletion keras_nlp/samplers/beam_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def setUp(self):
output_dim=self.feature_size,
),
keras.layers.Dense(self.vocab_size),
keras.layers.Softmax(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because we by default set from_logits=True (was False earlier for mysterious reason), so I am aligning the unit test with it.

]
)

Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/samplers/greedy_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def token_probability_fn(inputs, mask):

def test_end_token_id(self):
def token_probability_fn(inputs, mask):
batch_size = inputs.shape[0]
batch_size = tf.shape(inputs)[0]
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.repeat(
tf.repeat(prob, batch_size, axis=0), max_length, axis=1
Expand Down
45 changes: 31 additions & 14 deletions keras_nlp/samplers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ def __call__(
token_probability_fn,
mask,
max_length - shortest_prompt_len,
cache=cache,
from_logits=from_logits,
end_token_id=end_token_id,
cache=cache,
)
# Mask out tokens after `end_token_id`.
if end_token_id is not None:
Expand Down Expand Up @@ -269,6 +270,7 @@ def sample(
mask,
num_steps,
from_logits=True,
end_token_id=None,
cache=None,
):
"""Sampling logic implementation.
Expand All @@ -284,6 +286,9 @@ def sample(
from_logits: bool, defaults to True. Indicate if the
`token_probability_fn` returns logits. If False,
`token_probability_fn` returns probability distributions.
end_token_id: int, defaults to None. The token marking the end of
the sequence, once encountered the generation is finished for
the exact sequence.
cache: a dense int tensor, the cache used in decoding. The cache
stores the key and value of each
`keras_nlp.layers.CachedMultiHeadAttention` layer to make the
Expand All @@ -299,8 +304,9 @@ def sample(
# The index of the last non-padding token in prompt. Since all sequences
# are aligned to the right side, the index is the same for all.
current_index = max_length - num_steps
original_padding_mask = tf.cast(tf.identity(mask), dtype=tf.int32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we compute this twice (above as well), should we just pass it through?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fairly cheap to compute. I was doing this way because sample already has confusing arg list, would like to keep it shorter (though not much...)


def one_step(
def body(
current_index,
prompt,
mask,
Expand All @@ -316,7 +322,10 @@ def one_step(
)
next_token_probs = tf.squeeze(probs, axis=1)
else:
probs = token_probability_fn(prompt, mask)
probs = token_probability_fn(
prompt,
mask,
)
next_token_probs = tf.gather(
probs,
tf.repeat(current_index - 1, batch_size),
Expand All @@ -343,24 +352,32 @@ def one_step(
current_index = tf.add(current_index, 1)
if cache is None:
return current_index, prompt, mask
return [current_index, prompt, mask, cache]
return current_index, prompt, mask, cache

def cond(current_index, prompt, mask, cache=None):
if end_token_id is None:
return True
end_token_seen = (prompt == end_token_id) & (
original_padding_mask == 0
)
sequence_done = tf.reduce_any(end_token_seen, axis=-1)
all_done = tf.reduce_all(sequence_done)
return not all_done

if cache is None:
_, prompt, _ = tf.while_loop(
cond=lambda current_index, prompt, mask: tf.less(
current_index, max_length
),
body=one_step,
loop_vars=[current_index, prompt, mask],
cond=cond,
body=body,
loop_vars=(current_index, prompt, mask),
maximum_iterations=num_steps,
)
return prompt
# Run a while loop till `max_length` of tokens has been generated.
_, prompt, _, _ = tf.while_loop(
cond=lambda current_index, prompt, mask, cache: tf.less(
current_index, max_length
),
body=one_step,
loop_vars=[current_index, prompt, mask, cache],
cond=cond,
body=body,
loop_vars=(current_index, prompt, mask, cache),
maximum_iterations=num_steps,
)
return prompt

Expand Down