@@ -42,30 +42,23 @@ class BeamSampler(Sampler):
42
42
43
43
Examples:
44
44
```python
45
- VOCAB_SIZE = 10
46
-
47
- # Create a dummy model to predict the next token.
48
- model = keras.Sequential(
49
- [
50
- keras.Input(shape=[None]),
51
- keras.layers.Embedding(
52
- input_dim=VOCAB_SIZE,
53
- output_dim=16,
54
- ),
55
- keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
56
- ]
45
+ # Use a simple alphabet of lowercase characters to [0, 26).
46
+ int_lookup = {i: chr(i + ord('a')) for i in range(26)}
47
+ char_lookup = {v: k for k, v in int_lookup.items()}
48
+ batch_size, length, vocab_size = 1, 12, len(int_lookup)
49
+
50
+ def next(prompt, state, index):
51
+ # A uniform distribution over our alphabet.
52
+ probs = tf.ones((batch_size, vocab_size))
53
+ return probs, state
54
+
55
+ output = keras_nlp.samplers.BeamSampler()(
56
+ next=next,
57
+ prompt=tf.fill((batch_size, length,), char_lookup['z']),
58
+ index=5,
57
59
)
58
-
59
- # Define a function that outputs the next token's probability for each token
60
- # in the input sequence.
61
- def token_probability_fn(inputs, mask):
62
- return model(inputs)
63
-
64
- prompt = tf.fill((8, 1), 1)
65
-
66
- sampler = keras_nlp.samplers.BeamSampler(num_beams=3)
67
- # Print the generated sequence (token ids).
68
- print(sampler(prompt, token_probability_fn, max_length=10))
60
+ print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
61
+ # >>> "zzzzzaaaaaaa"
69
62
```
70
63
"""
71
64
@@ -78,8 +71,8 @@ def __init__(
78
71
79
72
def __call__ (
80
73
self ,
81
- prompt ,
82
74
next ,
75
+ prompt ,
83
76
index = 0 ,
84
77
state = None ,
85
78
mask = None ,
@@ -99,6 +92,17 @@ def unflatten(x):
99
92
unflat_shape = [batch_size , self .num_beams ] + x .shape .as_list ()[1 :]
100
93
return tf .reshape (x , shape = unflat_shape )
101
94
95
+ mask = tf .zeros_like (prompt , dtype = tf .bool ) if mask is None else mask
96
+ # `tf.while_loop` will not accept `None` as a value for `loop_vars`.
97
+ state = () if state is None else state
98
+ # Add extra sequences for each beam.
99
+ prompt , mask = add_beams (prompt ), add_beams (mask )
100
+ state = tf .nest .map_structure (add_beams , state )
101
+ # Setup the initial beam log-likelihoods.
102
+ # On the first loop, make sure only the original beam is considered.
103
+ beam_probs = tf .constant ([[0.0 ] + [- 1e9 ] * (self .num_beams - 1 )])
104
+ beam_probs = flatten (tf .repeat (beam_probs , batch_size , axis = 0 ))
105
+
102
106
def cond (prompt , state , index , beam_probs ):
103
107
if end_token_id is None :
104
108
return True
@@ -127,13 +131,13 @@ def body(prompt, state, index, beam_probs):
127
131
# We need `ensure_shape` as `top_k` will change the static shape.
128
132
beam_probs = tf .ensure_shape (flatten (next_probs ), beam_probs .shape )
129
133
130
- # Gather the correct prompt and state beams.
131
- prompt = unflatten (prompt )
132
- state = tf .nest . map_structure ( unflatten , state )
133
- prompt = tf . gather ( prompt , beam_indices , axis = 1 , batch_dims = 1 )
134
- state = tf . gather ( state , beam_indices , axis = 1 , batch_dims = 1 )
135
- prompt = flatten (prompt )
136
- state = tf .nest .map_structure (flatten , state )
134
+ def gather_beams ( x ):
135
+ x = unflatten (x )
136
+ x = tf .gather ( x , beam_indices , axis = 1 , batch_dims = 1 )
137
+ return flatten ( x )
138
+
139
+ prompt = gather_beams (prompt )
140
+ state = tf .nest .map_structure (gather_beams , state )
137
141
138
142
# Update each beam with the next token.
139
143
next_token = tf .cast (next_token , prompt .dtype )
@@ -145,25 +149,14 @@ def body(prompt, state, index, beam_probs):
145
149
# Return the iteration of the loop state.
146
150
return (prompt , state , index + 1 , beam_probs )
147
151
148
- mask = tf .zeros_like (prompt , dtype = tf .bool ) if mask is None else mask
149
- # `tf.while_loop` will not accept `None` as a value for `loop_vars`.
150
- state = () if state is None else state
151
- # Add extra sequences for each beam.
152
- prompt , mask = add_beams (prompt ), add_beams (mask )
153
- state = tf .nest .map_structure (add_beams , state )
154
- # Setup the initial beam log-likelihoods.
155
- # On the first loop, make sure only the original beam is considered.
156
- beam_probs = tf .constant ([[0.0 ] + [- 1e9 ] * (self .num_beams - 1 )])
157
- beam_probs = flatten (tf .repeat (beam_probs , batch_size , axis = 0 ))
158
-
159
152
prompt , _ , _ , beam_probs = tf .while_loop (
160
153
cond = cond ,
161
154
body = body ,
162
155
loop_vars = (prompt , state , index , beam_probs ),
163
156
maximum_iterations = (max_length - index ),
164
157
)
165
158
166
- # Gather the top beams for each batch index.
159
+ # Gather the top beam at each batch index.
167
160
prompt , beam_probs = unflatten (prompt ), unflatten (beam_probs )
168
161
top_beams = tf .math .argmax (beam_probs , axis = - 1 )[:, tf .newaxis ]
169
162
prompt = tf .gather (prompt , top_beams , axis = 1 , batch_dims = 1 )
0 commit comments