-
Notifications
You must be signed in to change notification settings - Fork 301
Description
Constrastive search is an improvement to Top-K search, which further reduces the non-sense repetition. Starting from top-k, the implementation should not be very hard, all we need is this equation:
This issue will be based on #563, which creates the basic interface of our sampler class.
For reference, please check this paper
================================================================
Adding some implementation guidance...
Contrastive sampler is a variant of top-k sampler, the only difference is it adds a penalty based on max similarity with previously seen tokens. You need to override the sample()
method instead of get_next_token()
as in TopKSampler
because we have to compute the penalty for each vocab token. Here is a template for the implementation to help you start:
@keras.utils.register_keras_serializable(package="keras_nlp")
class ContrastiveSampler(Sampler):
"""Contrastive Sampler class.
{Add docstring here}
"""
def __init__(
self,
k=5,
seed=None,
jit_compile=True,
run_eagerly=False,
):
self.k = k
self.seed = seed
super().__init__(jit_compile=jit_compile, run_eagerly=run_eagerly)
def get_next_token(self, next_token_probs):
pass
def sample(
self, prompt, token_probability_fn, mask, num_steps, from_logits=True
):
batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1]
max_length = tf.cast(max_length, num_steps.dtype)
# The index of the last non-padding token in prompt. Since all sequences
# are aligned to the right side, the index is the same for all.
current_index = max_length - num_steps
def one_step(current_index, prompt, mask):
#################################################################
# 1. Get top-k tokens along with their probibility and representation.
#
# Your code goes here!
#################################################################
#################################################################
# 2. Compute the penalty for each token in the top-k selection.
#
# Your code goes here!
#################################################################
#################################################################
# 3. Update the corresponding index and mask.
#
# Your code goes here!
#################################################################
# Run a while loop till `max_length` of tokens has been generated.
current_index, prompt, mask = tf.while_loop(
cond=lambda current_index, prompt, mask: tf.less(
current_index, max_length
),
body=one_step,
loop_vars=(current_index, prompt, mask),
)
return prompt
To test your implementation, you can use the code snippet below:
import keras_nlp
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.run_eagerly = True
gpt2_lm.jit_compile = False
print(
gpt2_lm.generate(
"that's weird", sampler="greedy", max_length=30
)
)
gpt2_lm.run_eagerly = True
makes the generate
run in eager mode for easier debugging.
This issue is a bit challenging but very rewarding!