Skip to content

Add contrastive search to our Sampler collection #644

@chenmoneygithub

Description

@chenmoneygithub

Constrastive search is an improvement to Top-K search, which further reduces the non-sense repetition. Starting from top-k, the implementation should not be very hard, all we need is this equation:

formulation

This issue will be based on #563, which creates the basic interface of our sampler class.

For reference, please check this paper

================================================================

Adding some implementation guidance...

Contrastive sampler is a variant of top-k sampler, the only difference is it adds a penalty based on max similarity with previously seen tokens. You need to override the sample() method instead of get_next_token() as in TopKSampler because we have to compute the penalty for each vocab token. Here is a template for the implementation to help you start:

@keras.utils.register_keras_serializable(package="keras_nlp")
class ContrastiveSampler(Sampler):
    """Contrastive Sampler class.
    
    {Add docstring here}
    """

    def __init__(
        self,
        k=5,
        seed=None,
        jit_compile=True,
        run_eagerly=False,
    ):
        self.k = k
        self.seed = seed
        super().__init__(jit_compile=jit_compile, run_eagerly=run_eagerly)

    def get_next_token(self, next_token_probs):
        pass

    def sample(
        self, prompt, token_probability_fn, mask, num_steps, from_logits=True
    ):
        batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1]
        max_length = tf.cast(max_length, num_steps.dtype)
        # The index of the last non-padding token in prompt. Since all sequences
        # are aligned to the right side, the index is the same for all.
        current_index = max_length - num_steps


        def one_step(current_index, prompt, mask):

            #################################################################
            # 1. Get top-k tokens along with their probibility and representation.
            # 
            # Your code goes here!
            #################################################################

            #################################################################
            # 2. Compute the penalty for each token in the top-k selection.
            # 
            # Your code goes here!
            #################################################################

            #################################################################
            # 3. Update the corresponding index and mask.
            # 
            # Your code goes here!
            #################################################################


        # Run a while loop till `max_length` of tokens has been generated.
        current_index, prompt, mask = tf.while_loop(
            cond=lambda current_index, prompt, mask: tf.less(
                current_index, max_length
            ),
            body=one_step,
            loop_vars=(current_index, prompt, mask),
        )
        return prompt

To test your implementation, you can use the code snippet below:

import keras_nlp

gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.run_eagerly = True
gpt2_lm.jit_compile = False
print(
    gpt2_lm.generate(
        "that's weird", sampler="greedy", max_length=30
    )
)

gpt2_lm.run_eagerly = True makes the generate run in eager mode for easier debugging.

This issue is a bit challenging but very rewarding!

Metadata

Metadata

Labels

stat:contributions welcomeAdd this label to feature request issues so they are separated out from bug reporting issuestype:featureNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions