From 4ddd679e0c16f8699cd85b5fa5caeac1403c2cdd Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 5 Apr 2023 18:21:25 -0700 Subject: [PATCH 01/11] rebase --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 58 +++++++-- keras_nlp/samplers/__init__.py | 1 + keras_nlp/samplers/contrastive_sampler.py | 146 ++++++++++++++++++++++ 3 files changed, 195 insertions(+), 10 deletions(-) create mode 100644 keras_nlp/samplers/contrastive_sampler.py diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 8fb1b68a55..974ea7e1e2 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -16,6 +16,7 @@ import copy import tensorflow as tf +from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice from keras_nlp.api_export import keras_nlp_export from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone @@ -24,6 +25,7 @@ ) from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.models.task import Task +from keras_nlp.samplers import ContrastiveSampler from keras_nlp.samplers.serialization import get as get_sampler from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -208,7 +210,9 @@ def backbone_cls(cls): def preprocessor_cls(cls): return GPT2CausalLMPreprocessor - def call_with_cache(self, token_ids, cache, cache_index): + def call_with_cache( + self, token_ids, cache, cache_index, hidden_states=None + ): """Forward pass of `GPT2CausalLM` with cache. `call_with_cache` adds an additional forward pass for the model for @@ -247,14 +251,18 @@ def call_with_cache(self, token_ids, cache, cache_index): caches[i] = next_cache cache = tf.stack(caches, axis=1) x = self.backbone.get_layer("layer_norm")(x) + if hidden_states is not None: + hidden_states = dynamic_update_slice( + hidden_states, tf.identity(x), [0, cache_index, 0] + ) x = tf.matmul( x, self.backbone.get_layer("token_embedding").embeddings, transpose_b=True, ) - return x, cache + return x, cache, hidden_states - def _build_cache(self, prompt): + def _build_cache(self, prompt, include_hidden_states=False): """Build an empty cache for use with `call_with_cache()`.""" batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] num_layers = self.backbone.num_layers @@ -262,9 +270,17 @@ def _build_cache(self, prompt): head_dim = self.backbone.hidden_dim // self.backbone.num_heads shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] cache = tf.zeros(shape) + if include_hidden_states: + hidden_states = tf.zeros( + [batch_size, max_length, self.backbone.hidden_dim] + ) + else: + hidden_states = None # Seed the cache. - _, cache = self.call_with_cache(prompt, cache, 0) - return cache + _, cache, hidden_states = self.call_with_cache( + prompt, cache, 0, hidden_states + ) + return cache, hidden_states def compile( self, @@ -286,26 +302,41 @@ def compile( # Clear the compiled generate function. self.generate_function = None - def make_generate_function(self): + def make_generate_function(self, include_hidden_states=False): """Create or return the compiled generation function.""" if self.generate_function is not None: return self.generate_function def generate_function(prompt, input_mask, min_length): # Create and seed cache with a single forward pass. - cache = self._build_cache(prompt) + cache, hidden_states = self._build_cache( + prompt, include_hidden_states + ) def next(prompt, state, index): # The cache index is the index of our previous token. cache_index = index - 1 + cache = state["cache"] + hidden_states = state.get("hidden_states", None) prompt = tf.slice(prompt, [0, cache_index], [-1, 1]) - logits, state = self.call_with_cache(prompt, state, cache_index) + logits, cache, hidden_states = self.call_with_cache( + prompt, + cache, + cache_index, + hidden_states, + ) + state["cache"] = cache + if hidden_states is not None: + state["hidden_states"] = hidden_states return tf.squeeze(logits, axis=1), state + state = {"cache": cache} + if hidden_states is not None: + state["hidden_states"] = hidden_states return self._sampler( next=next, prompt=prompt, - state=cache, + state=state, index=min_length, mask=input_mask, end_token_id=self.preprocessor.tokenizer.end_token_id, @@ -358,7 +389,14 @@ def generate( prompt = prompt.to_tensor(shape=padded_shape) # Run the (possibly compiled) generate function on dense inputs. - generate_function = self.make_generate_function() + if isinstance(self._sampler, ContrastiveSampler): + generate_function = self.make_generate_function( + include_hidden_states=True + ) + else: + generate_function = self.make_generate_function( + include_hidden_states=False + ) output = generate_function(prompt, input_mask, min_length) # Truncate to ragged by removing tokens after the first end token. diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index e76017b8db..8457f72d49 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -15,6 +15,7 @@ from tensorflow import keras from keras_nlp.samplers.beam_sampler import BeamSampler +from keras_nlp.samplers.contrastive_sampler import ContrastiveSampler from keras_nlp.samplers.greedy_sampler import GreedySampler from keras_nlp.samplers.random_sampler import RandomSampler from keras_nlp.samplers.sampler import Sampler diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py new file mode 100644 index 0000000000..a93a0fbe81 --- /dev/null +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -0,0 +1,146 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contrastive Sampler.""" + +import tensorflow as tf +from tensorflow import keras +from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.samplers.sampler import Sampler + + +@keras_nlp_export("keras_nlp.samplers.ContrastiveSampler") +class ContrastiveSampler(Sampler): + def __init__( + self, + k=5, + alpha=0.5, + seed=None, + ): + super().__init__() + self.k = k + self.alpha = alpha + self.seed = seed + + def __call__( + self, + next, + prompt, + state=None, + index=0, + mask=None, + end_token_id=None, + ): + batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] + # Make sure max length and start index are the same dtype. + index = tf.cast(index, max_length.dtype) + + def create_beams(x): + """Add initial beam state.""" + x = tf.repeat(x, self.k, axis=0) + flat_shape = [batch_size * self.k] + x.shape.as_list()[1:] + return tf.reshape(x, shape=flat_shape) + + def flatten_beams(x): + """Combine the beam dim and batch dim.""" + flat_shape = [batch_size * self.k] + x.shape.as_list()[2:] + return tf.reshape(x, shape=flat_shape) + + def unflatten_beams(x): + """Separate the beam dim and batch dim.""" + unflat_shape = [batch_size, self.k] + x.shape.as_list()[1:] + return tf.reshape(x, shape=unflat_shape) + + mask = tf.zeros_like(prompt, dtype=tf.bool) if mask is None else mask + # `tf.while_loop` will not accept `None` as a value for `loop_vars`. + state = () if state is None else state + + def cond(prompt, state, index): + if end_token_id is None: + return True + # Stop if all sequences have produced a *new* end_token_id. + end_tokens = (prompt == end_token_id) & (~mask) + prompt_done = tf.reduce_any(end_tokens, axis=-1) + return not tf.reduce_all(prompt_done) + + def body(prompt, state, index): + # Compute the softmax distribution for the next token. + logits, state = next(prompt, state, index) + probabilities = keras.activations.softmax(logits) + + prompt_beams, mask_beams = create_beams(prompt), create_beams(mask) + state_beams = tf.nest.map_structure(create_beams, state) + + top_k_probabilities, top_k_indices = tf.math.top_k( + probabilities, k=self.k, sorted=False + ) + next_token_probabilities = flatten_beams(top_k_probabilities) + next_token = flatten_beams(top_k_indices) + next_token = tf.cast(next_token, prompt.dtype) + + next_token = tf.where( + mask_beams[:, index], prompt_beams[:, index], next_token + ) + # Update the prompt with the next token. + next_token = next_token[:, tf.newaxis] + prompt_beams = dynamic_update_slice( + prompt_beams, next_token, [0, index] + ) + + _, next_state_beams = next(prompt_beams, state_beams, index + 1) + hidden_states = next_state_beams["hidden_states"] + last_token_state = hidden_states[:, index, :][:, tf.newaxis, :] + previous_states = state_beams["hidden_states"][:, :index, :] + similarity_scores = self.similarity( + previous_states, last_token_state + ) + max_similarity_scores = tf.reduce_max(similarity_scores, axis=1) + + accumulated_scores = ( + (1 - self.alpha) * next_token_probabilities + - self.alpha * max_similarity_scores + ) + unflat_score = unflatten_beams(accumulated_scores) + unflat_prompt = unflatten_beams(prompt_beams) + + win_token_indices = tf.math.argmax(unflat_score, axis=1) + prompt = tf.gather( + unflat_prompt, win_token_indices, axis=1, batch_dims=1 + ) + return (prompt, state, index + 1) + + prompt, _, _ = tf.while_loop( + cond=cond, + body=body, + loop_vars=(prompt, state, index), + maximum_iterations=(max_length - index), + ) + return prompt + + def similarity(self, h1, h2): + return tf.squeeze(tf.matmul(h1, h2, transpose_b=True), axis=-1) / ( + tf.norm(h1, axis=-1) * tf.norm(h2, axis=-1) + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "k": self.k, + "alpha": self.alpha, + "seed": self.seed, + } + ) + return config From 479c87f413308d4e4dd93a5649e15c3883e1e207 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 22 Mar 2023 13:39:56 +0800 Subject: [PATCH 02/11] better! --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 55 +++++++---------------- keras_nlp/samplers/beam_sampler.py | 1 + keras_nlp/samplers/contrastive_sampler.py | 6 +-- keras_nlp/samplers/sampler.py | 3 +- 4 files changed, 21 insertions(+), 44 deletions(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 974ea7e1e2..3f1dceac4d 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -16,7 +16,6 @@ import copy import tensorflow as tf -from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice from keras_nlp.api_export import keras_nlp_export from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone @@ -211,7 +210,10 @@ def preprocessor_cls(cls): return GPT2CausalLMPreprocessor def call_with_cache( - self, token_ids, cache, cache_index, hidden_states=None + self, + token_ids, + cache, + cache_index, ): """Forward pass of `GPT2CausalLM` with cache. @@ -251,10 +253,7 @@ def call_with_cache( caches[i] = next_cache cache = tf.stack(caches, axis=1) x = self.backbone.get_layer("layer_norm")(x) - if hidden_states is not None: - hidden_states = dynamic_update_slice( - hidden_states, tf.identity(x), [0, cache_index, 0] - ) + hidden_states = tf.identity(x) x = tf.matmul( x, self.backbone.get_layer("token_embedding").embeddings, @@ -262,7 +261,7 @@ def call_with_cache( ) return x, cache, hidden_states - def _build_cache(self, prompt, include_hidden_states=False): + def _build_cache(self, prompt): """Build an empty cache for use with `call_with_cache()`.""" batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] num_layers = self.backbone.num_layers @@ -270,16 +269,8 @@ def _build_cache(self, prompt, include_hidden_states=False): head_dim = self.backbone.hidden_dim // self.backbone.num_heads shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] cache = tf.zeros(shape) - if include_hidden_states: - hidden_states = tf.zeros( - [batch_size, max_length, self.backbone.hidden_dim] - ) - else: - hidden_states = None # Seed the cache. - _, cache, hidden_states = self.call_with_cache( - prompt, cache, 0, hidden_states - ) + _, cache, hidden_states = self.call_with_cache(prompt, cache, 0) return cache, hidden_states def compile( @@ -302,41 +293,32 @@ def compile( # Clear the compiled generate function. self.generate_function = None - def make_generate_function(self, include_hidden_states=False): + def make_generate_function(self): """Create or return the compiled generation function.""" if self.generate_function is not None: return self.generate_function def generate_function(prompt, input_mask, min_length): # Create and seed cache with a single forward pass. - cache, hidden_states = self._build_cache( - prompt, include_hidden_states - ) + cache, hidden_states = self._build_cache(prompt) def next(prompt, state, index): # The cache index is the index of our previous token. cache_index = index - 1 - cache = state["cache"] - hidden_states = state.get("hidden_states", None) + cache = state prompt = tf.slice(prompt, [0, cache_index], [-1, 1]) logits, cache, hidden_states = self.call_with_cache( prompt, cache, cache_index, - hidden_states, ) - state["cache"] = cache - if hidden_states is not None: - state["hidden_states"] = hidden_states - return tf.squeeze(logits, axis=1), state - - state = {"cache": cache} - if hidden_states is not None: - state["hidden_states"] = hidden_states + return tf.squeeze(logits, axis=1), cache, hidden_states + return self._sampler( next=next, prompt=prompt, - state=state, + state=cache, + initial_hidden_states=hidden_states, index=min_length, mask=input_mask, end_token_id=self.preprocessor.tokenizer.end_token_id, @@ -389,14 +371,7 @@ def generate( prompt = prompt.to_tensor(shape=padded_shape) # Run the (possibly compiled) generate function on dense inputs. - if isinstance(self._sampler, ContrastiveSampler): - generate_function = self.make_generate_function( - include_hidden_states=True - ) - else: - generate_function = self.make_generate_function( - include_hidden_states=False - ) + generate_function = self.make_generate_function() output = generate_function(prompt, input_mask, min_length) # Truncate to ragged by removing tokens after the first end token. diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index c20047b84b..7bbcdd37c5 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -105,6 +105,7 @@ def __call__( next, prompt, state=None, + initial_hidden_states=None, index=0, mask=None, end_token_id=None, diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index a93a0fbe81..9ec8269883 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -66,6 +66,7 @@ def unflatten_beams(x): mask = tf.zeros_like(prompt, dtype=tf.bool) if mask is None else mask # `tf.while_loop` will not accept `None` as a value for `loop_vars`. state = () if state is None else state + logits, state, hidden_states = next(prompt, state, 0) def cond(prompt, state, index): if end_token_id is None: @@ -77,9 +78,9 @@ def cond(prompt, state, index): def body(prompt, state, index): # Compute the softmax distribution for the next token. - logits, state = next(prompt, state, index) + logits, state, next_hidden_states = next(prompt, state, index) probabilities = keras.activations.softmax(logits) - + prompt_beams, mask_beams = create_beams(prompt), create_beams(mask) state_beams = tf.nest.map_structure(create_beams, state) @@ -100,7 +101,6 @@ def body(prompt, state, index): ) _, next_state_beams = next(prompt_beams, state_beams, index + 1) - hidden_states = next_state_beams["hidden_states"] last_token_state = hidden_states[:, index, :][:, tf.newaxis, :] previous_states = state_beams["hidden_states"][:, :index, :] similarity_scores = self.similarity( diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index a86cbc3499..84a6d1a4eb 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -87,6 +87,7 @@ def __call__( next, prompt, state=None, + initial_hidden_states=None, index=0, mask=None, end_token_id=None, @@ -108,7 +109,7 @@ def cond(prompt, state, index): def body(prompt, state, index): # Compute the softmax distribution for the next token. - logits, state = next(prompt, state, index) + logits, state, _ = next(prompt, state, index) probabilities = keras.activations.softmax(logits) # Compute the next token. next_token = self.get_next_token(probabilities) From 291eee5b62b7fff4a4ac3188e605ae1e75090029 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 22 Mar 2023 19:31:46 +0800 Subject: [PATCH 03/11] better! --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 2 +- keras_nlp/samplers/beam_sampler.py | 2 +- keras_nlp/samplers/contrastive_sampler.py | 117 ++++++++++++--- .../samplers/contrastive_sampler_test.py | 139 ++++++++++++++++++ keras_nlp/samplers/sampler.py | 14 +- 5 files changed, 250 insertions(+), 24 deletions(-) create mode 100644 keras_nlp/samplers/contrastive_sampler_test.py diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 3f1dceac4d..91dfb4a7fe 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -318,10 +318,10 @@ def next(prompt, state, index): next=next, prompt=prompt, state=cache, - initial_hidden_states=hidden_states, index=min_length, mask=input_mask, end_token_id=self.preprocessor.tokenizer.end_token_id, + init_hidden_states=hidden_states, ) if self.run_eagerly: diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 7bbcdd37c5..561595a5fd 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -105,10 +105,10 @@ def __call__( next, prompt, state=None, - initial_hidden_states=None, index=0, mask=None, end_token_id=None, + init_hidden_states=None, ): batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] # Make sure max length and start index are the same dtype. diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index 9ec8269883..f8748e347e 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -19,14 +19,57 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.samplers.sampler import Sampler +from keras_nlp.samplers.sampler import call_args_docstring +from keras_nlp.utils.python_utils import format_docstring +@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.ContrastiveSampler") class ContrastiveSampler(Sampler): + """Contrastive Sampler class. + + This sampler implements contrastive search algorithm. In short, the sampler + chooses the token having the max "score" as the next token. The "score" is + a weighted sum between token's probability and max similarity against + previous tokens. By using this joint score, contrastive sampler reduces the + behavior of duplicating seen tokens. + + Args: + k: int, the `k` value of top-k. Next token will be chosen from k tokens. + alpha: float, the weight of minus max similarity in joint score + computation. The larger the value of `alpha`, the score relies more + on the similarity than the token probability. + seed: int, defaults to None. The random seed. + + Call Args: + {{call_args}} + + Examples: + ```python + # Use a simple alphabet of lowercase characters to [0, 26). + int_lookup = {i: chr(i + ord('a')) for i in range(26)} + char_lookup = {v: k for k, v in int_lookup.items()} + batch_size, length, vocab_size = 1, 12, len(int_lookup) + + def next(prompt, state, index): + # A uniform distribution over our alphabet. + logits = tf.ones((batch_size, vocab_size)) + return logits, state + + output = keras_nlp.samplers.ContrastiveSampler()( + next=next, + prompt=tf.fill((batch_size, length,), char_lookup['z']), + index=5, + ) + print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) + # >>> "zzzzzaaaaaaa" + ``` + """ + def __init__( self, k=5, - alpha=0.5, + alpha=0.2, seed=None, ): super().__init__() @@ -42,8 +85,10 @@ def __call__( index=0, mask=None, end_token_id=None, + init_hidden_states=None, ): batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] + hidden_states = init_hidden_states # Make sure max length and start index are the same dtype. index = tf.cast(index, max_length.dtype) @@ -64,11 +109,12 @@ def unflatten_beams(x): return tf.reshape(x, shape=unflat_shape) mask = tf.zeros_like(prompt, dtype=tf.bool) if mask is None else mask + # Compute initial logits. + logits, state, _ = next(prompt, state, index) # `tf.while_loop` will not accept `None` as a value for `loop_vars`. state = () if state is None else state - logits, state, hidden_states = next(prompt, state, 0) - def cond(prompt, state, index): + def cond(prompt, state, index, logits, hidden_states): if end_token_id is None: return True # Stop if all sequences have produced a *new* end_token_id. @@ -76,55 +122,88 @@ def cond(prompt, state, index): prompt_done = tf.reduce_any(end_tokens, axis=-1) return not tf.reduce_all(prompt_done) - def body(prompt, state, index): + def body(prompt, state, index, logits, hidden_states): # Compute the softmax distribution for the next token. - logits, state, next_hidden_states = next(prompt, state, index) probabilities = keras.activations.softmax(logits) - - prompt_beams, mask_beams = create_beams(prompt), create_beams(mask) + + # Replicate for `self.k` times to find the best token in top-k + # candidates. + prompt_beams = create_beams(prompt) + mask_beams = create_beams(mask) + hidden_states_beams = create_beams(hidden_states) state_beams = tf.nest.map_structure(create_beams, state) + # Get top-k candidate tokens and their probabilities. top_k_probabilities, top_k_indices = tf.math.top_k( probabilities, k=self.k, sorted=False ) next_token_probabilities = flatten_beams(top_k_probabilities) next_token = flatten_beams(top_k_indices) next_token = tf.cast(next_token, prompt.dtype) - next_token = tf.where( mask_beams[:, index], prompt_beams[:, index], next_token ) + # Update the prompt with the next token. next_token = next_token[:, tf.newaxis] prompt_beams = dynamic_update_slice( prompt_beams, next_token, [0, index] ) - _, next_state_beams = next(prompt_beams, state_beams, index + 1) - last_token_state = hidden_states[:, index, :][:, tf.newaxis, :] - previous_states = state_beams["hidden_states"][:, :index, :] + # Compute the logits and hidden states for top-k candidate tokens. + next_logits, state_beams, next_hidden_states_beams = next( + prompt_beams, state_beams, index + 1 + ) + + # Compute the max similarity score for top-k candidate tokens + # against previous tokens. + last_token_state = next_hidden_states_beams + previous_states = hidden_states_beams[:, :index, :] similarity_scores = self.similarity( previous_states, last_token_state ) - max_similarity_scores = tf.reduce_max(similarity_scores, axis=1) - + max_similarity_scores = tf.cast( + tf.reduce_max(similarity_scores, axis=1), + dtype=next_token_probabilities.dtype, + ) + # The final score of each candidate token is weighted sum of + # probability and similarity against previous tokens. accumulated_scores = ( (1 - self.alpha) * next_token_probabilities - self.alpha * max_similarity_scores ) + + # Unflatten varibles to shape [batch_size, self.k, ...] for + # gather purpose. unflat_score = unflatten_beams(accumulated_scores) unflat_prompt = unflatten_beams(prompt_beams) - + unflat_next_logits = unflatten_beams(next_logits) + unflat_next_hidden_states = unflatten_beams( + next_hidden_states_beams + ) + unflat_state = tf.nest.map_structure(unflatten_beams, state_beams) win_token_indices = tf.math.argmax(unflat_score, axis=1) - prompt = tf.gather( - unflat_prompt, win_token_indices, axis=1, batch_dims=1 + + def gather_win_token(beams): + return tf.gather(beams, win_token_indices, axis=1, batch_dims=1) + + prompt = gather_win_token(unflat_prompt) + # We avoid recomputing forward pass for each token by updating the + # state/hidden_states using the output, and pass the logits to + # next iteration step. + logits = gather_win_token(unflat_next_logits) + next_hidden_states = gather_win_token(unflat_next_hidden_states) + state = tf.nest.map_structure(gather_win_token, unflat_state) + + hidden_states = dynamic_update_slice( + hidden_states, next_hidden_states, [0, index, 0] ) - return (prompt, state, index + 1) + return (prompt, state, index + 1, logits, hidden_states) - prompt, _, _ = tf.while_loop( + prompt, _, _, _, _ = tf.while_loop( cond=cond, body=body, - loop_vars=(prompt, state, index), + loop_vars=(prompt, state, index, logits, hidden_states), maximum_iterations=(max_length - index), ) return prompt diff --git a/keras_nlp/samplers/contrastive_sampler_test.py b/keras_nlp/samplers/contrastive_sampler_test.py new file mode 100644 index 0000000000..13ea9eb654 --- /dev/null +++ b/keras_nlp/samplers/contrastive_sampler_test.py @@ -0,0 +1,139 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Top-K sampler.""" + +import tensorflow as tf +from absl.testing import parameterized + +from keras_nlp.samplers.contrastive_sampler import ContrastiveSampler + + +class ContrastiveSamplerTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + # Use a simple alphabet of lowercase characters to [0, 26). + self.int_lookup = {i: chr(i + ord("a")) for i in range(26)} + self.char_lookup = {v: k for k, v in self.int_lookup.items()} + self.batch_size = 1 + self.length = 13 + self.hidden_dim = 3 + self.vocab_size = len(self.int_lookup) + self.init_hidden_states = tf.ones( + [ + self.batch_size, + self.length, + self.hidden_dim, + ] + ) + + def next(prompt, state, index): + batch_size = tf.shape(prompt)[0] + # Return a distribution favoring the next char in state. + logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9 + hidden_states = tf.ones([batch_size, 1, self.hidden_dim]) + return logits, state, hidden_states + + self.next = next + self.sampler = ContrastiveSampler(k=5, alpha=0.2) + + def join_as_string(self, x): + return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] + + def test_stateless_call(self): + def next(prompt, state, index): + # Return a distribution favoring the first token in the vocab. + batch_size = tf.shape(prompt)[0] + logits = tf.one_hot(0, self.vocab_size) * 1e9 + logits = tf.reshape( + tf.repeat([logits], batch_size, axis=0), + [batch_size, -1], + ) + hidden_states = tf.ones([batch_size, 1, self.hidden_dim]) + return logits, state, hidden_states + + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=next, + prompt=prompt, + index=5, + init_hidden_states=self.init_hidden_states, + ) + self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaaa"]) + + def test_stateful_call(self): + state_chars = list("zsequentiallyy") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=self.next, + prompt=prompt, + state=state, + index=1, + init_hidden_states=self.init_hidden_states, + ) + self.assertEqual(self.join_as_string(output[:, 1:]), ["sequentially"]) + + def test_early_stopping(self): + state_chars = list("zsequentiallyy") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=self.next, + prompt=prompt, + state=state, + end_token_id=self.char_lookup["t"], + index=1, + init_hidden_states=self.init_hidden_states, + ) + self.assertEqual(self.join_as_string(output[:, 1:]), ["sequentzzzzz"]) + + def test_outputs_in_top_k(self): + def next(prompt, state, index): + batch_size = tf.shape(prompt)[0] + # Return a distribution where each id is progressively less likely. + logits = tf.range(self.vocab_size, 0, -1, dtype="float32") + logits = tf.repeat(logits[tf.newaxis, :], batch_size, axis=0) + hidden_states = tf.ones([batch_size, 1, self.hidden_dim]) + return logits, state, hidden_states + + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=next, + prompt=prompt, + index=1, + init_hidden_states=self.init_hidden_states, + ) + output_ids = set(output[0, 1:].numpy()) + self.assertContainsSubset(output_ids, range(5)) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_compilation(self, jit_compile): + state_chars = list("zsequentiallyy") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + + @tf.function(jit_compile=jit_compile) + def generate(prompt, state): + return self.sampler( + self.next, + prompt=prompt, + state=state, + index=1, + init_hidden_states=self.init_hidden_states, + ) + + output = generate(prompt, state) + self.assertEqual(self.join_as_string(output[:, 1:]), ["sequentially"]) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 84a6d1a4eb..33592120f1 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -23,7 +23,9 @@ call_args_docstring = """ next: A function which takes in the `prompt, state, index` of the current generation loop, and outputs a tuple `(logits, state)` with the - probability for the next token and state for the next iteration. + probability for the next token and state for the next iteration. Or a + tuple `(logits, state, hidden_states)` with `hidden_states` being the + representation of the token. prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This tensor will be iteratively updated column by column with new sampled values, starting at `index`. @@ -87,10 +89,10 @@ def __call__( next, prompt, state=None, - initial_hidden_states=None, index=0, mask=None, end_token_id=None, + init_hidden_states=None, ): max_length = tf.shape(prompt)[-1] # Make sure `max_length` and `index` are the same dtype. @@ -109,7 +111,13 @@ def cond(prompt, state, index): def body(prompt, state, index): # Compute the softmax distribution for the next token. - logits, state, _ = next(prompt, state, index) + outputs = next(prompt, state, index) + if len(outputs) == 2: + logits, state = next(prompt, state, index) + else: + # `next` could contain a `hidden_states` return value, which is + # only used in contrastive search now. + logits, state, _ = next(prompt, state, index) probabilities = keras.activations.softmax(logits) # Compute the next token. next_token = self.get_next_token(probabilities) From 8d57bd4443a11bdc8003d23f54a093c30964b597 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Fri, 31 Mar 2023 12:29:41 +0800 Subject: [PATCH 04/11] renaming --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 25 +++++++------- keras_nlp/samplers/beam_sampler.py | 24 +++++++------- keras_nlp/samplers/contrastive_sampler.py | 33 +++++++++---------- keras_nlp/samplers/sampler.py | 40 ++++++++++------------- 4 files changed, 58 insertions(+), 64 deletions(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 91dfb4a7fe..acbdeba2fb 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -229,9 +229,10 @@ def call_with_cache( whole sequence. Returns: - A (logits, cache) tuple. Where the first output is the language - model logits for the input token_ids and the second output is the - cache. + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the hidden state of the input (the layer before embedding matrix + mapping), and `cache` is the decoding cache. """ token_embedding = self.backbone.get_layer("token_embedding")(token_ids) position_embedding = self.backbone.get_layer("position_embedding")( @@ -253,13 +254,13 @@ def call_with_cache( caches[i] = next_cache cache = tf.stack(caches, axis=1) x = self.backbone.get_layer("layer_norm")(x) - hidden_states = tf.identity(x) + hidden_states = x x = tf.matmul( x, self.backbone.get_layer("token_embedding").embeddings, transpose_b=True, ) - return x, cache, hidden_states + return x, hidden_states, cache def _build_cache(self, prompt): """Build an empty cache for use with `call_with_cache()`.""" @@ -270,8 +271,8 @@ def _build_cache(self, prompt): shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] cache = tf.zeros(shape) # Seed the cache. - _, cache, hidden_states = self.call_with_cache(prompt, cache, 0) - return cache, hidden_states + _, hidden_states, cache = self.call_with_cache(prompt, cache, 0) + return hidden_states, cache def compile( self, @@ -300,28 +301,28 @@ def make_generate_function(self): def generate_function(prompt, input_mask, min_length): # Create and seed cache with a single forward pass. - cache, hidden_states = self._build_cache(prompt) + hidden_states, cache = self._build_cache(prompt) def next(prompt, state, index): # The cache index is the index of our previous token. cache_index = index - 1 cache = state prompt = tf.slice(prompt, [0, cache_index], [-1, 1]) - logits, cache, hidden_states = self.call_with_cache( + logits, hidden_states, cache = self.call_with_cache( prompt, cache, cache_index, ) - return tf.squeeze(logits, axis=1), cache, hidden_states + return tf.squeeze(logits, axis=1), hidden_states, cache return self._sampler( next=next, prompt=prompt, - state=cache, + cache=cache, index=min_length, mask=input_mask, end_token_id=self.preprocessor.tokenizer.end_token_id, - init_hidden_states=hidden_states, + hidden_states=hidden_states, ) if self.run_eagerly: diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 561595a5fd..8bb3963637 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -50,10 +50,10 @@ class BeamSampler(Sampler): char_lookup = {v: k for k, v in int_lookup.items()} batch_size, length, vocab_size = 1, 12, len(int_lookup) - def next(prompt, state, index): + def next(prompt, cache, index): # A uniform distribution over our alphabet. logits = tf.ones((batch_size, vocab_size)) - return logits, state + return logits, cache output = keras_nlp.samplers.BeamSampler()( next=next, @@ -104,11 +104,11 @@ def __call__( self, next, prompt, - state=None, + cache=None, index=0, mask=None, end_token_id=None, - init_hidden_states=None, + hidden_states=None, ): batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] # Make sure max length and start index are the same dtype. @@ -130,16 +130,16 @@ def unflatten_beams(x): mask = tf.zeros_like(prompt, dtype=tf.bool) if mask is None else mask # `tf.while_loop` will not accept `None` as a value for `loop_vars`. - state = () if state is None else state + cache = () if cache is None else cache # Add extra sequences for each beam. prompt, mask = create_beams(prompt), create_beams(mask) - state = tf.nest.map_structure(create_beams, state) + cache = tf.nest.map_structure(create_beams, cache) # Setup the initial beam log-likelihoods. # On the first loop, make sure only the original beam is considered. log_probs = tf.constant([[0.0] + [-1e9] * (self.num_beams - 1)]) log_probs = flatten_beams(tf.repeat(log_probs, batch_size, axis=0)) - def cond(prompt, state, index, log_probs): + def cond(prompt, cache, index, log_probs): if end_token_id is None: return True # Stop if all sequences have produced a *new* end_token_id. @@ -147,9 +147,9 @@ def cond(prompt, state, index, log_probs): prompt_done = tf.reduce_any(end_tokens, axis=-1) return not tf.reduce_all(prompt_done) - def body(prompt, state, index, log_probs): + def body(prompt, cache, index, log_probs): # Compute the softmax distribution for the next token. - logits, state = next(prompt, state, index) + logits, _, cache = next(prompt, cache, index) vocab_size = tf.shape(logits)[-1] probs = keras.activations.softmax(logits) @@ -177,7 +177,7 @@ def gather_beams(x): return flatten_beams(x) prompt = gather_beams(prompt) - state = tf.nest.map_structure(gather_beams, state) + cache = tf.nest.map_structure(gather_beams, cache) # Update each beam with the next token. next_token = tf.cast(next_token, prompt.dtype) @@ -187,12 +187,12 @@ def gather_beams(x): next_token = next_token[:, tf.newaxis] prompt = dynamic_update_slice(prompt, next_token, [0, index]) # Return the iteration of the loop state. - return (prompt, state, index + 1, log_probs) + return (prompt, cache, index + 1, log_probs) prompt, _, _, log_probs = tf.while_loop( cond=cond, body=body, - loop_vars=(prompt, state, index, log_probs), + loop_vars=(prompt, cache, index, log_probs), maximum_iterations=(max_length - index), ) diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index f8748e347e..bd76e016c6 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -51,10 +51,10 @@ class ContrastiveSampler(Sampler): char_lookup = {v: k for k, v in int_lookup.items()} batch_size, length, vocab_size = 1, 12, len(int_lookup) - def next(prompt, state, index): + def next(prompt, cache, index): # A uniform distribution over our alphabet. logits = tf.ones((batch_size, vocab_size)) - return logits, state + return logits, None, cache output = keras_nlp.samplers.ContrastiveSampler()( next=next, @@ -81,14 +81,13 @@ def __call__( self, next, prompt, - state=None, + cache=None, index=0, mask=None, end_token_id=None, - init_hidden_states=None, + hidden_states=None, ): batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] - hidden_states = init_hidden_states # Make sure max length and start index are the same dtype. index = tf.cast(index, max_length.dtype) @@ -110,11 +109,11 @@ def unflatten_beams(x): mask = tf.zeros_like(prompt, dtype=tf.bool) if mask is None else mask # Compute initial logits. - logits, state, _ = next(prompt, state, index) + logits, _, cache = next(prompt, cache, index) # `tf.while_loop` will not accept `None` as a value for `loop_vars`. - state = () if state is None else state + cache = () if cache is None else cache - def cond(prompt, state, index, logits, hidden_states): + def cond(prompt, cache, index, logits, hidden_states): if end_token_id is None: return True # Stop if all sequences have produced a *new* end_token_id. @@ -122,7 +121,7 @@ def cond(prompt, state, index, logits, hidden_states): prompt_done = tf.reduce_any(end_tokens, axis=-1) return not tf.reduce_all(prompt_done) - def body(prompt, state, index, logits, hidden_states): + def body(prompt, cache, index, logits, hidden_states): # Compute the softmax distribution for the next token. probabilities = keras.activations.softmax(logits) @@ -131,7 +130,7 @@ def body(prompt, state, index, logits, hidden_states): prompt_beams = create_beams(prompt) mask_beams = create_beams(mask) hidden_states_beams = create_beams(hidden_states) - state_beams = tf.nest.map_structure(create_beams, state) + cache_beams = tf.nest.map_structure(create_beams, cache) # Get top-k candidate tokens and their probabilities. top_k_probabilities, top_k_indices = tf.math.top_k( @@ -151,8 +150,8 @@ def body(prompt, state, index, logits, hidden_states): ) # Compute the logits and hidden states for top-k candidate tokens. - next_logits, state_beams, next_hidden_states_beams = next( - prompt_beams, state_beams, index + 1 + next_logits, next_hidden_states_beams, cache_beams = next( + prompt_beams, cache_beams, index + 1 ) # Compute the max similarity score for top-k candidate tokens @@ -181,7 +180,7 @@ def body(prompt, state, index, logits, hidden_states): unflat_next_hidden_states = unflatten_beams( next_hidden_states_beams ) - unflat_state = tf.nest.map_structure(unflatten_beams, state_beams) + unflat_cache = tf.nest.map_structure(unflatten_beams, cache_beams) win_token_indices = tf.math.argmax(unflat_score, axis=1) def gather_win_token(beams): @@ -189,21 +188,21 @@ def gather_win_token(beams): prompt = gather_win_token(unflat_prompt) # We avoid recomputing forward pass for each token by updating the - # state/hidden_states using the output, and pass the logits to + # cache/hidden_states using the output, and pass the logits to # next iteration step. logits = gather_win_token(unflat_next_logits) next_hidden_states = gather_win_token(unflat_next_hidden_states) - state = tf.nest.map_structure(gather_win_token, unflat_state) + cache = tf.nest.map_structure(gather_win_token, unflat_cache) hidden_states = dynamic_update_slice( hidden_states, next_hidden_states, [0, index, 0] ) - return (prompt, state, index + 1, logits, hidden_states) + return (prompt, cache, index + 1, logits, hidden_states) prompt, _, _, _, _ = tf.while_loop( cond=cond, body=body, - loop_vars=(prompt, state, index, logits, hidden_states), + loop_vars=(prompt, cache, index, logits, hidden_states), maximum_iterations=(max_length - index), ) return prompt diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 33592120f1..d7b78159d5 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -21,15 +21,15 @@ from keras_nlp.utils.python_utils import format_docstring call_args_docstring = """ - next: A function which takes in the `prompt, state, index` of the - current generation loop, and outputs a tuple `(logits, state)` with the - probability for the next token and state for the next iteration. Or a - tuple `(logits, state, hidden_states)` with `hidden_states` being the + next: A function which takes in the `prompt, cache, index` of the + current generation loop, and outputs a tuple + `(logits, cache, hidden_states)` with `logits` being the logits of next + token, `cache` for next iteration, and `hidden_states` being the representation of the token. prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This tensor will be iteratively updated column by column with new sampled values, starting at `index`. - state: Optional. A tensor or nested structure of tensors that will be + cache: Optional. A tensor or nested structure of tensors that will be updated by each call to `next`. This can be used to cache computations from early iterations of the generative loop. index: Optional. The first index to start sampling at. @@ -56,7 +56,7 @@ class Sampler: - Override the `get_next_token()` method, which computes the next token based on a probability distribution over all possible vocab entries. - - Override `__call__`, if the sampling method need additional state beyond + - Override `__call__`, if the sampling method need additional cache beyond the next tokens probability distribution to sample a sequence. Please check available subclass samplers for examples. @@ -69,10 +69,10 @@ class Sampler: char_lookup = {v: k for k, v in int_lookup.items()} batch_size, length, vocab_size = 1, 12, len(int_lookup) - def next(prompt, state, index): + def next(prompt, cache, index): # return a uniform distribution over our alphabet. logits = tf.ones((batch_size, vocab_size)) - return logits, state + return logits, None, cache output = keras_nlp.samplers.GreedySampler()( next=next, @@ -88,20 +88,20 @@ def __call__( self, next, prompt, - state=None, + cache=None, index=0, mask=None, end_token_id=None, - init_hidden_states=None, + hidden_states=None, ): max_length = tf.shape(prompt)[-1] # Make sure `max_length` and `index` are the same dtype. index = tf.cast(index, max_length.dtype) mask = tf.zeros_like(prompt, dtype=tf.bool) if mask is None else mask # `tf.while_loop` will not accept `None` as a value for `loop_vars`. - state = () if state is None else state + cache = () if cache is None else cache - def cond(prompt, state, index): + def cond(prompt, cache, index): if end_token_id is None: return True # Stop if all sequences have produced a *new* end_token_id. @@ -109,15 +109,9 @@ def cond(prompt, state, index): prompt_done = tf.reduce_any(end_tokens, axis=-1) return not tf.reduce_all(prompt_done) - def body(prompt, state, index): + def body(prompt, cache, index): # Compute the softmax distribution for the next token. - outputs = next(prompt, state, index) - if len(outputs) == 2: - logits, state = next(prompt, state, index) - else: - # `next` could contain a `hidden_states` return value, which is - # only used in contrastive search now. - logits, state, _ = next(prompt, state, index) + logits, _, cache = next(prompt, cache, index) probabilities = keras.activations.softmax(logits) # Compute the next token. next_token = self.get_next_token(probabilities) @@ -130,13 +124,13 @@ def body(prompt, state, index): # Update the prompt with the next token. next_token = next_token[:, tf.newaxis] prompt = dynamic_update_slice(prompt, next_token, [0, index]) - # Return the next prompt, state and incremented index. - return (prompt, state, index + 1) + # Return the next prompt, cache and incremented index. + return (prompt, cache, index + 1) prompt, _, _ = tf.while_loop( cond=cond, body=body, - loop_vars=(prompt, state, index), + loop_vars=(prompt, cache, index), maximum_iterations=(max_length - index), ) return prompt From 3fe775ebc88b260e8f93e132d7e151aeafe52038 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Fri, 31 Mar 2023 13:22:39 +0800 Subject: [PATCH 05/11] even better style --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 6 +-- keras_nlp/samplers/beam_sampler_test.py | 34 ++++++------- .../samplers/contrastive_sampler_test.py | 50 +++++++++---------- keras_nlp/samplers/greedy_sampler_test.py | 38 +++++++------- keras_nlp/samplers/sampler.py | 4 +- keras_nlp/samplers/top_k_sampler_test.py | 38 +++++++------- keras_nlp/samplers/top_p_sampler_test.py | 38 +++++++------- 7 files changed, 104 insertions(+), 104 deletions(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index acbdeba2fb..55d3e57527 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -229,9 +229,9 @@ def call_with_cache( whole sequence. Returns: - A (logits, hidden_states, cache) tuple. Where `logits` is the - language model logits for the input token_ids, `hidden_states` is - the hidden state of the input (the layer before embedding matrix + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the hidden state of the input (the layer before embedding matrix mapping), and `cache` is the decoding cache. """ token_embedding = self.backbone.get_layer("token_embedding")(token_ids) diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index e20c66035e..03f73a2e31 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -30,10 +30,10 @@ def setUp(self): self.length = 12 self.vocab_size = len(self.int_lookup) - def next(prompt, state, index): - # Return a distribution favoring the next char in state. - logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9 - return logits, state + def next(prompt, cache, index): + # Return a distribution favoring the next char in cache. + logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 + return logits, None, cache self.next = next self.sampler = BeamSampler(num_beams=5) @@ -43,11 +43,11 @@ def join_as_string(self, x): return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] def test_stateless_call(self): - def next(prompt, state, index): + def next(prompt, cache, index): # Return a distribution favoring the first token in the vocab. logits = np.zeros((self.batch_size, self.vocab_size)) logits[:, 0] = 1e9 - return tf.constant(logits, dtype="float32"), state + return tf.constant(logits, dtype="float32"), None, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( @@ -58,13 +58,13 @@ def next(prompt, state, index): self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"]) def test_stateful_call(self): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=self.next, prompt=prompt, - state=state, + cache=cache, ) self.assertEqual(self.join_as_string(output), ["sequentially"]) @@ -90,13 +90,13 @@ def test_return_all_beams(self): ) def test_early_stopping(self): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=self.next, prompt=prompt, - state=state, + cache=cache, end_token_id=self.char_lookup["t"], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) @@ -105,13 +105,13 @@ def test_early_stopping(self): ("jit_compile_false", False), ("jit_compile_true", True) ) def test_compilation(self, jit_compile): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) @tf.function(jit_compile=jit_compile) - def generate(prompt, state): - return self.sampler(self.next, prompt=prompt, state=state) + def generate(prompt, cache): + return self.sampler(self.next, prompt=prompt, cache=cache) - output = generate(prompt, state) + output = generate(prompt, cache) self.assertEqual(self.join_as_string(output), ["sequentially"]) diff --git a/keras_nlp/samplers/contrastive_sampler_test.py b/keras_nlp/samplers/contrastive_sampler_test.py index 13ea9eb654..a0d46884f0 100644 --- a/keras_nlp/samplers/contrastive_sampler_test.py +++ b/keras_nlp/samplers/contrastive_sampler_test.py @@ -29,7 +29,7 @@ def setUp(self): self.length = 13 self.hidden_dim = 3 self.vocab_size = len(self.int_lookup) - self.init_hidden_states = tf.ones( + self.hidden_states = tf.ones( [ self.batch_size, self.length, @@ -37,12 +37,12 @@ def setUp(self): ] ) - def next(prompt, state, index): + def next(prompt, cache, index): batch_size = tf.shape(prompt)[0] - # Return a distribution favoring the next char in state. - logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9 + # Return a distribution favoring the next char in cache. + logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 hidden_states = tf.ones([batch_size, 1, self.hidden_dim]) - return logits, state, hidden_states + return logits, hidden_states, cache self.next = next self.sampler = ContrastiveSampler(k=5, alpha=0.2) @@ -51,7 +51,7 @@ def join_as_string(self, x): return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] def test_stateless_call(self): - def next(prompt, state, index): + def next(prompt, cache, index): # Return a distribution favoring the first token in the vocab. batch_size = tf.shape(prompt)[0] logits = tf.one_hot(0, self.vocab_size) * 1e9 @@ -60,59 +60,59 @@ def next(prompt, state, index): [batch_size, -1], ) hidden_states = tf.ones([batch_size, 1, self.hidden_dim]) - return logits, state, hidden_states + return logits, hidden_states, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=next, prompt=prompt, index=5, - init_hidden_states=self.init_hidden_states, + hidden_states=self.hidden_states, ) self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaaa"]) def test_stateful_call(self): - state_chars = list("zsequentiallyy") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("zsequentiallyy") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=self.next, prompt=prompt, - state=state, + cache=cache, index=1, - init_hidden_states=self.init_hidden_states, + hidden_states=self.hidden_states, ) self.assertEqual(self.join_as_string(output[:, 1:]), ["sequentially"]) def test_early_stopping(self): - state_chars = list("zsequentiallyy") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("zsequentiallyy") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=self.next, prompt=prompt, - state=state, + cache=cache, end_token_id=self.char_lookup["t"], index=1, - init_hidden_states=self.init_hidden_states, + hidden_states=self.hidden_states, ) self.assertEqual(self.join_as_string(output[:, 1:]), ["sequentzzzzz"]) def test_outputs_in_top_k(self): - def next(prompt, state, index): + def next(prompt, cache, index): batch_size = tf.shape(prompt)[0] # Return a distribution where each id is progressively less likely. logits = tf.range(self.vocab_size, 0, -1, dtype="float32") logits = tf.repeat(logits[tf.newaxis, :], batch_size, axis=0) hidden_states = tf.ones([batch_size, 1, self.hidden_dim]) - return logits, state, hidden_states + return logits, hidden_states, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=next, prompt=prompt, index=1, - init_hidden_states=self.init_hidden_states, + hidden_states=self.hidden_states, ) output_ids = set(output[0, 1:].numpy()) self.assertContainsSubset(output_ids, range(5)) @@ -121,19 +121,19 @@ def next(prompt, state, index): ("jit_compile_false", False), ("jit_compile_true", True) ) def test_compilation(self, jit_compile): - state_chars = list("zsequentiallyy") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("zsequentiallyy") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) @tf.function(jit_compile=jit_compile) - def generate(prompt, state): + def generate(prompt, cache): return self.sampler( self.next, prompt=prompt, - state=state, + cache=cache, index=1, - init_hidden_states=self.init_hidden_states, + hidden_states=self.hidden_states, ) - output = generate(prompt, state) + output = generate(prompt, cache) self.assertEqual(self.join_as_string(output[:, 1:]), ["sequentially"]) diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py index 21c237e56b..f120ad6701 100644 --- a/keras_nlp/samplers/greedy_sampler_test.py +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -30,10 +30,10 @@ def setUp(self): self.length = 12 self.vocab_size = len(self.int_lookup) - def next(prompt, state, index): - # Return a distribution favoring the next char in state. - logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9 - return logits, state + def next(prompt, cache, index): + # Return a distribution favoring the next char in cache. + logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 + return logits, None, cache self.next = next self.sampler = GreedySampler() @@ -42,11 +42,11 @@ def join_as_string(self, x): return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] def test_stateless_call(self): - def next(prompt, state, index): + def next(prompt, cache, index): # Return a distribution favoring the first token in the vocab. logits = np.zeros((self.batch_size, self.vocab_size)) logits[:, 0] = 1e9 - return tf.constant(logits), state + return tf.constant(logits), None, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( @@ -57,34 +57,34 @@ def next(prompt, state, index): self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"]) def test_stateful_call(self): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=self.next, prompt=prompt, - state=state, + cache=cache, ) self.assertEqual(self.join_as_string(output), ["sequentially"]) def test_early_stopping(self): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=self.next, prompt=prompt, - state=state, + cache=cache, end_token_id=self.char_lookup["t"], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) def test_is_greedy(self): - def next(prompt, state, index): + def next(prompt, cache, index): # Return a distribution where each id is progressively less likely. logits = tf.range(self.vocab_size, 0, -1, dtype="float32") logits = tf.repeat(logits[tf.newaxis, :], self.batch_size, axis=0) - return logits, state + return logits, None, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( @@ -98,13 +98,13 @@ def next(prompt, state, index): ("jit_compile_false", False), ("jit_compile_true", True) ) def test_compilation(self, jit_compile): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) @tf.function(jit_compile=jit_compile) - def generate(prompt, state): - return self.sampler(self.next, prompt=prompt, state=state) + def generate(prompt, cache): + return self.sampler(self.next, prompt=prompt, cache=cache) - output = generate(prompt, state) + output = generate(prompt, cache) self.assertEqual(self.join_as_string(output), ["sequentially"]) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index d7b78159d5..964df843e9 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -22,8 +22,8 @@ call_args_docstring = """ next: A function which takes in the `prompt, cache, index` of the - current generation loop, and outputs a tuple - `(logits, cache, hidden_states)` with `logits` being the logits of next + current generation loop, and outputs a tuple + `(logits, cache, hidden_states)` with `logits` being the logits of next token, `cache` for next iteration, and `hidden_states` being the representation of the token. prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This diff --git a/keras_nlp/samplers/top_k_sampler_test.py b/keras_nlp/samplers/top_k_sampler_test.py index 9efc45f41d..b726a88b8f 100644 --- a/keras_nlp/samplers/top_k_sampler_test.py +++ b/keras_nlp/samplers/top_k_sampler_test.py @@ -30,10 +30,10 @@ def setUp(self): self.length = 12 self.vocab_size = len(self.int_lookup) - def next(prompt, state, index): - # Return a distribution favoring the next char in state. - logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9 - return logits, state + def next(prompt, cache, index): + # Return a distribution favoring the next char in cache. + logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 + return logits, None, cache self.next = next self.sampler = TopKSampler(k=5) @@ -42,11 +42,11 @@ def join_as_string(self, x): return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] def test_stateless_call(self): - def next(prompt, state, index): + def next(prompt, cache, index): # Return a distribution favoring the first token in the vocab. logits = np.zeros((self.batch_size, self.vocab_size)) logits[:, 0] = 1e9 - return tf.constant(logits), state + return tf.constant(logits), None, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( @@ -57,34 +57,34 @@ def next(prompt, state, index): self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"]) def test_stateful_call(self): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=self.next, prompt=prompt, - state=state, + cache=cache, ) self.assertEqual(self.join_as_string(output), ["sequentially"]) def test_early_stopping(self): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=self.next, prompt=prompt, - state=state, + cache=cache, end_token_id=self.char_lookup["t"], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) def test_outputs_in_top_k(self): - def next(prompt, state, index): + def next(prompt, cache, index): # Return a distribution where each id is progressively less likely. logits = tf.range(self.vocab_size, 0, -1, dtype="float32") logits = tf.repeat(logits[tf.newaxis, :], self.batch_size, axis=0) - return logits, state + return logits, None, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( @@ -98,13 +98,13 @@ def next(prompt, state, index): ("jit_compile_false", False), ("jit_compile_true", True) ) def test_compilation(self, jit_compile): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) @tf.function(jit_compile=jit_compile) - def generate(prompt, state): - return self.sampler(self.next, prompt=prompt, state=state) + def generate(prompt, cache): + return self.sampler(self.next, prompt=prompt, cache=cache) - output = generate(prompt, state) + output = generate(prompt, cache) self.assertEqual(self.join_as_string(output), ["sequentially"]) diff --git a/keras_nlp/samplers/top_p_sampler_test.py b/keras_nlp/samplers/top_p_sampler_test.py index 73efc485cd..5d3b59c0d9 100644 --- a/keras_nlp/samplers/top_p_sampler_test.py +++ b/keras_nlp/samplers/top_p_sampler_test.py @@ -30,10 +30,10 @@ def setUp(self): self.length = 12 self.vocab_size = len(self.int_lookup) - def next(prompt, state, index): - # Return a distribution favoring the next char in state. - logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9 - return logits, state + def next(prompt, cache, index): + # Return a distribution favoring the next char in cache. + logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 + return logits, None, cache self.next = next self.sampler = TopPSampler(p=0.1) @@ -42,11 +42,11 @@ def join_as_string(self, x): return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] def test_stateless_call(self): - def next(prompt, state, index): + def next(prompt, cache, index): # Return a distribution favoring the first token in the vocab. logits = np.zeros((self.batch_size, self.vocab_size)) logits[:, 0] = 1e9 - return tf.constant(logits), state + return tf.constant(logits), None, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( @@ -57,32 +57,32 @@ def next(prompt, state, index): self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"]) def test_stateful_call(self): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=self.next, prompt=prompt, - state=state, + cache=cache, ) self.assertEqual(self.join_as_string(output), ["sequentially"]) def test_early_stopping(self): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=self.next, prompt=prompt, - state=state, + cache=cache, end_token_id=self.char_lookup["t"], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) def test_outputs_in_top_p(self): - def next(prompt, state, index): + def next(prompt, cache, index): logits = np.zeros((self.batch_size, self.vocab_size)) - return tf.constant(logits), state + return tf.constant(logits), None, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = TopPSampler(p=(2.0 / self.vocab_size))( @@ -96,13 +96,13 @@ def next(prompt, state, index): ("jit_compile_false", False), ("jit_compile_true", True) ) def test_compilation(self, jit_compile): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) @tf.function(jit_compile=jit_compile) - def generate(prompt, state): - return self.sampler(self.next, prompt=prompt, state=state) + def generate(prompt, cache): + return self.sampler(self.next, prompt=prompt, cache=cache) - output = generate(prompt, state) + output = generate(prompt, cache) self.assertEqual(self.join_as_string(output), ["sequentially"]) From cb085a3f59ca07c40b5daa73fde8e6ed8fae9376 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Mon, 3 Apr 2023 21:32:10 -0700 Subject: [PATCH 06/11] address comments --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 3 +- keras_nlp/samplers/beam_sampler.py | 8 +++-- keras_nlp/samplers/beam_sampler_test.py | 18 +++++++--- keras_nlp/samplers/contrastive_sampler.py | 22 ++++++++---- .../samplers/contrastive_sampler_test.py | 34 ++++++++++--------- keras_nlp/samplers/greedy_sampler.py | 7 ++-- keras_nlp/samplers/greedy_sampler_test.py | 22 ++++++++---- keras_nlp/samplers/top_k_sampler.py | 5 +-- keras_nlp/samplers/top_k_sampler_test.py | 22 ++++++++---- keras_nlp/samplers/top_p_sampler.py | 5 +-- keras_nlp/samplers/top_p_sampler_test.py | 21 +++++++++--- 11 files changed, 110 insertions(+), 57 deletions(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 55d3e57527..072eb6ed04 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -303,10 +303,9 @@ def generate_function(prompt, input_mask, min_length): # Create and seed cache with a single forward pass. hidden_states, cache = self._build_cache(prompt) - def next(prompt, state, index): + def next(prompt, cache, index): # The cache index is the index of our previous token. cache_index = index - 1 - cache = state prompt = tf.slice(prompt, [0, cache_index], [-1, 1]) logits, hidden_states, cache = self.call_with_cache( prompt, diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 8bb3963637..46b571be88 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -51,13 +51,15 @@ class BeamSampler(Sampler): batch_size, length, vocab_size = 1, 12, len(int_lookup) def next(prompt, cache, index): + prompt_batch_size = tf.shape(prompt)[0] + hidden_states = tf.ones((prompt_batch_size, 1, 10)) # A uniform distribution over our alphabet. - logits = tf.ones((batch_size, vocab_size)) - return logits, cache + logits = tf.ones((prompt_batch_size, vocab_size)) + return logits, hidden_states, cache output = keras_nlp.samplers.BeamSampler()( next=next, - prompt=tf.fill((batch_size, length,), char_lookup['z']), + prompt=tf.fill((batch_size, length), char_lookup["z"]), index=5, ) print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index 03f73a2e31..a7820cf4bf 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for Beam sampler.""" -import numpy as np import tensorflow as tf from absl.testing import parameterized @@ -31,9 +30,11 @@ def setUp(self): self.vocab_size = len(self.int_lookup) def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([10]) # Return a distribution favoring the next char in cache. logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 - return logits, None, cache + return logits, hidden_states, cache self.next = next self.sampler = BeamSampler(num_beams=5) @@ -44,10 +45,17 @@ def join_as_string(self, x): def test_stateless_call(self): def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([10]) # Return a distribution favoring the first token in the vocab. - logits = np.zeros((self.batch_size, self.vocab_size)) - logits[:, 0] = 1e9 - return tf.constant(logits, dtype="float32"), None, cache + logits = ( + tf.one_hot( + tf.zeros(self.batch_size, dtype=tf.int32), + self.vocab_size, + ) + * 1e9 + ) + return logits, hidden_states, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index bd76e016c6..5310b58240 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -47,19 +47,24 @@ class ContrastiveSampler(Sampler): Examples: ```python # Use a simple alphabet of lowercase characters to [0, 26). - int_lookup = {i: chr(i + ord('a')) for i in range(26)} + int_lookup = {i: chr(i + ord("a")) for i in range(26)} char_lookup = {v: k for k, v in int_lookup.items()} batch_size, length, vocab_size = 1, 12, len(int_lookup) + hidden_size = 5 + index = 5 def next(prompt, cache, index): + prompt_batch_size = tf.shape(prompt)[0] + hidden_states = tf.ones((prompt_batch_size, 1, hidden_size)) # A uniform distribution over our alphabet. - logits = tf.ones((batch_size, vocab_size)) - return logits, None, cache + logits = tf.ones((prompt_batch_size, vocab_size)) + return logits, hidden_states, cache output = keras_nlp.samplers.ContrastiveSampler()( next=next, - prompt=tf.fill((batch_size, length,), char_lookup['z']), - index=5, + prompt=tf.fill((batch_size, length), char_lookup["z"]), + index=index, + hidden_states=tf.ones([batch_size, index, hidden_size]), ) print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) # >>> "zzzzzaaaaaaa" @@ -81,11 +86,11 @@ def __call__( self, next, prompt, + hidden_states, cache=None, index=0, mask=None, end_token_id=None, - hidden_states=None, ): batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] # Make sure max length and start index are the same dtype. @@ -165,13 +170,16 @@ def body(prompt, cache, index, logits, hidden_states): tf.reduce_max(similarity_scores, axis=1), dtype=next_token_probabilities.dtype, ) + if index == 0: + # If the index is 0, there is no previous states so we set + # `max_similarity_scores` the same for all beams. + max_similarity_scores = tf.zeros_like(max_similarity_scores) # The final score of each candidate token is weighted sum of # probability and similarity against previous tokens. accumulated_scores = ( (1 - self.alpha) * next_token_probabilities - self.alpha * max_similarity_scores ) - # Unflatten varibles to shape [batch_size, self.k, ...] for # gather purpose. unflat_score = unflatten_beams(accumulated_scores) diff --git a/keras_nlp/samplers/contrastive_sampler_test.py b/keras_nlp/samplers/contrastive_sampler_test.py index a0d46884f0..18e7620007 100644 --- a/keras_nlp/samplers/contrastive_sampler_test.py +++ b/keras_nlp/samplers/contrastive_sampler_test.py @@ -26,7 +26,7 @@ def setUp(self): self.int_lookup = {i: chr(i + ord("a")) for i in range(26)} self.char_lookup = {v: k for k, v in self.int_lookup.items()} self.batch_size = 1 - self.length = 13 + self.length = 12 self.hidden_dim = 3 self.vocab_size = len(self.int_lookup) self.hidden_states = tf.ones( @@ -54,10 +54,12 @@ def test_stateless_call(self): def next(prompt, cache, index): # Return a distribution favoring the first token in the vocab. batch_size = tf.shape(prompt)[0] - logits = tf.one_hot(0, self.vocab_size) * 1e9 - logits = tf.reshape( - tf.repeat([logits], batch_size, axis=0), - [batch_size, -1], + logits = ( + tf.one_hot( + tf.zeros(batch_size, dtype=tf.int32), + self.vocab_size, + ) + * 1e9 ) hidden_states = tf.ones([batch_size, 1, self.hidden_dim]) return logits, hidden_states, cache @@ -69,12 +71,12 @@ def next(prompt, cache, index): index=5, hidden_states=self.hidden_states, ) - self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaaa"]) + self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"]) def test_stateful_call(self): - cache_chars = list("zsequentiallyy") + cache_chars = list("sequentiallyy") cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) - prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["s"]) output = self.sampler( next=self.next, prompt=prompt, @@ -82,21 +84,21 @@ def test_stateful_call(self): index=1, hidden_states=self.hidden_states, ) - self.assertEqual(self.join_as_string(output[:, 1:]), ["sequentially"]) + self.assertEqual(self.join_as_string(output), ["sequentially"]) def test_early_stopping(self): - cache_chars = list("zsequentiallyy") + cache_chars = list("sequentiallyy") cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) - prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["s"]) output = self.sampler( next=self.next, prompt=prompt, cache=cache, end_token_id=self.char_lookup["t"], - index=1, + index=0, hidden_states=self.hidden_states, ) - self.assertEqual(self.join_as_string(output[:, 1:]), ["sequentzzzzz"]) + self.assertEqual(self.join_as_string(output), ["sequentsssss"]) def test_outputs_in_top_k(self): def next(prompt, cache, index): @@ -121,9 +123,9 @@ def next(prompt, cache, index): ("jit_compile_false", False), ("jit_compile_true", True) ) def test_compilation(self, jit_compile): - cache_chars = list("zsequentiallyy") + cache_chars = list("sequentiallyy") cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) - prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["s"]) @tf.function(jit_compile=jit_compile) def generate(prompt, cache): @@ -136,4 +138,4 @@ def generate(prompt, cache): ) output = generate(prompt, cache) - self.assertEqual(self.join_as_string(output[:, 1:]), ["sequentially"]) + self.assertEqual(self.join_as_string(output), ["sequentially"]) diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index 33b433e479..0e46fa6205 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -39,10 +39,11 @@ class GreedySampler(Sampler): char_lookup = {v: k for k, v in int_lookup.items()} batch_size, length, vocab_size = 1, 12, len(int_lookup) - def next(prompt, state, index): - # return a uniform distribution over our alphabet. + def next(prompt, cache, index): + hidden_states = tf.ones((batch_size, 1, 10)) + # A uniform distribution over our alphabet. logits = tf.ones((batch_size, vocab_size)) - return logits, state + return logits, hidden_states, cache output = keras_nlp.samplers.GreedySampler()( next=next, diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py index f120ad6701..a20486c50d 100644 --- a/keras_nlp/samplers/greedy_sampler_test.py +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for Greedy sampler.""" -import numpy as np import tensorflow as tf from absl.testing import parameterized @@ -31,9 +30,11 @@ def setUp(self): self.vocab_size = len(self.int_lookup) def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([10]) # Return a distribution favoring the next char in cache. logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 - return logits, None, cache + return logits, hidden_states, cache self.next = next self.sampler = GreedySampler() @@ -43,10 +44,17 @@ def join_as_string(self, x): def test_stateless_call(self): def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([10]) # Return a distribution favoring the first token in the vocab. - logits = np.zeros((self.batch_size, self.vocab_size)) - logits[:, 0] = 1e9 - return tf.constant(logits), None, cache + logits = ( + tf.one_hot( + tf.zeros(self.batch_size, dtype=tf.int32), + self.vocab_size, + ) + * 1e9 + ) + return logits, hidden_states, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( @@ -81,10 +89,12 @@ def test_early_stopping(self): def test_is_greedy(self): def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([10]) # Return a distribution where each id is progressively less likely. logits = tf.range(self.vocab_size, 0, -1, dtype="float32") logits = tf.repeat(logits[tf.newaxis, :], self.batch_size, axis=0) - return logits, None, cache + return logits, hidden_states, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index 69ac857cd1..f3782462e4 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -44,10 +44,11 @@ class TopKSampler(Sampler): char_lookup = {v: k for k, v in int_lookup.items()} batch_size, length, vocab_size = 1, 12, len(int_lookup) - def next(prompt, state, index): + def next(prompt, cache, index): + hidden_states = tf.ones((batch_size, 1, 10)) # A uniform distribution over our alphabet. logits = tf.ones((batch_size, vocab_size)) - return logits, state + return logits, hidden_states, cache output = keras_nlp.samplers.TopKSampler(k=3)( next=next, diff --git a/keras_nlp/samplers/top_k_sampler_test.py b/keras_nlp/samplers/top_k_sampler_test.py index b726a88b8f..31c0a06d02 100644 --- a/keras_nlp/samplers/top_k_sampler_test.py +++ b/keras_nlp/samplers/top_k_sampler_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for Top-K sampler.""" -import numpy as np import tensorflow as tf from absl.testing import parameterized @@ -31,9 +30,11 @@ def setUp(self): self.vocab_size = len(self.int_lookup) def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([10]) # Return a distribution favoring the next char in cache. logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 - return logits, None, cache + return logits, hidden_states, cache self.next = next self.sampler = TopKSampler(k=5) @@ -43,10 +44,17 @@ def join_as_string(self, x): def test_stateless_call(self): def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([10]) # Return a distribution favoring the first token in the vocab. - logits = np.zeros((self.batch_size, self.vocab_size)) - logits[:, 0] = 1e9 - return tf.constant(logits), None, cache + logits = ( + tf.one_hot( + tf.zeros(self.batch_size, dtype=tf.int32), + self.vocab_size, + ) + * 1e9 + ) + return logits, hidden_states, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( @@ -81,10 +89,12 @@ def test_early_stopping(self): def test_outputs_in_top_k(self): def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([10]) # Return a distribution where each id is progressively less likely. logits = tf.range(self.vocab_size, 0, -1, dtype="float32") logits = tf.repeat(logits[tf.newaxis, :], self.batch_size, axis=0) - return logits, None, cache + return logits, hidden_states, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index 167c43852b..18a78d7e87 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -46,10 +46,11 @@ class TopPSampler(Sampler): char_lookup = {v: k for k, v in int_lookup.items()} batch_size, length, vocab_size = 1, 12, len(int_lookup) - def next(prompt, state, index): + def next(prompt, cache, index): + hidden_states = tf.ones((batch_size, 1, 10)) # A uniform distribution over our alphabet. logits = tf.ones((batch_size, vocab_size)) - return logits, state + return logits, hidden_states, cache output = keras_nlp.samplers.TopPSampler(p=0.1)( next=next, diff --git a/keras_nlp/samplers/top_p_sampler_test.py b/keras_nlp/samplers/top_p_sampler_test.py index 5d3b59c0d9..271679cfa2 100644 --- a/keras_nlp/samplers/top_p_sampler_test.py +++ b/keras_nlp/samplers/top_p_sampler_test.py @@ -31,9 +31,11 @@ def setUp(self): self.vocab_size = len(self.int_lookup) def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([10]) # Return a distribution favoring the next char in cache. logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 - return logits, None, cache + return logits, hidden_states, cache self.next = next self.sampler = TopPSampler(p=0.1) @@ -43,10 +45,17 @@ def join_as_string(self, x): def test_stateless_call(self): def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([10]) # Return a distribution favoring the first token in the vocab. - logits = np.zeros((self.batch_size, self.vocab_size)) - logits[:, 0] = 1e9 - return tf.constant(logits), None, cache + logits = ( + tf.one_hot( + tf.zeros(self.batch_size, dtype=tf.int32), + self.vocab_size, + ) + * 1e9 + ) + return logits, hidden_states, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( @@ -81,8 +90,10 @@ def test_early_stopping(self): def test_outputs_in_top_p(self): def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([10]) logits = np.zeros((self.batch_size, self.vocab_size)) - return tf.constant(logits), None, cache + return tf.constant(logits), hidden_states, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = TopPSampler(p=(2.0 / self.vocab_size))( From e5877b35be4b85c628ef23bb0d22ad97fd351b1c Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 5 Apr 2023 18:18:18 -0700 Subject: [PATCH 07/11] fix comments --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 16 +++--- keras_nlp/samplers/beam_sampler_test.py | 6 ++- keras_nlp/samplers/contrastive_sampler.py | 40 +++++++++----- .../samplers/contrastive_sampler_test.py | 52 +++++++++++++++++-- keras_nlp/samplers/greedy_sampler_test.py | 6 +-- keras_nlp/samplers/top_k_sampler_test.py | 6 +-- keras_nlp/samplers/top_p_sampler_test.py | 6 +-- 7 files changed, 98 insertions(+), 34 deletions(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 072eb6ed04..c08b2686bc 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -231,8 +231,8 @@ def call_with_cache( Returns: A (logits, hidden_states, cache) tuple. Where `logits` is the language model logits for the input token_ids, `hidden_states` is - the hidden state of the input (the layer before embedding matrix - mapping), and `cache` is the decoding cache. + the final hidden representation of the input tokens, and `cache` is + the decoding cache. """ token_embedding = self.backbone.get_layer("token_embedding")(token_ids) position_embedding = self.backbone.get_layer("position_embedding")( @@ -255,12 +255,12 @@ def call_with_cache( cache = tf.stack(caches, axis=1) x = self.backbone.get_layer("layer_norm")(x) hidden_states = x - x = tf.matmul( - x, + logits = tf.matmul( + hidden_states, self.backbone.get_layer("token_embedding").embeddings, transpose_b=True, ) - return x, hidden_states, cache + return logits, hidden_states, cache def _build_cache(self, prompt): """Build an empty cache for use with `call_with_cache()`.""" @@ -312,7 +312,11 @@ def next(prompt, cache, index): cache, cache_index, ) - return tf.squeeze(logits, axis=1), hidden_states, cache + return ( + tf.squeeze(logits, axis=1), + tf.squeeze(hidden_states, axis=1), + cache, + ) return self._sampler( next=next, diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index a7820cf4bf..5ebfe60327 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -30,8 +30,9 @@ def setUp(self): self.vocab_size = len(self.int_lookup) def next(prompt, cache, index): + batch_size = tf.shape(prompt)[0] # Dummy hidden states. - hidden_states = tf.ones([10]) + hidden_states = tf.ones([batch_size, 5]) # Return a distribution favoring the next char in cache. logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 return logits, hidden_states, cache @@ -45,8 +46,9 @@ def join_as_string(self, x): def test_stateless_call(self): def next(prompt, cache, index): + batch_size = tf.shape(prompt)[0] # Dummy hidden states. - hidden_states = tf.ones([10]) + hidden_states = tf.ones([batch_size, 5]) # Return a distribution favoring the first token in the vocab. logits = ( tf.one_hot( diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index 5310b58240..45629457c7 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -86,12 +86,17 @@ def __call__( self, next, prompt, - hidden_states, cache=None, index=0, mask=None, end_token_id=None, + hidden_states=None, ): + if hidden_states is None: + raise ValueError( + "`ContrastiveSampler` requires passing a `hidden_states`, but" + "received `None`." + ) batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] # Make sure max length and start index are the same dtype. index = tf.cast(index, max_length.dtype) @@ -162,12 +167,11 @@ def body(prompt, cache, index, logits, hidden_states): # Compute the max similarity score for top-k candidate tokens # against previous tokens. last_token_state = next_hidden_states_beams - previous_states = hidden_states_beams[:, :index, :] similarity_scores = self.similarity( - previous_states, last_token_state + hidden_states_beams, last_token_state ) max_similarity_scores = tf.cast( - tf.reduce_max(similarity_scores, axis=1), + tf.reduce_max(similarity_scores[:, :index], axis=1), dtype=next_token_probabilities.dtype, ) if index == 0: @@ -189,21 +193,28 @@ def body(prompt, cache, index, logits, hidden_states): next_hidden_states_beams ) unflat_cache = tf.nest.map_structure(unflatten_beams, cache_beams) - win_token_indices = tf.math.argmax(unflat_score, axis=1) + best_token_indices = tf.math.argmax(unflat_score, axis=1) - def gather_win_token(beams): - return tf.gather(beams, win_token_indices, axis=1, batch_dims=1) + def gather_best_token(beams): + return tf.gather( + beams, + best_token_indices, + axis=1, + batch_dims=1, + ) - prompt = gather_win_token(unflat_prompt) + prompt = gather_best_token(unflat_prompt) # We avoid recomputing forward pass for each token by updating the # cache/hidden_states using the output, and pass the logits to # next iteration step. - logits = gather_win_token(unflat_next_logits) - next_hidden_states = gather_win_token(unflat_next_hidden_states) - cache = tf.nest.map_structure(gather_win_token, unflat_cache) + logits = gather_best_token(unflat_next_logits) + next_hidden_states = gather_best_token(unflat_next_hidden_states) + cache = tf.nest.map_structure(gather_best_token, unflat_cache) hidden_states = dynamic_update_slice( - hidden_states, next_hidden_states, [0, index, 0] + hidden_states, + next_hidden_states[:, tf.newaxis, :], + [0, index, 0], ) return (prompt, cache, index + 1, logits, hidden_states) @@ -216,8 +227,9 @@ def gather_win_token(beams): return prompt def similarity(self, h1, h2): - return tf.squeeze(tf.matmul(h1, h2, transpose_b=True), axis=-1) / ( - tf.norm(h1, axis=-1) * tf.norm(h2, axis=-1) + h2 = h2[..., tf.newaxis] + return tf.squeeze(tf.matmul(h1, h2), axis=-1) / ( + tf.norm(h1, axis=-1) * tf.norm(h2, axis=-2) ) def get_config(self): diff --git a/keras_nlp/samplers/contrastive_sampler_test.py b/keras_nlp/samplers/contrastive_sampler_test.py index 18e7620007..34c38f1986 100644 --- a/keras_nlp/samplers/contrastive_sampler_test.py +++ b/keras_nlp/samplers/contrastive_sampler_test.py @@ -41,7 +41,7 @@ def next(prompt, cache, index): batch_size = tf.shape(prompt)[0] # Return a distribution favoring the next char in cache. logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 - hidden_states = tf.ones([batch_size, 1, self.hidden_dim]) + hidden_states = tf.ones([batch_size, self.hidden_dim]) return logits, hidden_states, cache self.next = next @@ -61,7 +61,7 @@ def next(prompt, cache, index): ) * 1e9 ) - hidden_states = tf.ones([batch_size, 1, self.hidden_dim]) + hidden_states = tf.ones([batch_size, self.hidden_dim]) return logits, hidden_states, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) @@ -106,7 +106,7 @@ def next(prompt, cache, index): # Return a distribution where each id is progressively less likely. logits = tf.range(self.vocab_size, 0, -1, dtype="float32") logits = tf.repeat(logits[tf.newaxis, :], batch_size, axis=0) - hidden_states = tf.ones([batch_size, 1, self.hidden_dim]) + hidden_states = tf.ones([batch_size, self.hidden_dim]) return logits, hidden_states, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) @@ -119,6 +119,52 @@ def next(prompt, cache, index): output_ids = set(output[0, 1:].numpy()) self.assertContainsSubset(output_ids, range(5)) + def test_alpha_penalty_work(self): + def next(prompt, cache, index): + batch_size = tf.shape(prompt)[0] + best_token_id = 7 + logits = tf.ones([batch_size, self.vocab_size]) + # Favoring `best_token_id` in the logits. + logits += ( + tf.one_hot( + tf.zeros(self.batch_size, dtype=tf.int32) + best_token_id, + self.vocab_size, + ) + * 1e9 + ) + + # Set the hidden states for `best_token_id` as [1, 1, ..., 1], so it + # gets the max similarity penality score. + mask_of_best_token = prompt[:, index - 1] == best_token_id + random_states = tf.random.uniform([batch_size, self.hidden_dim]) * ( + 1 - tf.cast(mask_of_best_token, dtype=tf.float32)[:, tf.newaxis] + ) + hidden_states = ( + tf.ones([batch_size, self.hidden_dim]) + * tf.cast(mask_of_best_token, dtype=tf.float32)[:, tf.newaxis] + ) + hidden_states = hidden_states + random_states + return logits, hidden_states, cache + + prompt = tf.fill((1, self.length), self.char_lookup["z"]) + hidden_states = tf.ones([1, self.length, self.hidden_dim]) + 1e-5 + output = self.sampler( + next=next, + prompt=prompt, + index=5, + hidden_states=hidden_states, + ) + self.assertEqual(self.join_as_string(output), ["zzzzzhhhhhhh"]) + + sampler = ContrastiveSampler(k=5, alpha=1.0) + output = sampler( + next=next, + prompt=prompt, + index=5, + hidden_states=hidden_states, + ) + self.assertTrue("h" not in self.join_as_string(output)) + @parameterized.named_parameters( ("jit_compile_false", False), ("jit_compile_true", True) ) diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py index a20486c50d..f45902f525 100644 --- a/keras_nlp/samplers/greedy_sampler_test.py +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -31,7 +31,7 @@ def setUp(self): def next(prompt, cache, index): # Dummy hidden states. - hidden_states = tf.ones([10]) + hidden_states = tf.ones([self.batch_size, 5]) # Return a distribution favoring the next char in cache. logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 return logits, hidden_states, cache @@ -45,7 +45,7 @@ def join_as_string(self, x): def test_stateless_call(self): def next(prompt, cache, index): # Dummy hidden states. - hidden_states = tf.ones([10]) + hidden_states = tf.ones([self.batch_size, 5]) # Return a distribution favoring the first token in the vocab. logits = ( tf.one_hot( @@ -90,7 +90,7 @@ def test_early_stopping(self): def test_is_greedy(self): def next(prompt, cache, index): # Dummy hidden states. - hidden_states = tf.ones([10]) + hidden_states = tf.ones([self.batch_size, 5]) # Return a distribution where each id is progressively less likely. logits = tf.range(self.vocab_size, 0, -1, dtype="float32") logits = tf.repeat(logits[tf.newaxis, :], self.batch_size, axis=0) diff --git a/keras_nlp/samplers/top_k_sampler_test.py b/keras_nlp/samplers/top_k_sampler_test.py index 31c0a06d02..10ce77b956 100644 --- a/keras_nlp/samplers/top_k_sampler_test.py +++ b/keras_nlp/samplers/top_k_sampler_test.py @@ -31,7 +31,7 @@ def setUp(self): def next(prompt, cache, index): # Dummy hidden states. - hidden_states = tf.ones([10]) + hidden_states = tf.ones([self.batch_size, 5]) # Return a distribution favoring the next char in cache. logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 return logits, hidden_states, cache @@ -45,7 +45,7 @@ def join_as_string(self, x): def test_stateless_call(self): def next(prompt, cache, index): # Dummy hidden states. - hidden_states = tf.ones([10]) + hidden_states = tf.ones([self.batch_size, 5]) # Return a distribution favoring the first token in the vocab. logits = ( tf.one_hot( @@ -90,7 +90,7 @@ def test_early_stopping(self): def test_outputs_in_top_k(self): def next(prompt, cache, index): # Dummy hidden states. - hidden_states = tf.ones([10]) + hidden_states = tf.ones([self.batch_size, 5]) # Return a distribution where each id is progressively less likely. logits = tf.range(self.vocab_size, 0, -1, dtype="float32") logits = tf.repeat(logits[tf.newaxis, :], self.batch_size, axis=0) diff --git a/keras_nlp/samplers/top_p_sampler_test.py b/keras_nlp/samplers/top_p_sampler_test.py index 271679cfa2..68afbb283a 100644 --- a/keras_nlp/samplers/top_p_sampler_test.py +++ b/keras_nlp/samplers/top_p_sampler_test.py @@ -32,7 +32,7 @@ def setUp(self): def next(prompt, cache, index): # Dummy hidden states. - hidden_states = tf.ones([10]) + hidden_states = tf.ones([self.batch_size, 5]) # Return a distribution favoring the next char in cache. logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 return logits, hidden_states, cache @@ -46,7 +46,7 @@ def join_as_string(self, x): def test_stateless_call(self): def next(prompt, cache, index): # Dummy hidden states. - hidden_states = tf.ones([10]) + hidden_states = tf.ones([self.batch_size, 5]) # Return a distribution favoring the first token in the vocab. logits = ( tf.one_hot( @@ -91,7 +91,7 @@ def test_early_stopping(self): def test_outputs_in_top_p(self): def next(prompt, cache, index): # Dummy hidden states. - hidden_states = tf.ones([10]) + hidden_states = tf.ones([self.batch_size, 5]) logits = np.zeros((self.batch_size, self.vocab_size)) return tf.constant(logits), hidden_states, cache From 95a86e5bb17beb4f80e38ced37e6d864ff4c01fe Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 6 Apr 2023 14:54:41 -0700 Subject: [PATCH 08/11] small fix --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 1 - keras_nlp/samplers/serialization.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index c08b2686bc..110853f7ee 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -24,7 +24,6 @@ ) from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.models.task import Task -from keras_nlp.samplers import ContrastiveSampler from keras_nlp.samplers.serialization import get as get_sampler from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty diff --git a/keras_nlp/samplers/serialization.py b/keras_nlp/samplers/serialization.py index 56adb7d226..a1becf23c5 100644 --- a/keras_nlp/samplers/serialization.py +++ b/keras_nlp/samplers/serialization.py @@ -16,6 +16,7 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.samplers.beam_sampler import BeamSampler +from keras_nlp.samplers.contrastive_sampler import ContrastiveSampler from keras_nlp.samplers.greedy_sampler import GreedySampler from keras_nlp.samplers.random_sampler import RandomSampler from keras_nlp.samplers.top_k_sampler import TopKSampler @@ -32,6 +33,7 @@ def deserialize(config, custom_objects=None): """Return a `Sampler` object from its config.""" all_classes = { "beam": BeamSampler, + "contrastive": ContrastiveSampler, "greedy": GreedySampler, "random": RandomSampler, "top_k": TopKSampler, From 379d2430588c2a432682a315fc9b8613827bc054 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 6 Apr 2023 18:11:39 -0700 Subject: [PATCH 09/11] merge master --- keras_nlp/samplers/beam_sampler_test.py | 6 ++-- keras_nlp/samplers/random_sampler_test.py | 36 +++++++++++++---------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index 5ebfe60327..ff0d671834 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -79,13 +79,13 @@ def test_stateful_call(self): self.assertEqual(self.join_as_string(output), ["sequentially"]) def test_return_all_beams(self): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) sorted_prompts, sorted_log_probs = self.sampler_all_beams( next=self.next, prompt=prompt, - state=state, + cache=cache, ) self.assertEqual( diff --git a/keras_nlp/samplers/random_sampler_test.py b/keras_nlp/samplers/random_sampler_test.py index 97f41b638b..2c08b09780 100644 --- a/keras_nlp/samplers/random_sampler_test.py +++ b/keras_nlp/samplers/random_sampler_test.py @@ -30,10 +30,12 @@ def setUp(self): self.length = 12 self.vocab_size = len(self.int_lookup) - def next(prompt, state, index): + def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([self.batch_size, 5]) # Return a distribution favoring the next char in state. - logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9 - return logits, state + logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9 + return logits, hidden_states, cache self.next = next self.sampler = RandomSampler() @@ -42,11 +44,13 @@ def join_as_string(self, x): return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] def test_stateless_call(self): - def next(prompt, state, index): + def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([self.batch_size, 5]) # Return a distribution favoring the first token in the vocab. logits = np.zeros((self.batch_size, self.vocab_size)) logits[:, 0] = 1e9 - return tf.constant(logits), state + return tf.constant(logits), hidden_states, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( @@ -57,24 +61,24 @@ def next(prompt, state, index): self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"]) def test_stateful_call(self): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=self.next, prompt=prompt, - state=state, + cache=cache, ) self.assertEqual(self.join_as_string(output), ["sequentially"]) def test_early_stopping(self): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) output = self.sampler( next=self.next, prompt=prompt, - state=state, + cache=cache, end_token_id=self.char_lookup["t"], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) @@ -83,13 +87,13 @@ def test_early_stopping(self): ("jit_compile_false", False), ("jit_compile_true", True) ) def test_compilation(self, jit_compile): - state_chars = list("sequentially") - state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + cache_chars = list("sequentially") + cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) @tf.function(jit_compile=jit_compile) - def generate(prompt, state): - return self.sampler(self.next, prompt=prompt, state=state) + def generate(prompt, cache): + return self.sampler(self.next, prompt=prompt, cache=cache) - output = generate(prompt, state) + output = generate(prompt, cache) self.assertEqual(self.join_as_string(output), ["sequentially"]) From 27c8ae1d91ae054eb2619ba19df9d81c61638c21 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Fri, 7 Apr 2023 14:57:08 -0700 Subject: [PATCH 10/11] one more pass --- keras_nlp/samplers/beam_sampler.py | 2 +- keras_nlp/samplers/contrastive_sampler.py | 9 ++++----- keras_nlp/samplers/contrastive_sampler_test.py | 4 ++-- keras_nlp/samplers/greedy_sampler.py | 2 +- keras_nlp/samplers/random_sampler.py | 3 ++- keras_nlp/samplers/top_k_sampler.py | 2 +- keras_nlp/samplers/top_p_sampler.py | 2 +- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 46b571be88..43c333d8d4 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -52,7 +52,7 @@ class BeamSampler(Sampler): def next(prompt, cache, index): prompt_batch_size = tf.shape(prompt)[0] - hidden_states = tf.ones((prompt_batch_size, 1, 10)) + hidden_states = tf.ones((prompt_batch_size, 10)) # A uniform distribution over our alphabet. logits = tf.ones((prompt_batch_size, vocab_size)) return logits, hidden_states, cache diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index 45629457c7..e89f4d2b0c 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -55,7 +55,7 @@ class ContrastiveSampler(Sampler): def next(prompt, cache, index): prompt_batch_size = tf.shape(prompt)[0] - hidden_states = tf.ones((prompt_batch_size, 1, hidden_size)) + hidden_states = tf.ones((prompt_batch_size, hidden_size)) # A uniform distribution over our alphabet. logits = tf.ones((prompt_batch_size, vocab_size)) return logits, hidden_states, cache @@ -67,14 +67,14 @@ def next(prompt, cache, index): hidden_states=tf.ones([batch_size, index, hidden_size]), ) print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> "zzzzzaaaaaaa" + # >>> "zzzzzeeeeeee" ``` """ def __init__( self, k=5, - alpha=0.2, + alpha=0.6, seed=None, ): super().__init__() @@ -166,9 +166,8 @@ def body(prompt, cache, index, logits, hidden_states): # Compute the max similarity score for top-k candidate tokens # against previous tokens. - last_token_state = next_hidden_states_beams similarity_scores = self.similarity( - hidden_states_beams, last_token_state + hidden_states_beams, next_hidden_states_beams ) max_similarity_scores = tf.cast( tf.reduce_max(similarity_scores[:, :index], axis=1), diff --git a/keras_nlp/samplers/contrastive_sampler_test.py b/keras_nlp/samplers/contrastive_sampler_test.py index 34c38f1986..f19f4cb17f 100644 --- a/keras_nlp/samplers/contrastive_sampler_test.py +++ b/keras_nlp/samplers/contrastive_sampler_test.py @@ -119,10 +119,10 @@ def next(prompt, cache, index): output_ids = set(output[0, 1:].numpy()) self.assertContainsSubset(output_ids, range(5)) - def test_alpha_penalty_work(self): + def test_alpha_penalty(self): def next(prompt, cache, index): batch_size = tf.shape(prompt)[0] - best_token_id = 7 + best_token_id = self.int_lookup("h") logits = tf.ones([batch_size, self.vocab_size]) # Favoring `best_token_id` in the logits. logits += ( diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index 0e46fa6205..9d4cf7389e 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -40,7 +40,7 @@ class GreedySampler(Sampler): batch_size, length, vocab_size = 1, 12, len(int_lookup) def next(prompt, cache, index): - hidden_states = tf.ones((batch_size, 1, 10)) + hidden_states = tf.ones((batch_size, 10)) # A uniform distribution over our alphabet. logits = tf.ones((batch_size, vocab_size)) return logits, hidden_states, cache diff --git a/keras_nlp/samplers/random_sampler.py b/keras_nlp/samplers/random_sampler.py index 17030293c8..9d2c731757 100644 --- a/keras_nlp/samplers/random_sampler.py +++ b/keras_nlp/samplers/random_sampler.py @@ -44,9 +44,10 @@ class RandomSampler(Sampler): batch_size, length, vocab_size = 1, 12, len(int_lookup) def next(prompt, state, index): + hidden_states = tf.ones((batch_size, 10)) # A uniform distribution over our alphabet. logits = tf.ones((batch_size, vocab_size)) - return logits, state + return logits, hidden_states, state output = keras_nlp.samplers.RandomSampler()( next=next, diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index f3782462e4..68f369a996 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -45,7 +45,7 @@ class TopKSampler(Sampler): batch_size, length, vocab_size = 1, 12, len(int_lookup) def next(prompt, cache, index): - hidden_states = tf.ones((batch_size, 1, 10)) + hidden_states = tf.ones((batch_size, 10)) # A uniform distribution over our alphabet. logits = tf.ones((batch_size, vocab_size)) return logits, hidden_states, cache diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index 18a78d7e87..f38331b9a3 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -47,7 +47,7 @@ class TopPSampler(Sampler): batch_size, length, vocab_size = 1, 12, len(int_lookup) def next(prompt, cache, index): - hidden_states = tf.ones((batch_size, 1, 10)) + hidden_states = tf.ones((batch_size, 10)) # A uniform distribution over our alphabet. logits = tf.ones((batch_size, vocab_size)) return logits, hidden_states, cache From 4eb17c4584b74528a3e89e6a7842862fbbbfeabe Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Fri, 7 Apr 2023 16:08:14 -0700 Subject: [PATCH 11/11] fix tests --- keras_nlp/samplers/contrastive_sampler_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/samplers/contrastive_sampler_test.py b/keras_nlp/samplers/contrastive_sampler_test.py index f19f4cb17f..8981c809cb 100644 --- a/keras_nlp/samplers/contrastive_sampler_test.py +++ b/keras_nlp/samplers/contrastive_sampler_test.py @@ -122,7 +122,7 @@ def next(prompt, cache, index): def test_alpha_penalty(self): def next(prompt, cache, index): batch_size = tf.shape(prompt)[0] - best_token_id = self.int_lookup("h") + best_token_id = self.char_lookup["h"] logits = tf.ones([batch_size, self.vocab_size]) # Favoring `best_token_id` in the logits. logits += (