Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 95 additions & 63 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@

import tensorflow as tf

import keras_nlp
from keras_nlp import samplers
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import (
GPT2CausalLMPreprocessor,
)
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
from keras_nlp.models.task import Task
from keras_nlp.samplers import BeamSampler
from keras_nlp.samplers import serialize
from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.tf_utils import truncate_at_token


@keras_nlp_export("keras_nlp.models.GPT2CasualLM")
Expand All @@ -37,8 +37,13 @@ class GPT2CausalLM(Task):
A causal language model (LM) predicts the next token based on previous
tokens the next token based on previous tokens, which is the way GPT2 gets
pretrained. You can finetune `GPT2CausalLM` to generate text similar to
the custom dataset. `GPT2CausalLM` also has a method `generate()`, which
generates text based on given prompt.
the custom dataset.

This model has a `generate()` method, which generates text based on a
prompt. The generation strategy used is controlled by an additional
`sampler` argument on `compile()`. You can recompile the model with
different `keras_nlp.samplers` objects to control the generation. By
default, `"top_k"` sampling will be used.

This model can optionally be configured with a `preprocessor` layer, in
which case it will automatically apply preprocessing to raw inputs during
Expand Down Expand Up @@ -67,15 +72,13 @@ class GPT2CausalLM(Task):
gpt2_lm.generate(["This is a", "Where are you"], max_length=30)
```

Use a custom sampler for text generation.
Compile the `generate()` function with custom samplers.
```python
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.compile(sampler="top_p")
gpt2_lm.generate("I want to say", max_length=30)

# Use string identifier to set sampler.
gpt2_lm.generate("I want to say", max_length=30, sampler="top_p")

