Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multinomial sampling to sequence generation #161

Merged
merged 7 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ defmodule Bumblebee.Text do
* `:min_new_tokens` - the minimum number of tokens to be generated,
ignoring the number of tokens in the prompt

* `:num_beams` - the number of beams to use in a beam search. If set
to 1, beam search will not be used. Defaults to `1`

* `:early_stopping` - whether to stop the beam search when at least
`:num_beams` sentences are finished per batch or not. Defaults to
`false`

* `:sample` - whether or not to use random sampling. Defaults to `false`

* `:prng_key` - random key to use when sampling. Defaults to `nil`

* `:compile` - compiles all computations for predefined input shapes
during serving initialization. Should be a keyword list with the
following keys:
Expand Down
196 changes: 186 additions & 10 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ defmodule Bumblebee.Text.Generation do
(including padding). In general, prefer `:min_new_tokens`, which
ignores the number of tokens in the prompt

* `:num_beams` - the number of beams to use in a beam search. If set
to 1, beam search will not be used. Defaults to `1`

* `:early_stopping` - whether to stop the beam search when at least
`:num_beams` sentences are finished per batch or not. Defaults to
`false`

* `:sample` - whether or not to use random sampling. Defaults to `false`

* `:prng_key` - random key to use when sampling. Defaults to `nil`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this function is pretty higher-level, what about :seed instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, though one thing I was thinking was if we should pass this in the serving as well and return the updated key. They do accept key in generate in HF, but they do not return an updated key, so maybe just allowing a key to the serving is enough and it is assumed you will split it outside the serving

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Servings are so high-level that it seems to me the seed is for reproducibility if anything else (or to explore different result for the same prompt), and so a number is more transparent than a tensor key. Or do you think there's a scenario where you would actually use a split key?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, what I mean is currently you'd need to create a new serving with a new seed to get different inputs. Maybe this is desirable though, but I'm not sure as if you want something non-deterministic you may want the option to pass a key to the serving


* `:decoder_start_token_id` - the id of the initial token when
generating from scratch, in case of encoder-decoder models

Expand Down Expand Up @@ -148,6 +159,10 @@ defmodule Bumblebee.Text.Generation do
min_new_tokens: nil,
max_length: nil,
min_length: nil,
num_beams: 1,
sample: false,
early_stopping: false,
prng_key: nil,
decoder_start_token_id: Map.get(spec, :decoder_start_token_id),
bos_token_id: Map.get(spec, :bos_token_id),
eos_token_id: Map.get(spec, :eos_token_id),
Expand All @@ -164,6 +179,11 @@ defmodule Bumblebee.Text.Generation do
forced_eos_token_id = opts[:forced_eos_token_id]
forced_token_ids = opts[:forced_token_ids]

num_beams = opts[:num_beams]
early_stopping = opts[:early_stopping]
sample = opts[:sample]
prng_key = opts[:prng_key]

{max_length_fun, min_length_fun} = lazy_lengths_from_opts(opts)

{prepare_inputs_fun, update_inputs_fun} =
Expand All @@ -188,7 +208,11 @@ defmodule Bumblebee.Text.Generation do
prepare_inputs_fun,
update_inputs_fun,
pad_token_id: pad_token_id,
eos_token_id: eos_token_id
eos_token_id: eos_token_id,
num_beams: num_beams,
early_stopping: early_stopping,
sample: sample,
prng_key: prng_key
)
end

Expand Down Expand Up @@ -374,19 +398,43 @@ defmodule Bumblebee.Text.Generation do
update_inputs_fun,
opts \\ []
) do
{num_beams, opts} = Keyword.pop!(opts, :num_beams)
{early_stopping, opts} = Keyword.pop!(opts, :early_stopping)
{sample, opts} = Keyword.pop!(opts, :sample)
{prng_key, opts} = Keyword.pop!(opts, :prng_key)

{decoder_inputs, decoder_input_ids, max_length} = prepare_inputs_fun.(inputs, params)

greedy(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
[max_length: max_length] ++ opts
)
cond do
sample and num_beams == 1 ->
prng_key = prng_key || Nx.Random.key(:erlang.system_time())
prng_key = Nx.backend_copy(prng_key, Nx.BinaryBackend)

sample(
inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
[max_length: max_length, prng_key: prng_key] ++ opts
)

true ->
greedy(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
[max_length: max_length] ++ opts
)
end
end

## Greedy Generation

