Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,12 @@ def backbone_cls(cls):
def preprocessor_cls(cls):
return GPT2CausalLMPreprocessor

def call_with_cache(self, token_ids, cache, cache_index):
def call_with_cache(
self,
token_ids,
cache,
cache_index,
):
"""Forward pass of `GPT2CausalLM` with cache.

`call_with_cache` adds an additional forward pass for the model for
Expand All @@ -223,9 +228,10 @@ def call_with_cache(self, token_ids, cache, cache_index):
whole sequence.

Returns:
A (logits, cache) tuple. Where the first output is the language
model logits for the input token_ids and the second output is the
cache.
A (logits, hidden_states, cache) tuple. Where `logits` is the
language model logits for the input token_ids, `hidden_states` is
the final hidden representation of the input tokens, and `cache` is
the decoding cache.
"""
token_embedding = self.backbone.get_layer("token_embedding")(token_ids)
position_embedding = self.backbone.get_layer("position_embedding")(
Expand All @@ -247,12 +253,13 @@ def call_with_cache(self, token_ids, cache, cache_index):
caches[i] = next_cache
cache = tf.stack(caches, axis=1)
x = self.backbone.get_layer("layer_norm")(x)
x = tf.matmul(
x,
hidden_states = x
logits = tf.matmul(
hidden_states,
self.backbone.get_layer("token_embedding").embeddings,
transpose_b=True,
)
return x, cache
return logits, hidden_states, cache

def _build_cache(self, prompt):
"""Build an empty cache for use with `call_with_cache()`."""
Expand All @@ -263,8 +270,8 @@ def _build_cache(self, prompt):
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
cache = tf.zeros(shape, dtype=self.compute_dtype)
# Seed the cache.
_, cache = self.call_with_cache(prompt, cache, 0)
return cache
_, hidden_states, cache = self.call_with_cache(prompt, cache, 0)
return hidden_states, cache

def compile(
self,
Expand Down Expand Up @@ -293,22 +300,31 @@ def make_generate_function(self):

def generate_function(prompt, input_mask, min_length):
# Create and seed cache with a single forward pass.
cache = self._build_cache(prompt)
hidden_states, cache = self._build_cache(prompt)

def next(prompt, state, index):
def next(prompt, cache, index):
# The cache index is the index of our previous token.
cache_index = index - 1
prompt = tf.slice(prompt, [0, cache_index], [-1, 1])
logits, state = self.call_with_cache(prompt, state, cache_index)
return tf.squeeze(logits, axis=1), state
logits, hidden_states, cache = self.call_with_cache(
prompt,
cache,
cache_index,
)
return (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would actually comment this shape manipulation so people can follow...

# Remove the sequence dim, so shape is `(batch_size, vocab_size)`.
logits = tf.squeeze(logits, axis=1)
# Remove the sequence dim, so shape is `(batch_size, hidden_dim)`.
hidden_states = tf.squeeze(logits, axis=1)
return logits, hidden_state, cache

tf.squeeze(logits, axis=1),
tf.squeeze(hidden_states, axis=1),
cache,
)

return self._sampler(
next=next,
prompt=prompt,
state=cache,
cache=cache,
index=min_length,
mask=input_mask,
end_token_id=self.preprocessor.tokenizer.end_token_id,
hidden_states=hidden_states,
)

if self.run_eagerly:
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tensorflow import keras

from keras_nlp.samplers.beam_sampler import BeamSampler
from keras_nlp.samplers.contrastive_sampler import ContrastiveSampler
from keras_nlp.samplers.greedy_sampler import GreedySampler
from keras_nlp.samplers.random_sampler import RandomSampler
from keras_nlp.samplers.sampler import Sampler
Expand Down
29 changes: 16 additions & 13 deletions keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,16 @@ class BeamSampler(Sampler):
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)

def next(prompt, state, index):
def next(prompt, cache, index):
prompt_batch_size = tf.shape(prompt)[0]
hidden_states = tf.ones((prompt_batch_size, 10))
# A uniform distribution over our alphabet.
logits = tf.ones((batch_size, vocab_size))
return logits, state
logits = tf.ones((prompt_batch_size, vocab_size))
return logits, hidden_states, cache

output = keras_nlp.samplers.BeamSampler()(
next=next,
prompt=tf.fill((batch_size, length,), char_lookup['z']),
prompt=tf.fill((batch_size, length), char_lookup["z"]),
index=5,
)
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
Expand Down Expand Up @@ -104,10 +106,11 @@ def __call__(
self,
next,
prompt,
state=None,
cache=None,
index=0,
mask=None,
end_token_id=None,
hidden_states=None,
):
batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1]
# Make sure max length and start index are the same dtype.
Expand All @@ -129,26 +132,26 @@ def unflatten_beams(x):

mask = tf.zeros_like(prompt, dtype=tf.bool) if mask is None else mask
# `tf.while_loop` will not accept `None` as a value for `loop_vars`.
state = () if state is None else state
cache = () if cache is None else cache
# Add extra sequences for each beam.
prompt, mask = create_beams(prompt), create_beams(mask)
state = tf.nest.map_structure(create_beams, state)
cache = tf.nest.map_structure(create_beams, cache)
# Setup the initial beam log-likelihoods.
# On the first loop, make sure only the original beam is considered.
log_probs = tf.constant([[0.0] + [-1e9] * (self.num_beams - 1)])
log_probs = flatten_beams(tf.repeat(log_probs, batch_size, axis=0))

def cond(prompt, state, index, log_probs):
def cond(prompt, cache, index, log_probs):
if end_token_id is None:
return True
# Stop if all sequences have produced a *new* end_token_id.
end_tokens = (prompt == end_token_id) & (~mask)
prompt_done = tf.reduce_any(end_tokens, axis=-1)
return not tf.reduce_all(prompt_done)

def body(prompt, state, index, log_probs):
def body(prompt, cache, index, log_probs):
# Compute the softmax distribution for the next token.
logits, state = next(prompt, state, index)
logits, _, cache = next(prompt, cache, index)
vocab_size = tf.shape(logits)[-1]
probs = keras.activations.softmax(logits)

Expand Down Expand Up @@ -176,7 +179,7 @@ def gather_beams(x):
return flatten_beams(x)

prompt = gather_beams(prompt)
state = tf.nest.map_structure(gather_beams, state)
cache = tf.nest.map_structure(gather_beams, cache)

# Update each beam with the next token.
next_token = tf.cast(next_token, prompt.dtype)
Expand All @@ -186,12 +189,12 @@ def gather_beams(x):
next_token = next_token[:, tf.newaxis]
prompt = dynamic_update_slice(prompt, next_token, [0, index])
# Return the iteration of the loop state.
return (prompt, state, index + 1, log_probs)
return (prompt, cache, index + 1, log_probs)

prompt, _, _, log_probs = tf.while_loop(
cond=cond,
body=body,
loop_vars=(prompt, state, index, log_probs),
loop_vars=(prompt, cache, index, log_probs),
maximum_iterations=(max_length - index),
)

Expand Down
56 changes: 33 additions & 23 deletions keras_nlp/samplers/beam_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
"""Tests for Beam sampler."""

import numpy as np
import tensorflow as tf
from absl.testing import parameterized

Expand All @@ -30,10 +29,13 @@ def setUp(self):
self.length = 12
self.vocab_size = len(self.int_lookup)

def next(prompt, state, index):
# Return a distribution favoring the next char in state.
logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9
return logits, state
def next(prompt, cache, index):
batch_size = tf.shape(prompt)[0]
# Dummy hidden states.
hidden_states = tf.ones([batch_size, 5])
# Return a distribution favoring the next char in cache.
logits = tf.one_hot(cache[:, index], self.vocab_size) * 1e9
return logits, hidden_states, cache

self.next = next
self.sampler = BeamSampler(num_beams=5)
Expand All @@ -43,11 +45,19 @@ def join_as_string(self, x):
return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()]

