Skip to content

Commit

Permalink
Add multinomial sampling to sequence generation (#161)
Browse files Browse the repository at this point in the history
Co-authored-by: José Valim <jose.valim@dashbit.co>
Co-authored-by: Jonatan Kłosko <jonatanklosko@gmail.com>
  • Loading branch information
3 people committed Apr 12, 2023
1 parent 6c4d885 commit 1f968a5
Show file tree
Hide file tree
Showing 13 changed files with 670 additions and 172 deletions.
10 changes: 2 additions & 8 deletions lib/bumblebee/audio.ex
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,10 @@ defmodule Bumblebee.Audio do
The serving accepts `t:speech_to_text_input/0` and returns
`t:speech_to_text_output/0`. A list of inputs is also supported.
Note that either `:max_new_tokens` or `:max_length` must be specified.
The generation should generally finish based on the audio input,
however you still need to specify the upper limit.
## Options
* `:max_new_tokens` - the maximum number of tokens to be generated,
ignoring the number of tokens in the prompt
* `:seed` - random seed to use when sampling. By default the current
timestamp is used
* `:compile` - compiles all computations for predefined input shapes
during serving initialization. Should be a keyword list with the
Expand All @@ -46,8 +42,6 @@ defmodule Bumblebee.Audio do
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
Also accepts all the other options of `Bumblebee.Text.Generation.build_generate/3`.
## Examples
{:ok, whisper} = Bumblebee.load_model({:hf, "openai/whisper-tiny"})
Expand Down
10 changes: 8 additions & 2 deletions lib/bumblebee/audio/speech_to_text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule Bumblebee.Audio.SpeechToText do
alias Bumblebee.Shared

def speech_to_text(model_info, featurizer, tokenizer, generation_config, opts \\ []) do
opts = Keyword.validate!(opts, [:compile, defn_options: []])
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: []])

compile = opts[:compile]
defn_options = opts[:defn_options]
Expand All @@ -20,7 +20,13 @@ defmodule Bumblebee.Audio.SpeechToText do

%{model: model, params: params, spec: spec} = model_info

generate_fun = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
generate_fun =
Bumblebee.Text.Generation.build_generate(
model,
spec,
generation_config,
Keyword.take(opts, [:seed])
)

Nx.Serving.new(
fn defn_options ->
Expand Down
20 changes: 4 additions & 16 deletions lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,10 @@ defmodule Bumblebee.Text do
The serving accepts `t:generation_input/0` and returns `t:generation_output/0`.
A list of inputs is also supported.
Note that either `:max_new_tokens` or `:max_length` must be specified.
## Options
* `:max_new_tokens` - the maximum number of tokens to be generated,
ignoring the number of tokens in the prompt
* `:min_new_tokens` - the minimum number of tokens to be generated,
ignoring the number of tokens in the prompt
* `:seed` - random seed to use when sampling. By default the current
timestamp is used
* `:compile` - compiles all computations for predefined input shapes
during serving initialization. Should be a keyword list with the
Expand All @@ -147,8 +142,6 @@ defmodule Bumblebee.Text do
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
Also accepts all the other options of `Bumblebee.Text.Generation.build_generate/3`.
## Examples
{:ok, model_info} = Bumblebee.load_model({:hf, "gpt2"})
Expand Down Expand Up @@ -194,11 +187,8 @@ defmodule Bumblebee.Text do
## Options
* `:max_new_tokens` - the maximum number of tokens to be generated,
ignoring the number of tokens in the prompt
* `:min_new_tokens` - the minimum number of tokens to be generated,
ignoring the number of tokens in the prompt
* `:seed` - random seed to use when sampling. By default the current
timestamp is used
* `:compile` - compiles all computations for predefined input shapes
during serving initialization. Should be a keyword list with the
Expand All @@ -219,8 +209,6 @@ defmodule Bumblebee.Text do
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
Also accepts all the other options of `Bumblebee.Text.Generation.build_generate/3`.
## Examples
{:ok, model_info} = Bumblebee.load_model({:hf, "facebook/blenderbot-400M-distill"})
Expand Down
5 changes: 3 additions & 2 deletions lib/bumblebee/text/conversation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ defmodule Bumblebee.Text.Conversation do

@doc false
def conversation(model_info, tokenizer, generation_config, opts \\ []) do
opts = Keyword.validate!(opts, [:compile, defn_options: []])
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: []])

%{params: params, spec: spec} = model_info

Expand All @@ -39,7 +39,8 @@ defmodule Bumblebee.Text.Conversation do
Bumblebee.Text.Generation.build_generate(
model_info.model,
model_info.spec,
generation_config
generation_config,
Keyword.take(opts, [:seed])
)

Nx.Serving.new(
Expand Down
Loading

0 comments on commit 1f968a5

Please sign in to comment.