Skip to content

Commit ba982e2

Browse files
committed
Fix tests and docstrings
1 parent 3001685 commit ba982e2

12 files changed

+458
-698
lines changed

keras_nlp/models/gpt2/gpt2_causal_lm.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
import tensorflow as tf
1919

20-
import keras_nlp
20+
from keras_nlp import samplers
2121
from keras_nlp.api_export import keras_nlp_export
2222
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
2323
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import (
2424
GPT2CausalLMPreprocessor,
2525
)
2626
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
2727
from keras_nlp.models.task import Task
28-
from keras_nlp.samplers import serialize
28+
from keras_nlp.utils.keras_utils import is_xla_compatible
2929
from keras_nlp.utils.python_utils import classproperty
3030
from keras_nlp.utils.tf_utils import truncate_at
3131

@@ -37,8 +37,12 @@ class GPT2CausalLM(Task):
3737
A causal language model (LM) predicts the next token based on previous
3838
tokens the next token based on previous tokens, which is the way GPT2 gets
3939
pretrained. You can finetune `GPT2CausalLM` to generate text similar to
40-
the custom dataset. `GPT2CausalLM` also has a method `generate()`, which
41-
generates text based on given prompt.
40+
the custom dataset.
41+
42+
`GPT2CausalLM` has a method `generate()`, which generates text based on a
43+
prompt. The generation strategy used is controlled by an additional
44+
`sampler` argument on `compile()`. You can recompile the model with
45+
different samplers to control generation.
4246
4347
This model can optionally be configured with a `preprocessor` layer, in
4448
which case it will automatically apply preprocessing to raw inputs during
@@ -67,15 +71,13 @@ class GPT2CausalLM(Task):
6771
gpt2_lm.generate(["This is a", "Where are you"], max_length=30)
6872
```
6973
70-
Use a custom sampler for text generation.
74+
Compile the `generate()` function with custom samplers.
7175
```python
7276
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
77+
gpt2_lm.compile(sampler="top_p")
78+
gpt2_lm.generate("I want to say", max_length=30)
7379
74-
# Use string identifier to set sampler.
75-
gpt2_lm.generate("I want to say", max_length=30, sampler="top_p")
76-
77-
# Construct a sampler instance.
78-
sampler = keras_nlp.samplers.BeamSampler(num_beams=2)
80+
gpt2_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=2))
7981
gpt2_lm.generate("I want to say", max_length=30, sampler=sampler)
8082
```
8183
@@ -189,8 +191,8 @@ def __init__(
189191

190192
self.backbone = backbone
191193
self.preprocessor = preprocessor
192-
self.sampler = None
193194
self.generate_function = None
195+
self.sampler = samplers.get("top_k")
194196

195197
@classproperty
196198
def presets(cls):
@@ -260,12 +262,30 @@ def build_empty_cache(self, batch_size, max_length):
260262
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
261263
return tf.zeros(shape)
262264

263-
def make_generate_function(self, sampler):
265+
def compile(
266+
self,
267+
*args,
268+
run_eagerly=False,
269+
jit_compile=True,
270+
sampler="top_k",
271+
**kwargs,
272+
):
273+
jit_compile = jit_compile and is_xla_compatible(self)
274+
jit_compile = jit_compile and not run_eagerly
275+
super().compile(
276+
*args,
277+
run_eagerly=run_eagerly,
278+
jit_compile=jit_compile,
279+
**kwargs,
280+
)
281+
# Clear the compiled generate function.
282+
self.generate_function = None
283+
self.sampler = samplers.get(sampler)
284+
285+
def make_generate_function(self):
264286
"""Create or return the compiled generation function."""
265-
# If our sampler has not changed, re-use the compiled function.
266-
if self.sampler and serialize(self.sampler) == serialize(sampler):
287+
if self.generate_function is not None:
267288
return self.generate_function
268-
self.sampler = sampler
269289

270290
def fn(prompt, input_mask, min_length, max_length):
271291
batch_size = tf.shape(prompt)[0]
@@ -284,9 +304,9 @@ def next(prompt, state, index):
284304
)
285305
return tf.squeeze(probs, axis=1), state
286306

287-
return sampler(
288-
prompt=prompt,
307+
return self.sampler(
289308
next=next,
309+
prompt=prompt,
290310
state=cache,
291311
index=min_length,
292312
mask=input_mask,
@@ -306,7 +326,6 @@ def generate(
306326
self,
307327
prompt,
308328
max_length,
309-
sampler="top_k",
310329
):
311330
"""Generate text.
312331
@@ -327,9 +346,11 @@ def generate(
327346
"`self.preprocessor` is `None`, please make sure "
328347
"`preprocessor` is set before calling `generate`."
329348
)
330-
sampler = keras_nlp.samplers.get(sampler)
331349

332350
# Tokenize.
351+
prompt = tf.convert_to_tensor(prompt)
352+
input_is_scalar = prompt.shape.rank == 0
353+
prompt = prompt[tf.newaxis] if input_is_scalar else prompt
333354
prompt = self.preprocessor.tokenizer(prompt)
334355

335356
# Pad ragged to dense tensors.
@@ -339,12 +360,13 @@ def generate(
339360
prompt = prompt.to_tensor(shape=padded_shape)
340361

341362
# Run the (possibly compiled) generate function on dense inputs.
342-
generate_function = self.make_generate_function(sampler)
363+
generate_function = self.make_generate_function()
343364
output = generate_function(prompt, input_mask, min_length, max_length)
344365

345366
# Truncate back to ragged to account for end of sequence ids.
346367
end_token_id = self.preprocessor.tokenizer.end_token_id
347368
output = truncate_at(output, end_token_id, input_mask)
348369

349370
# Detokenize.
350-
return self.preprocessor.tokenizer.detokenize(output)
371+
output = self.preprocessor.tokenizer.detokenize(output)
372+
return tf.squeeze(output, 0) if input_is_scalar else output

keras_nlp/models/gpt2/gpt2_causal_lm_test.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,25 +122,27 @@ def test_gpt2_causal_lm_fit_no_preprocessing(self, jit_compile):
122122
self.causal_lm_no_preprocessing.fit(self.preprocessed_dataset)
123123

124124
@parameterized.named_parameters(
125-
("non_jit_compile_cache", False, True),
126-
("non_jit_compile_non_cache", False, False),
127-
("jit_compile_non_cache", True, False),
125+
("jit_compile_false", False), ("jit_compile_true", True)
128126
)
129-
def test_gpt2_causal_lm_generate(self, jit_compile, use_cache):
127+
def test_compilation(self, jit_compile):
128+
# Tensor input.
130129
self.causal_lm.compile(jit_compile=jit_compile)
131130
self.causal_lm.generate(
132131
self.raw_batch,
133132
max_length=10,
134133
)
135-
136-
# String input
134+
first_fn = self.causal_lm.generate_function
135+
# String input.
137136
prompt = " airplane"
138137
generated = self.causal_lm.generate(
139138
prompt,
140139
max_length=10,
141140
)
142141
generated = generated.numpy().decode("utf-8")
143142
self.assertTrue(prompt in generated)
143+
second_fn = self.causal_lm.generate_function
144+
# Assert we did not recompile.
145+
self.assertEqual(first_fn, second_fn)
144146

145147
@parameterized.named_parameters(
146148
("tf_format", "tf", "model"),

keras_nlp/samplers/beam_sampler.py

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -42,30 +42,23 @@ class BeamSampler(Sampler):
4242
4343
Examples:
4444
```python
45-
VOCAB_SIZE = 10
46-
47-
# Create a dummy model to predict the next token.
48-
model = keras.Sequential(
49-
[
50-
keras.Input(shape=[None]),
51-
keras.layers.Embedding(
52-
input_dim=VOCAB_SIZE,
53-
output_dim=16,
54-
),
55-
keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
56-
]
45+
# Use a simple alphabet of lowercase characters to [0, 26).
46+
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
47+
char_lookup = {v: k for k, v in int_lookup.items()}
48+
batch_size, length, vocab_size = 1, 12, len(int_lookup)
49+
50+
def next(prompt, state, index):
51+
# A uniform distribution over our alphabet.
52+
probs = tf.ones((batch_size, vocab_size))
53+
return probs, state
54+
55+
output = keras_nlp.samplers.BeamSampler()(
56+
next=next,
57+
prompt=tf.fill((batch_size, length,), char_lookup['z']),
58+
index=5,
5759
)
58-
59-
# Define a function that outputs the next token's probability for each token
60-
# in the input sequence.
61-
def token_probability_fn(inputs, mask):
62-
return model(inputs)
63-
64-
prompt = tf.fill((8, 1), 1)
65-
66-
sampler = keras_nlp.samplers.BeamSampler(num_beams=3)
67-
# Print the generated sequence (token ids).
68-
print(sampler(prompt, token_probability_fn, max_length=10))
60+
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
61+
# >>> "zzzzzaaaaaaa"
6962
```
7063
"""
7164

@@ -78,8 +71,8 @@ def __init__(
7871

7972
def __call__(
8073
self,
81-
prompt,
8274
next,
75+
prompt,
8376
index=0,
8477
state=None,
8578
mask=None,
@@ -99,6 +92,17 @@ def unflatten(x):
9992
unflat_shape = [batch_size, self.num_beams] + x.shape.as_list()[1:]
10093
return tf.reshape(x, shape=unflat_shape)
10194

95+
mask = tf.zeros_like(prompt, dtype=tf.bool) if mask is None else mask
96+
# `tf.while_loop` will not accept `None` as a value for `loop_vars`.
97+
state = () if state is None else state
98+
# Add extra sequences for each beam.
99+
prompt, mask = add_beams(prompt), add_beams(mask)
100+
state = tf.nest.map_structure(add_beams, state)
101+
# Setup the initial beam log-likelihoods.
102+
# On the first loop, make sure only the original beam is considered.
103+
beam_probs = tf.constant([[0.0] + [-1e9] * (self.num_beams - 1)])
104+
beam_probs = flatten(tf.repeat(beam_probs, batch_size, axis=0))
105+
102106
def cond(prompt, state, index, beam_probs):
103107
if end_token_id is None:
104108
return True
@@ -127,13 +131,13 @@ def body(prompt, state, index, beam_probs):
127131
# We need `ensure_shape` as `top_k` will change the static shape.
128132
beam_probs = tf.ensure_shape(flatten(next_probs), beam_probs.shape)
129133

130-
# Gather the correct prompt and state beams.
131-
prompt = unflatten(prompt)
132-
state = tf.nest.map_structure(unflatten, state)
133-
prompt = tf.gather(prompt, beam_indices, axis=1, batch_dims=1)
134-
state = tf.gather(state, beam_indices, axis=1, batch_dims=1)
135-
prompt = flatten(prompt)
136-
state = tf.nest.map_structure(flatten, state)
134+
def gather_beams(x):
135+
x = unflatten(x)
136+
x = tf.gather(x, beam_indices, axis=1, batch_dims=1)
137+
return flatten(x)
138+
139+
prompt = gather_beams(prompt)
140+
state = tf.nest.map_structure(gather_beams, state)
137141

138142
# Update each beam with the next token.
139143
next_token = tf.cast(next_token, prompt.dtype)
@@ -145,25 +149,14 @@ def body(prompt, state, index, beam_probs):
145149
# Return the iteration of the loop state.
146150
return (prompt, state, index + 1, beam_probs)
147151

148-
mask = tf.zeros_like(prompt, dtype=tf.bool) if mask is None else mask
149-
# `tf.while_loop` will not accept `None` as a value for `loop_vars`.
150-
state = () if state is None else state
151-
# Add extra sequences for each beam.
152-
prompt, mask = add_beams(prompt), add_beams(mask)
153-
state = tf.nest.map_structure(add_beams, state)
154-
# Setup the initial beam log-likelihoods.
155-
# On the first loop, make sure only the original beam is considered.
156-
beam_probs = tf.constant([[0.0] + [-1e9] * (self.num_beams - 1)])
157-
beam_probs = flatten(tf.repeat(beam_probs, batch_size, axis=0))
158-
159152
prompt, _, _, beam_probs = tf.while_loop(
160153
cond=cond,
161154
body=body,
162155
loop_vars=(prompt, state, index, beam_probs),
163156
maximum_iterations=(max_length - index),
164157
)
165158

166-
# Gather the top beams for each batch index.
159+
# Gather the top beam at each batch index.
167160
prompt, beam_probs = unflatten(prompt), unflatten(beam_probs)
168161
top_beams = tf.math.argmax(beam_probs, axis=-1)[:, tf.newaxis]
169162
prompt = tf.gather(prompt, top_beams, axis=1, batch_dims=1)

0 commit comments

Comments
 (0)