def test_stateless_call(self):
def next(prompt, state, index):
def next(prompt, cache, index):
batch_size = tf.shape(prompt)[0]
# Dummy hidden states.
hidden_states = tf.ones([batch_size, 5])
# Return a distribution favoring the first token in the vocab.
logits = np.zeros((self.batch_size, self.vocab_size))
logits[:, 0] = 1e9
return tf.constant(logits, dtype="float32"), state
logits = (
tf.one_hot(
tf.zeros(self.batch_size, dtype=tf.int32),
self.vocab_size,
)
* 1e9
)
return logits, hidden_states, cache

prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])
output = self.sampler(
Expand All @@ -58,24 +68,24 @@ def next(prompt, state, index):
self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"])

def test_stateful_call(self):
state_chars = list("sequentially")
state = tf.constant([[self.char_lookup[c] for c in state_chars]])
cache_chars = list("sequentially")
cache = tf.constant([[self.char_lookup[c] for c in cache_chars]])
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])
output = self.sampler(
next=self.next,
prompt=prompt,
state=state,
cache=cache,
)
self.assertEqual(self.join_as_string(output), ["sequentially"])

def test_return_all_beams(self):
state_chars = list("sequentially")
state = tf.constant([[self.char_lookup[c] for c in state_chars]])
cache_chars = list("sequentially")
cache = tf.constant([[self.char_lookup[c] for c in cache_chars]])
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])
sorted_prompts, sorted_log_probs = self.sampler_all_beams(
next=self.next,
prompt=prompt,
state=state,
cache=cache,
)

self.assertEqual(
Expand All @@ -90,13 +100,13 @@ def test_return_all_beams(self):
)

def test_early_stopping(self):
state_chars = list("sequentially")
state = tf.constant([[self.char_lookup[c] for c in state_chars]])
cache_chars = list("sequentially")
cache = tf.constant([[self.char_lookup[c] for c in cache_chars]])
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])
output = self.sampler(
next=self.next,
prompt=prompt,
state=state,
cache=cache,
end_token_id=self.char_lookup["t"],
)
self.assertEqual(self.join_as_string(output), ["sequentzzzzz"])
Expand All @@ -105,13 +115,13 @@ def test_early_stopping(self):
("jit_compile_false", False), ("jit_compile_true", True)
)
def test_compilation(self, jit_compile):
state_chars = list("sequentially")
state = tf.constant([[self.char_lookup[c] for c in state_chars]])
cache_chars = list("sequentially")
cache = tf.constant([[self.char_lookup[c] for c in cache_chars]])
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])

@tf.function(jit_compile=jit_compile)
def generate(prompt, state):
return self.sampler(self.next, prompt=prompt, state=state)
def generate(prompt, cache):
return self.sampler(self.next, prompt=prompt, cache=cache)

output = generate(prompt, state)
output = generate(prompt, cache)
self.assertEqual(self.join_as_string(output), ["sequentially"])
Loading