defnp greedy(
inputs,
decoder_input_ids,
Expand Down Expand Up @@ -507,6 +555,134 @@ defmodule Bumblebee.Text.Generation do
{sequences, length + 1, finished?, inputs}
end

## Sampling

defnp sample(
inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
opts \\ []
) do
max_length = opts[:max_length]
pad_token_id = opts[:pad_token_id]
eos_token_id = opts[:eos_token_id]
prng_key = opts[:prng_key]

{batch_size, length} = Nx.shape(decoder_input_ids)

if length > max_length do
raise ArgumentError, "expected the input to be at most #{max_length} tokens, got: #{length}"
end

sequences = Nx.broadcast(pad_token_id, {batch_size, max_length})
sequences = Nx.put_slice(sequences, [0, 0], decoder_input_ids)

finished? = Nx.broadcast(Nx.tensor(0, type: :u8), {batch_size})

input_length = length

# The loop works with inputs of length 1, so if the initial input
# is longer, we make the initial pass outside
{sequences, length, finished?, inputs, prng_key} =
if length > 1 do
sample_step(
sequences,
length,
finished?,
inputs,
input_length,
predict_fun,
params,
prng_key,
logits_processor_fun,
update_inputs_fun,
pad_token_id: pad_token_id,
eos_token_id: eos_token_id
)
else
{sequences, length, finished?, inputs, prng_key}
end

{sequences, _length, _finished?, _inputs, _params, _key} =
while {sequences, length, finished?, inputs, params, prng_key},
sample_condition(finished?, length, max_length) do
{sequences, length, finished?, inputs, prng_key} =
sample_step(
sequences,
length,
finished?,
inputs,
input_length,
predict_fun,
params,
prng_key,
logits_processor_fun,
update_inputs_fun,
pad_token_id: pad_token_id,
eos_token_id: eos_token_id
)

{sequences, length, finished?, inputs, params, prng_key}
end

sequences
end

defnp sample_condition(finished?, length, max_length) do
not(Nx.all(finished?) or length == max_length)
end

defnp sample_step(
sequences,
length,
finished?,
inputs,
input_length,
predict_fun,
params,
prng_key,
logits_processor_fun,
update_inputs_fun,
opts \\ []
) do
pad_token_id = opts[:pad_token_id]
eos_token_id = opts[:eos_token_id]

key = Nx.Random.split(prng_key)
{key, key_next} = {key[1], key[0]}

model_outputs = predict_fun.(params, inputs)

logits = model_outputs.logits[[0..-1//1, -1]]
jonatanklosko marked this conversation as resolved.
Show resolved Hide resolved

logits =
logits_processor_fun.(logits, %{
sequences: sequences,
length: length,
input_length: input_length
})

# TODO: logits warper

vocab = Nx.iota(Nx.shape(logits), axis: -1)
probabilities = Axon.Activations.softmax(logits)
next_token = Nx.Random.choice(key, vocab, probabilities, [])
seanmor5 marked this conversation as resolved.
Show resolved Hide resolved

next_finished? = Nx.logical_or(finished?, Nx.equal(next_token, eos_token_id))
next_token = next_token * Nx.logical_not(next_finished?) + pad_token_id * next_finished?
next_token = Nx.new_axis(next_token, 1)

sequences = Nx.put_slice(sequences, [0, length], next_token)
inputs = update_inputs_fun.(inputs, model_outputs, next_token)

{sequences, length + 1, next_finished?, inputs, key_next}
end

## Beam Search

# Logit processors

defnp bos_token_logits_processor(logits, context, opts \\ []) do
Expand Down
21 changes: 20 additions & 1 deletion test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ defmodule Bumblebee.Text.GenerationTest do
@moduletag model_test_tags()

describe "integration" do
test "generates text" do
test "generates text with greedy generation" do
{:ok, model_info} = Bumblebee.load_model({:hf, "facebook/bart-large-cnn"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-cnn"})

Expand All @@ -21,5 +21,24 @@ defmodule Bumblebee.Text.GenerationTest do

assert %{results: [%{text: "PG&E scheduled the black"}]} = Nx.Serving.run(serving, article)
end

test "generates text with sampling" do
{:ok, model_info} = Bumblebee.load_model({:hf, "gpt2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "gpt2"})

serving =
Bumblebee.Text.generation(model_info, tokenizer,
max_new_tokens: 8,
sample: true,
prng_key: Nx.Random.key(0)
)

prompt = """
I enjoy walking with my cute dog
"""

assert %{results: [%{text: "On the field, on a field trip,"}]} =
Nx.Serving.run(serving, prompt)
end
end
end