Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make randomization seed a serving input, rather than compile option #303

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lib/bumblebee/audio.ex
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ defmodule Bumblebee.Audio do
requires `ffmpeg` installed)

"""
@type speech_to_text_whisper_input :: Nx.t() | {:file, String.t()}
@type audio :: Nx.t() | {:file, String.t()}

@type speech_to_text_whisper_input ::
audio() | %{:audio => audio(), optional(:seed) => integer()}
@type speech_to_text_whisper_output :: %{
chunks: list(speech_to_text_whisper_chunk())
}
Expand Down Expand Up @@ -88,9 +91,6 @@ defmodule Bumblebee.Audio do
supported value is `:segments`, the length of each segment is up
to the model

* `:seed` - random seed to use when sampling. By default the current
timestamp is used

* `:compile` - compiles all computations for predefined input shapes
during serving initialization. Should be a keyword list with the
following keys:
Expand Down
61 changes: 39 additions & 22 deletions lib/bumblebee/audio/speech_to_text_whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
:chunk_num_seconds,
:context_num_seconds,
:language,
:seed,
:compile,
:timestamps,
defn_options: [],
Expand Down Expand Up @@ -50,8 +49,9 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
{generate_opts, generation_config} = generate_opts(generation_config, opts)
generate_fun = Text.Generation.build_generate(model, spec, generation_config, generate_opts)

generate_fun = fn params, inputs ->
generate_fun = fn params, {inputs, seed} ->
inputs = Bumblebee.Featurizer.process_batch(featurizer, inputs)
inputs = Map.put(inputs, "seed", seed)
generate_fun.(params, inputs)
end

Expand All @@ -62,7 +62,8 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size)
[params, inputs]
seed = Nx.template({batch_size}, :s64)
[params, {inputs, seed}]
end)

fn inputs ->
Expand All @@ -78,24 +79,17 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
Shared.validate_input_for_stream!(input)
end

{inputs, multi?} =
Shared.validate_serving_input!(input, fn
%Nx.Tensor{shape: {_}} = input ->
{:ok, Nx.backend_transfer(input, Nx.BinaryBackend)}

{:file, path} when is_binary(path) ->
ffmpeg_read_as_pcm(path, sampling_rate)

other ->
{:error, "expected a 1-dimensional tensor or {:file, path}, got: #{inspect(other)}"}
end)
{inputs, multi?} = Shared.validate_serving_input!(input, &validate_input(&1, sampling_rate))

all_chunks =
for input <- inputs do
if chunk_num_seconds do
chunk_input(input, sampling_rate, chunk_num_seconds, context_num_seconds)
chunks =
chunk_input(input.audio, sampling_rate, chunk_num_seconds, context_num_seconds)

for {chunk, lengths} <- chunks, do: {{chunk, input.seed}, lengths}
else
[{input, nil}]
[{{input.audio, input.seed}, nil}]
end
end

Expand All @@ -109,19 +103,45 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
all_chunks
|> Stream.chunk_every(batch_size)
|> Stream.map(fn all_chunks ->
{all_chunks, seed} = Enum.unzip(all_chunks)
seed = Nx.tensor(seed, backend: Nx.BinaryBackend)
inputs = Bumblebee.Featurizer.process_input(featurizer, all_chunks)
Nx.Batch.concatenate([inputs])
Nx.Batch.concatenate([{inputs, seed}])
end)

{stream, {multi?, all_num_chunks, lengths}}
else
{all_chunks, seed} = Enum.unzip(all_chunks)
seed = Nx.tensor(seed, backend: Nx.BinaryBackend)
inputs = Bumblebee.Featurizer.process_input(featurizer, all_chunks)
{Nx.Batch.concatenate([inputs]), {multi?, all_num_chunks, lengths}}
{Nx.Batch.concatenate([{inputs, seed}]), {multi?, all_num_chunks, lengths}}
end
end)
|> maybe_stream(opts[:stream], spec, featurizer, tokenizer, timestamps?)
end

defp validate_input(%{audio: audio} = input, sampling_rate) do
with {:ok, audio} <- parse_audio(audio, sampling_rate) do
{:ok, %{audio: audio, seed: input[:seed] || :erlang.system_time()}}
end
end

defp validate_input(input, sampling_rate), do: validate_input(%{audio: input}, sampling_rate)

defp parse_audio(input, sampling_rate) do
case input do
%Nx.Tensor{shape: {_}} = input ->
{:ok, Nx.backend_transfer(input, Nx.BinaryBackend)}

{:file, path} when is_binary(path) ->
ffmpeg_read_as_pcm(path, sampling_rate)

other ->
{:error,
"expected audio to be a 1-dimensional tensor or {:file, path}, got: #{inspect(other)}"}
end
end

defp maybe_stream(serving, false, spec, featurizer, tokenizer, timestamps?) do
Nx.Serving.client_postprocessing(serving, fn
{outputs, _metadata}, {multi?, all_num_chunks, lengths} ->
Expand Down Expand Up @@ -170,10 +190,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
[]
end

opts =
opts
|> Keyword.take([:seed])
|> Keyword.put(:logits_processors, logits_processors)
opts = [logits_processors: logits_processors]

{opts, generation_config}
end
Expand Down
41 changes: 30 additions & 11 deletions lib/bumblebee/diffusion/stable_diffusion.ex
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
alias Bumblebee.Shared

@type text_to_image_input ::
String.t() | %{:prompt => String.t(), optional(:negative_prompt) => String.t()}
String.t()
| %{
:prompt => String.t(),
optional(:negative_prompt) => String.t(),
optional(:seed) => integer()
}
@type text_to_image_output :: %{results: list(text_to_image_result())}
@type text_to_image_result :: %{:image => Nx.Tensor.t(), optional(:is_safe) => boolean()}

Expand Down Expand Up @@ -44,8 +49,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
$\omega$ in Equation (2) of the [Imagen paper](https://arxiv.org/pdf/2205.11487.pdf).
Defaults to `7.5`

* `:seed` - a seed for the random number generator. Defaults to `0`

* `:compile` - compiles all computations for predefined input shapes
during serving initialization. Should be a keyword list with the
following keys:
Expand Down Expand Up @@ -131,7 +134,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
num_steps: 50,
num_images_per_prompt: 1,
guidance_scale: 7.5,
seed: 0,
defn_options: [],
preallocate_params: false
])
Expand Down Expand Up @@ -183,7 +185,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
num_images_per_prompt: opts[:num_images_per_prompt],
latents_sample_size: unet.spec.sample_size,
latents_channels: unet.spec.in_channels,
seed: opts[:seed],
guidance_scale: opts[:guidance_scale]
)

Expand Down Expand Up @@ -237,7 +238,11 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
"input_ids" => Nx.template({batch_size, sequence_length}, :u32)
}

inputs = %{"unconditional" => text_inputs, "conditional" => text_inputs}
inputs = %{
"unconditional" => text_inputs,
"conditional" => text_inputs,
"seed" => Nx.template({batch_size}, :s64)
}

[encoder_params, unet_params, vae_params, inputs]
end)
Expand Down Expand Up @@ -284,6 +289,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion do

prompts = Enum.map(inputs, & &1.prompt)
negative_prompts = Enum.map(inputs, & &1.negative_prompt)
seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(backend: Nx.BinaryBackend)

conditional =
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Expand All @@ -303,7 +309,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
)
end)

inputs = %{"unconditional" => unconditional, "conditional" => conditional}
inputs = %{"unconditional" => unconditional, "conditional" => conditional, "seed" => seed}

{Nx.Batch.concatenate([inputs]), multi?}
end
Expand Down Expand Up @@ -349,9 +355,10 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
num_images_per_prompt = opts[:num_images_per_prompt]
latents_sample_size = opts[:latents_sample_size]
latents_in_channels = opts[:latents_channels]
seed = opts[:seed]
guidance_scale = opts[:guidance_scale]

seed = inputs["seed"]

inputs =
Bumblebee.Utils.Nx.composite_concatenate(inputs["unconditional"], inputs["conditional"])

Expand All @@ -372,8 +379,15 @@ defmodule Bumblebee.Diffusion.StableDiffusion do

{scheduler_state, timesteps} = scheduler_init.(latents_shape)

key = Nx.Random.key(seed)
{latents, _key} = Nx.Random.normal(key, shape: latents_shape)
key = seed |> Nx.vectorize(:batch) |> Nx.Random.key()

{latents, _key} =
Nx.Random.normal(key,
shape:
{num_images_per_prompt, latents_sample_size, latents_sample_size, latents_in_channels}
)

latents = latents |> Nx.devectorize() |> Nx.reshape(latents_shape)

{_, latents, _, _} =
while {scheduler_state, latents, text_embeddings, unet_params}, timestep <- timesteps do
Expand Down Expand Up @@ -411,7 +425,12 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
defp validate_input(prompt) when is_binary(prompt), do: validate_input(%{prompt: prompt})

defp validate_input(%{prompt: prompt} = input) do
{:ok, %{prompt: prompt, negative_prompt: input[:negative_prompt] || ""}}
{:ok,
%{
prompt: prompt,
negative_prompt: input[:negative_prompt] || "",
seed: input[:seed] || :erlang.system_time()
}}
end

defp validate_input(%{} = input) do
Expand Down
15 changes: 7 additions & 8 deletions lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ defmodule Bumblebee.Text do
defdelegate token_classification(model_info, tokenizer, opts \\ []),
to: Bumblebee.Text.TokenClassification

@type generation_input :: String.t()
@type generation_input ::
String.t() | %{:text => String.t(), optional(:seed) => integer()}
@type generation_output :: %{results: list(generation_result())}
@type generation_result :: %{text: String.t()}

Expand All @@ -138,9 +139,6 @@ defmodule Bumblebee.Text do

## Options

* `:seed` - random seed to use when sampling. By default the current
timestamp is used

* `:compile` - compiles all computations for predefined input shapes
during serving initialization. Should be a keyword list with the
following keys:
Expand Down Expand Up @@ -215,7 +213,11 @@ defmodule Bumblebee.Text do
defdelegate generation(model_info, tokenizer, generation_config, opts \\ []),
to: Bumblebee.Text.Generation

@type conversation_input :: %{text: String.t(), history: conversation_history() | nil}
@type conversation_input :: %{
:text => String.t(),
:history => conversation_history() | nil,
optional(:seed) => integer()
}
@type conversation_output :: %{text: String.t(), history: conversation_history()}

@type conversation_history :: list({:user | :generated, String.t()})
Expand All @@ -233,9 +235,6 @@ defmodule Bumblebee.Text do

## Options

* `:seed` - random seed to use when sampling. By default the current
timestamp is used

* `:compile` - compiles all computations for predefined input shapes
during serving initialization. Should be a keyword list with the
following keys:
Expand Down
20 changes: 13 additions & 7 deletions lib/bumblebee/text/conversation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ defmodule Bumblebee.Text.Conversation do
%Text.GenerationConfig{} = generation_config,
opts \\ []
) do
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], preallocate_params: false])
opts = Keyword.validate!(opts, [:compile, defn_options: [], preallocate_params: false])

%{model: model, params: params, spec: spec} = model_info

Expand All @@ -41,8 +41,7 @@ defmodule Bumblebee.Text.Conversation do
batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]

generate_fun =
Text.Generation.build_generate(model, spec, generation_config, Keyword.take(opts, [:seed]))
generate_fun = Text.Generation.build_generate(model, spec, generation_config)

batch_keys = Shared.sequence_batch_keys(sequence_length)

Expand All @@ -56,7 +55,8 @@ defmodule Bumblebee.Text.Conversation do

inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32),
"seed" => Nx.template({batch_size}, :s64)
}

[params, inputs]
Expand All @@ -72,7 +72,10 @@ defmodule Bumblebee.Text.Conversation do
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.process_options(batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{histories, multi?} = Shared.validate_serving_input!(input, &validate_input/1)
{inputs, multi?} = Shared.validate_serving_input!(input, &validate_input/1)

histories = Enum.map(inputs, & &1.history)
seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(backend: Nx.BinaryBackend)

texts =
for history <- histories do
Expand All @@ -89,6 +92,8 @@ defmodule Bumblebee.Text.Conversation do
)
end)

inputs = Map.put(inputs, "seed", seed)

batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)

Expand All @@ -104,9 +109,10 @@ defmodule Bumblebee.Text.Conversation do
end)
end

defp validate_input(%{text: text, history: history}) when is_binary(text) do
defp validate_input(%{text: text, history: history} = input) when is_binary(text) do
history = history || []
{:ok, [{:user, text} | history]}
history = [{:user, text} | history]
{:ok, %{history: history, seed: input[:seed] || :erlang.system_time()}}
end

defp validate_input(input) do
Expand Down
Loading
Loading