# Construct a sampler instance.
sampler = keras_nlp.samplers.BeamSampler(num_beams=2)
gpt2_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=2))
gpt2_lm.generate("I want to say", max_length=30, sampler=sampler)
```

Expand Down Expand Up @@ -189,7 +192,9 @@ def __init__(

self.backbone = backbone
self.preprocessor = preprocessor
self.sampler = None
self.generate_function = None
# Private sampler set by compile.
self._sampler = samplers.get("top_k")

@classproperty
def presets(cls):
Expand All @@ -203,7 +208,7 @@ def backbone_cls(cls):
def preprocessor_cls(cls):
return GPT2CausalLMPreprocessor

def call_with_cache(self, token_ids, padding_mask, cache, cache_index):
def call_with_cache(self, token_ids, cache, cache_index):
"""Forward pass of `GPT2CausalLM` with cache.

`call_with_cache` adds an additional forward pass for the model for
Expand All @@ -213,7 +218,6 @@ def call_with_cache(self, token_ids, padding_mask, cache, cache_index):

Args:
token_ids: a dense int Tensor, input token ids.
padding_mask: a dense bool Tensor, input padding mask.
cache: a dense float Tensor, the cache of key and value.
cache_index: int, or int Tensor. The index of current inputs in the
whole sequence.
Expand All @@ -237,7 +241,6 @@ def call_with_cache(self, token_ids, padding_mask, cache, cache_index):
current_cache = caches[i]
x, next_cache = self.backbone.get_layer(f"transformer_layer_{i}")(
x,
decoder_padding_mask=padding_mask,
cache=current_cache,
cache_index=cache_index,
)
Expand All @@ -251,30 +254,78 @@ def call_with_cache(self, token_ids, padding_mask, cache, cache_index):
)
return x, cache

def build_empty_cache(self, batch_size, max_length):
def _build_cache(self, prompt):
"""Build an empty cache for use with `call_with_cache()`."""
batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1]
num_layers = self.backbone.num_layers
num_heads = self.backbone.num_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_heads
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
return tf.zeros(shape)
cache = tf.zeros(shape)
# Seed the cache.
_, cache = self.call_with_cache(prompt, cache, 0)
return cache

def _get_token_probability(
def compile(
self,
prompt,
mask,
cache=None,
cache_index=None,
*args,
run_eagerly=False,
jit_compile=True,
sampler="top_k",
**kwargs,
):
batch_size = tf.shape(prompt)[0]
prompt = tf.slice(prompt, [0, cache_index], [batch_size, 1])
return self.call_with_cache(prompt, mask, cache, cache_index)
xla_compatible = is_xla_compatible(self)
super().compile(
*args,
run_eagerly=run_eagerly,
# Only `jit_compile` if not eager and in a compatible environment.
jit_compile=jit_compile and xla_compatible and not run_eagerly,
**kwargs,
)
self._sampler = samplers.get(sampler)
# Clear the compiled generate function.
self.generate_function = None

def make_generate_function(self):
"""Create or return the compiled generation function."""
if self.generate_function is not None:
return self.generate_function

def generate_function(prompt, input_mask, min_length):
# Create and seed cache with a single forward pass.
cache = self._build_cache(prompt)

def next(prompt, state, index):
# The cache index is the index of our previous token.
cache_index = index - 1
prompt = tf.slice(prompt, [0, cache_index], [-1, 1])
logits, state = self.call_with_cache(prompt, state, cache_index)
return tf.squeeze(logits, axis=1), state

return self._sampler(
next=next,
prompt=prompt,
state=cache,
index=min_length,
mask=input_mask,
end_token_id=self.preprocessor.tokenizer.end_token_id,
)

if self.run_eagerly:
self.generate_function = generate_function
else:
# `jit_compile` is a property of keras.Model after TF 2.12.
# Use `getattr()` for backwards compatibility.
jit_compile = getattr(self, "jit_compile", True)
self.generate_function = tf.function(
generate_function, jit_compile=jit_compile
)
return self.generate_function

def generate(
self,
prompt,
max_length,
sampler="top_k",
):
"""Generate text.

Expand All @@ -295,46 +346,27 @@ def generate(
"`self.preprocessor` is `None`, please make sure "
"`preprocessor` is set before calling `generate`."
)
sampler = keras_nlp.samplers.get(sampler)
if sampler.__class__ == BeamSampler:
raise ValueError(
"`BeamSampler` is not supported right now, please choose "
"another sampler, e.g., `TopPSampler`."
)
if hasattr(self, "jit_compile"):
# `jit_compile` is a public property as of tf 2.12. hasattr is for
# backward compat.
sampler.jit_compile = self.jit_compile
sampler.run_eagerly = self.run_eagerly
if self.sampler and serialize(sampler) == serialize(self.sampler):
# If the new sampler is the same as the older one, we reuse the old
# sampler to avoid recompile.
sampler = self.sampler
else:
self.sampler = sampler

# Tokenize.
prompt = tf.convert_to_tensor(prompt)
input_is_scalar = prompt.shape.rank == 0
prompt = prompt[tf.newaxis] if input_is_scalar else prompt
prompt = self.preprocessor.tokenizer(prompt)

# Create and seed the cache before generation.
token_ids = prompt
if prompt.shape.rank == 1:
token_ids = tf.RaggedTensor.from_tensor(prompt[tf.newaxis, :])
token_ids = token_ids.to_tensor(shape=(None, max_length))
# Pass a padding mask of all ones when seeing the cache. The mask will
# not affect cached key/values for input tokens we care about.
padding_mask = tf.ones_like(token_ids, dtype=tf.bool)
batch_size = tf.shape(token_ids)[0]
cache = self.build_empty_cache(batch_size, max_length)
_, cache = self.call_with_cache(token_ids, padding_mask, cache, 0)
# Run generation.
generated = sampler(
prompt,
self._get_token_probability,
max_length=max_length,
end_token_id=self.preprocessor.tokenizer.end_token_id,
cache=cache,
)
# Pad ragged to dense tensors.
padded_shape = (None, max_length)
min_length = tf.reduce_min(prompt.row_lengths())
input_mask = tf.ones_like(prompt, tf.bool).to_tensor(shape=padded_shape)
prompt = prompt.to_tensor(shape=padded_shape)

# Run the (possibly compiled) generate function on dense inputs.
generate_function = self.make_generate_function()
output = generate_function(prompt, input_mask, min_length)

# Truncate to ragged by removing tokens after the first end token.
end_token_id = self.preprocessor.tokenizer.end_token_id
output = truncate_at_token(output, end_token_id, input_mask)

# Detokenize.
return self.preprocessor.tokenizer.detokenize(generated)
output = self.preprocessor.tokenizer.detokenize(output)
return tf.squeeze(output, 0) if input_is_scalar else output
16 changes: 10 additions & 6 deletions keras_nlp/models/gpt2/gpt2_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,25 +122,29 @@ def test_gpt2_causal_lm_fit_no_preprocessing(self, jit_compile):
self.causal_lm_no_preprocessing.fit(self.preprocessed_dataset)

@parameterized.named_parameters(
("non_jit_compile_cache", False, True),
("non_jit_compile_non_cache", False, False),
("jit_compile_non_cache", True, False),
("jit_compile_false", False), ("jit_compile_true", True)
)
def test_gpt2_causal_lm_generate(self, jit_compile, use_cache):
def test_generate(self, jit_compile):
# Tensor input.
self.causal_lm.compile(jit_compile=jit_compile)
self.causal_lm.generate(
self.raw_batch,
max_length=10,
)

# String input
first_fn = self.causal_lm.generate_function
# String input.
prompt = " airplane"
generated = self.causal_lm.generate(
prompt,
max_length=10,
)
generated = generated.numpy().decode("utf-8")
self.assertTrue(prompt in generated)
second_fn = self.causal_lm.generate_function
# Assert we did not recompile.
self.assertEqual(first_fn, second_fn)
self.causal_lm.compile(sampler="greedy")
self.assertIsNone(self.causal_lm.generate_function)

@parameterized.named_parameters(
("tf_format", "tf", "model"),
Expand Down
Loading