diff --git a/lib/bumblebee/audio.ex b/lib/bumblebee/audio.ex index f40a6aa3..86243ebc 100644 --- a/lib/bumblebee/audio.ex +++ b/lib/bumblebee/audio.ex @@ -40,7 +40,10 @@ defmodule Bumblebee.Audio do requires `ffmpeg` installed) """ - @type speech_to_text_whisper_input :: Nx.t() | {:file, String.t()} + @type audio :: Nx.t() | {:file, String.t()} + + @type speech_to_text_whisper_input :: + audio() | %{:audio => audio(), optional(:seed) => integer()} @type speech_to_text_whisper_output :: %{ chunks: list(speech_to_text_whisper_chunk()) } @@ -88,9 +91,6 @@ defmodule Bumblebee.Audio do supported value is `:segments`, the length of each segment is up to the model - * `:seed` - random seed to use when sampling. By default the current - timestamp is used - * `:compile` - compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the following keys: diff --git a/lib/bumblebee/audio/speech_to_text_whisper.ex b/lib/bumblebee/audio/speech_to_text_whisper.ex index 99434aa1..41164ad7 100644 --- a/lib/bumblebee/audio/speech_to_text_whisper.ex +++ b/lib/bumblebee/audio/speech_to_text_whisper.ex @@ -17,7 +17,6 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do :chunk_num_seconds, :context_num_seconds, :language, - :seed, :compile, :timestamps, defn_options: [], @@ -50,8 +49,9 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do {generate_opts, generation_config} = generate_opts(generation_config, opts) generate_fun = Text.Generation.build_generate(model, spec, generation_config, generate_opts) - generate_fun = fn params, inputs -> + generate_fun = fn params, {inputs, seed} -> inputs = Bumblebee.Featurizer.process_batch(featurizer, inputs) + inputs = Map.put(inputs, "seed", seed) generate_fun.(params, inputs) end @@ -62,7 +62,8 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do generate_fun = Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size) - [params, inputs] + seed = Nx.template({batch_size}, :s64) + [params, {inputs, seed}] end) fn inputs -> @@ -78,24 +79,17 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do Shared.validate_input_for_stream!(input) end - {inputs, multi?} = - Shared.validate_serving_input!(input, fn - %Nx.Tensor{shape: {_}} = input -> - {:ok, Nx.backend_transfer(input, Nx.BinaryBackend)} - - {:file, path} when is_binary(path) -> - ffmpeg_read_as_pcm(path, sampling_rate) - - other -> - {:error, "expected a 1-dimensional tensor or {:file, path}, got: #{inspect(other)}"} - end) + {inputs, multi?} = Shared.validate_serving_input!(input, &validate_input(&1, sampling_rate)) all_chunks = for input <- inputs do if chunk_num_seconds do - chunk_input(input, sampling_rate, chunk_num_seconds, context_num_seconds) + chunks = + chunk_input(input.audio, sampling_rate, chunk_num_seconds, context_num_seconds) + + for {chunk, lengths} <- chunks, do: {{chunk, input.seed}, lengths} else - [{input, nil}] + [{{input.audio, input.seed}, nil}] end end @@ -109,19 +103,45 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do all_chunks |> Stream.chunk_every(batch_size) |> Stream.map(fn all_chunks -> + {all_chunks, seed} = Enum.unzip(all_chunks) + seed = Nx.tensor(seed, backend: Nx.BinaryBackend) inputs = Bumblebee.Featurizer.process_input(featurizer, all_chunks) - Nx.Batch.concatenate([inputs]) + Nx.Batch.concatenate([{inputs, seed}]) end) {stream, {multi?, all_num_chunks, lengths}} else + {all_chunks, seed} = Enum.unzip(all_chunks) + seed = Nx.tensor(seed, backend: Nx.BinaryBackend) inputs = Bumblebee.Featurizer.process_input(featurizer, all_chunks) - {Nx.Batch.concatenate([inputs]), {multi?, all_num_chunks, lengths}} + {Nx.Batch.concatenate([{inputs, seed}]), {multi?, all_num_chunks, lengths}} end end) |> maybe_stream(opts[:stream], spec, featurizer, tokenizer, timestamps?) end + defp validate_input(%{audio: audio} = input, sampling_rate) do + with {:ok, audio} <- parse_audio(audio, sampling_rate) do + {:ok, %{audio: audio, seed: input[:seed] || :erlang.system_time()}} + end + end + + defp validate_input(input, sampling_rate), do: validate_input(%{audio: input}, sampling_rate) + + defp parse_audio(input, sampling_rate) do + case input do + %Nx.Tensor{shape: {_}} = input -> + {:ok, Nx.backend_transfer(input, Nx.BinaryBackend)} + + {:file, path} when is_binary(path) -> + ffmpeg_read_as_pcm(path, sampling_rate) + + other -> + {:error, + "expected audio to be a 1-dimensional tensor or {:file, path}, got: #{inspect(other)}"} + end + end + defp maybe_stream(serving, false, spec, featurizer, tokenizer, timestamps?) do Nx.Serving.client_postprocessing(serving, fn {outputs, _metadata}, {multi?, all_num_chunks, lengths} -> @@ -170,10 +190,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do [] end - opts = - opts - |> Keyword.take([:seed]) - |> Keyword.put(:logits_processors, logits_processors) + opts = [logits_processors: logits_processors] {opts, generation_config} end diff --git a/lib/bumblebee/diffusion/stable_diffusion.ex b/lib/bumblebee/diffusion/stable_diffusion.ex index 986fa931..89059085 100644 --- a/lib/bumblebee/diffusion/stable_diffusion.ex +++ b/lib/bumblebee/diffusion/stable_diffusion.ex @@ -8,7 +8,12 @@ defmodule Bumblebee.Diffusion.StableDiffusion do alias Bumblebee.Shared @type text_to_image_input :: - String.t() | %{:prompt => String.t(), optional(:negative_prompt) => String.t()} + String.t() + | %{ + :prompt => String.t(), + optional(:negative_prompt) => String.t(), + optional(:seed) => integer() + } @type text_to_image_output :: %{results: list(text_to_image_result())} @type text_to_image_result :: %{:image => Nx.Tensor.t(), optional(:is_safe) => boolean()} @@ -44,8 +49,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion do $\omega$ in Equation (2) of the [Imagen paper](https://arxiv.org/pdf/2205.11487.pdf). Defaults to `7.5` - * `:seed` - a seed for the random number generator. Defaults to `0` - * `:compile` - compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the following keys: @@ -131,7 +134,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion do num_steps: 50, num_images_per_prompt: 1, guidance_scale: 7.5, - seed: 0, defn_options: [], preallocate_params: false ]) @@ -183,7 +185,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion do num_images_per_prompt: opts[:num_images_per_prompt], latents_sample_size: unet.spec.sample_size, latents_channels: unet.spec.in_channels, - seed: opts[:seed], guidance_scale: opts[:guidance_scale] ) @@ -237,7 +238,11 @@ defmodule Bumblebee.Diffusion.StableDiffusion do "input_ids" => Nx.template({batch_size, sequence_length}, :u32) } - inputs = %{"unconditional" => text_inputs, "conditional" => text_inputs} + inputs = %{ + "unconditional" => text_inputs, + "conditional" => text_inputs, + "seed" => Nx.template({batch_size}, :s64) + } [encoder_params, unet_params, vae_params, inputs] end) @@ -284,6 +289,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion do prompts = Enum.map(inputs, & &1.prompt) negative_prompts = Enum.map(inputs, & &1.negative_prompt) + seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(backend: Nx.BinaryBackend) conditional = Nx.with_default_backend(Nx.BinaryBackend, fn -> @@ -303,7 +309,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion do ) end) - inputs = %{"unconditional" => unconditional, "conditional" => conditional} + inputs = %{"unconditional" => unconditional, "conditional" => conditional, "seed" => seed} {Nx.Batch.concatenate([inputs]), multi?} end @@ -349,9 +355,10 @@ defmodule Bumblebee.Diffusion.StableDiffusion do num_images_per_prompt = opts[:num_images_per_prompt] latents_sample_size = opts[:latents_sample_size] latents_in_channels = opts[:latents_channels] - seed = opts[:seed] guidance_scale = opts[:guidance_scale] + seed = inputs["seed"] + inputs = Bumblebee.Utils.Nx.composite_concatenate(inputs["unconditional"], inputs["conditional"]) @@ -372,8 +379,15 @@ defmodule Bumblebee.Diffusion.StableDiffusion do {scheduler_state, timesteps} = scheduler_init.(latents_shape) - key = Nx.Random.key(seed) - {latents, _key} = Nx.Random.normal(key, shape: latents_shape) + key = seed |> Nx.vectorize(:batch) |> Nx.Random.key() + + {latents, _key} = + Nx.Random.normal(key, + shape: + {num_images_per_prompt, latents_sample_size, latents_sample_size, latents_in_channels} + ) + + latents = latents |> Nx.devectorize() |> Nx.reshape(latents_shape) {_, latents, _, _} = while {scheduler_state, latents, text_embeddings, unet_params}, timestep <- timesteps do @@ -411,7 +425,12 @@ defmodule Bumblebee.Diffusion.StableDiffusion do defp validate_input(prompt) when is_binary(prompt), do: validate_input(%{prompt: prompt}) defp validate_input(%{prompt: prompt} = input) do - {:ok, %{prompt: prompt, negative_prompt: input[:negative_prompt] || ""}} + {:ok, + %{ + prompt: prompt, + negative_prompt: input[:negative_prompt] || "", + seed: input[:seed] || :erlang.system_time() + }} end defp validate_input(%{} = input) do diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index 26c806d4..c5ab5348 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -126,7 +126,8 @@ defmodule Bumblebee.Text do defdelegate token_classification(model_info, tokenizer, opts \\ []), to: Bumblebee.Text.TokenClassification - @type generation_input :: String.t() + @type generation_input :: + String.t() | %{:text => String.t(), optional(:seed) => integer()} @type generation_output :: %{results: list(generation_result())} @type generation_result :: %{text: String.t()} @@ -138,9 +139,6 @@ defmodule Bumblebee.Text do ## Options - * `:seed` - random seed to use when sampling. By default the current - timestamp is used - * `:compile` - compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the following keys: @@ -215,7 +213,11 @@ defmodule Bumblebee.Text do defdelegate generation(model_info, tokenizer, generation_config, opts \\ []), to: Bumblebee.Text.Generation - @type conversation_input :: %{text: String.t(), history: conversation_history() | nil} + @type conversation_input :: %{ + :text => String.t(), + :history => conversation_history() | nil, + optional(:seed) => integer() + } @type conversation_output :: %{text: String.t(), history: conversation_history()} @type conversation_history :: list({:user | :generated, String.t()}) @@ -233,9 +235,6 @@ defmodule Bumblebee.Text do ## Options - * `:seed` - random seed to use when sampling. By default the current - timestamp is used - * `:compile` - compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the following keys: diff --git a/lib/bumblebee/text/conversation.ex b/lib/bumblebee/text/conversation.ex index d649e054..34cbe985 100644 --- a/lib/bumblebee/text/conversation.ex +++ b/lib/bumblebee/text/conversation.ex @@ -19,7 +19,7 @@ defmodule Bumblebee.Text.Conversation do %Text.GenerationConfig{} = generation_config, opts \\ [] ) do - opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], preallocate_params: false]) + opts = Keyword.validate!(opts, [:compile, defn_options: [], preallocate_params: false]) %{model: model, params: params, spec: spec} = model_info @@ -41,8 +41,7 @@ defmodule Bumblebee.Text.Conversation do batch_size = compile[:batch_size] sequence_length = compile[:sequence_length] - generate_fun = - Text.Generation.build_generate(model, spec, generation_config, Keyword.take(opts, [:seed])) + generate_fun = Text.Generation.build_generate(model, spec, generation_config) batch_keys = Shared.sequence_batch_keys(sequence_length) @@ -56,7 +55,8 @@ defmodule Bumblebee.Text.Conversation do inputs = %{ "input_ids" => Nx.template({batch_size, sequence_length}, :u32), - "attention_mask" => Nx.template({batch_size, sequence_length}, :u32) + "attention_mask" => Nx.template({batch_size, sequence_length}, :u32), + "seed" => Nx.template({batch_size}, :s64) } [params, inputs] @@ -72,7 +72,10 @@ defmodule Bumblebee.Text.Conversation do |> Nx.Serving.batch_size(batch_size) |> Nx.Serving.process_options(batch_keys: batch_keys) |> Nx.Serving.client_preprocessing(fn input -> - {histories, multi?} = Shared.validate_serving_input!(input, &validate_input/1) + {inputs, multi?} = Shared.validate_serving_input!(input, &validate_input/1) + + histories = Enum.map(inputs, & &1.history) + seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(backend: Nx.BinaryBackend) texts = for history <- histories do @@ -89,6 +92,8 @@ defmodule Bumblebee.Text.Conversation do ) end) + inputs = Map.put(inputs, "seed", seed) + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) @@ -104,9 +109,10 @@ defmodule Bumblebee.Text.Conversation do end) end - defp validate_input(%{text: text, history: history}) when is_binary(text) do + defp validate_input(%{text: text, history: history} = input) when is_binary(text) do history = history || [] - {:ok, [{:user, text} | history]} + history = [{:user, text} | history] + {:ok, %{history: history, seed: input[:seed] || :erlang.system_time()}} end defp validate_input(input) do diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 4c71dc32..27a54a44 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -87,10 +87,12 @@ defmodule Bumblebee.Text.Generation do `%Bumblebee.Text.GenerationConfig{}`, see the corresponding docs for more details. - ## Options + Returns a defn JIT-compatible anonymous function, which expects the + model params as the first argument and inputs map as the second + argument. Note that the inputs map should additionally include a + `"seed"` tensor, with one value per input in the batch. - * `:seed` - random seed to use when sampling. By default the current - timestamp is used + ## Options * `:logits_processors` - a list of numerical functions to modify predicted scores at each generation step. The functions are @@ -104,8 +106,7 @@ defmodule Bumblebee.Text.Generation do keyword() ) :: (params :: map(), inputs :: map() -> Nx.t()) def build_generate(model, spec, config, opts \\ []) do - opts = Keyword.validate!(opts, [:seed, logits_processors: []]) - seed = Keyword.get_lazy(opts, :seed, &:erlang.system_time/0) + opts = Keyword.validate!(opts, logits_processors: []) decoder_start_token_id = config.decoder_start_token_id || config.bos_token_id eos_token_id = config.eos_token_id @@ -141,7 +142,6 @@ defmodule Bumblebee.Text.Generation do traverse_cache_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id, - seed: seed, strategy: config.strategy ) end @@ -339,6 +339,8 @@ defmodule Bumblebee.Text.Generation do traverse_cache_fun, opts \\ [] ) do + {seed, inputs} = pop_seed(inputs) + {decoder_inputs, decoder_input_ids, max_length} = prepare_inputs_fun.(inputs, params) length = Nx.axis_size(decoder_input_ids, 1) @@ -350,7 +352,6 @@ defmodule Bumblebee.Text.Generation do end strategy = opts[:strategy] - seed = opts[:seed] sequences = case strategy.type do @@ -381,16 +382,15 @@ defmodule Bumblebee.Text.Generation do ) :multinomial_sampling -> - prng_key = Nx.Random.key(seed) - sampling( decoder_inputs, decoder_input_ids, predict_fun, params, + seed, logits_processor_fun, update_inputs_fun, - merge_options([max_length: max_length, prng_key: prng_key], opts) + merge_options([max_length: max_length], opts) ) end @@ -398,6 +398,8 @@ defmodule Bumblebee.Text.Generation do sequences[[.., length..-1//1]] end + deftransformp pop_seed(inputs), do: Map.pop!(inputs, "seed") + deftransformp merge_options(left, right), do: left ++ right # Greedy search @@ -707,6 +709,7 @@ defmodule Bumblebee.Text.Generation do decoder_input_ids, predict_fun, params, + seed, logits_processor_fun, update_inputs_fun, opts \\ [] @@ -714,11 +717,12 @@ defmodule Bumblebee.Text.Generation do max_length = opts[:max_length] pad_token_id = opts[:pad_token_id] eos_token_id = opts[:eos_token_id] - prng_key = opts[:prng_key] {sequences, length = input_length, finished?} = init_sequences(decoder_input_ids, max_length, pad_token_id) + prng_key = seed |> Nx.vectorize(:batch) |> Nx.Random.key() + # The loop works with inputs of length 1, so if the initial input # is longer, we make the initial pass outside {sequences, length, finished?, inputs, prng_key} = @@ -801,13 +805,10 @@ defmodule Bumblebee.Text.Generation do end deftransformp batched_choice(key, scores) do - {batch_size, vocab_size} = Nx.shape(scores) + vocab_size = Nx.axis_size(scores, 1) vocab = Nx.iota({vocab_size}) - keys = Nx.Random.split(key, parts: batch_size) - - key = Nx.vectorize(keys, :batch) probabilities = Nx.vectorize(scores, :batch) {tokens, _} = Nx.Random.choice(key, vocab, probabilities, samples: 1) @@ -823,7 +824,6 @@ defmodule Bumblebee.Text.Generation do def generation(model_info, tokenizer, %Text.GenerationConfig{} = generation_config, opts \\ []) do opts = Keyword.validate!(opts, [ - :seed, :compile, defn_options: [], preallocate_params: false, @@ -850,7 +850,7 @@ defmodule Bumblebee.Text.Generation do batch_size = compile[:batch_size] sequence_length = compile[:sequence_length] - generate_fun = build_generate(model, spec, generation_config, Keyword.take(opts, [:seed])) + generate_fun = build_generate(model, spec, generation_config) batch_keys = Shared.sequence_batch_keys(sequence_length) @@ -864,7 +864,8 @@ defmodule Bumblebee.Text.Generation do inputs = %{ "input_ids" => Nx.template({batch_size, sequence_length}, :u32), - "attention_mask" => Nx.template({batch_size, sequence_length}, :u32) + "attention_mask" => Nx.template({batch_size, sequence_length}, :u32), + "seed" => Nx.template({batch_size}, :s64) } [params, inputs] @@ -884,7 +885,10 @@ defmodule Bumblebee.Text.Generation do Shared.validate_input_for_stream!(input) end - {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) + {inputs, multi?} = Shared.validate_serving_input!(input, &validate_input/1) + + texts = Enum.map(inputs, & &1.text) + seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(backend: Nx.BinaryBackend) inputs = Nx.with_default_backend(Nx.BinaryBackend, fn -> @@ -895,6 +899,8 @@ defmodule Bumblebee.Text.Generation do ) end) + inputs = Map.put(inputs, "seed", seed) + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) @@ -903,6 +909,20 @@ defmodule Bumblebee.Text.Generation do |> maybe_stream(opts[:stream], tokenizer) end + defp validate_input(text) when is_binary(text), do: validate_input(%{text: text}) + + defp validate_input(%{text: text} = input) do + {:ok, %{text: text, seed: input[:seed] || :erlang.system_time()}} + end + + defp validate_input(%{} = input) do + {:error, "expected the input map to have :text key, got: #{inspect(input)}"} + end + + defp validate_input(input) do + {:error, "expected either a string or a map, got: #{inspect(input)}"} + end + defp maybe_stream(serving, false, tokenizer) do Nx.Serving.client_postprocessing(serving, fn {token_ids, _metadata}, multi? -> decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids) diff --git a/lib/bumblebee/vision.ex b/lib/bumblebee/vision.ex index 5b1130e9..27e82d1d 100644 --- a/lib/bumblebee/vision.ex +++ b/lib/bumblebee/vision.ex @@ -86,7 +86,7 @@ defmodule Bumblebee.Vision do defdelegate image_classification(model_info, featurizer, opts \\ []), to: Bumblebee.Vision.ImageClassification - @type image_to_text_input :: image() + @type image_to_text_input :: image() | %{:image => image(), optional(:seed) => integer()} @type image_to_text_output :: %{results: list(image_to_text_result())} @type image_to_text_result :: %{text: String.t()} @@ -98,9 +98,6 @@ defmodule Bumblebee.Vision do ## Options - * `:seed` - random seed to use when sampling. By default the current - timestamp is used - * `:compile` - compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the following keys: diff --git a/lib/bumblebee/vision/image_to_text.ex b/lib/bumblebee/vision/image_to_text.ex index 29d96b8d..08889d72 100644 --- a/lib/bumblebee/vision/image_to_text.ex +++ b/lib/bumblebee/vision/image_to_text.ex @@ -11,7 +11,7 @@ defmodule Bumblebee.Vision.ImageToText do %Text.GenerationConfig{} = generation_config, opts \\ [] ) do - opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], preallocate_params: false]) + opts = Keyword.validate!(opts, [:compile, defn_options: [], preallocate_params: false]) %{model: model, params: params, spec: spec} = model_info @@ -29,11 +29,11 @@ defmodule Bumblebee.Vision.ImageToText do batch_size = compile[:batch_size] - generate_fun = - Text.Generation.build_generate(model, spec, generation_config, Keyword.take(opts, [:seed])) + generate_fun = Text.Generation.build_generate(model, spec, generation_config) - generate_fun = fn params, inputs -> + generate_fun = fn params, {inputs, seed} -> inputs = Bumblebee.Featurizer.process_batch(featurizer, inputs) + inputs = Map.put(inputs, "seed", seed) generate_fun.(params, inputs) end @@ -44,7 +44,8 @@ defmodule Bumblebee.Vision.ImageToText do generate_fun = Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size) - [params, inputs] + seed = Nx.template({batch_size}, :s64) + [params, {inputs, seed}] end) fn inputs -> @@ -56,9 +57,13 @@ defmodule Bumblebee.Vision.ImageToText do ) |> Nx.Serving.batch_size(batch_size) |> Nx.Serving.client_preprocessing(fn input -> - {images, multi?} = Shared.validate_serving_input!(input, &Shared.validate_image/1) + {inputs, multi?} = Shared.validate_serving_input!(input, &validate_input/1) + + images = Enum.map(inputs, & &1.image) + seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(backend: Nx.BinaryBackend) + inputs = Bumblebee.Featurizer.process_input(featurizer, images) - {Nx.Batch.concatenate([inputs]), multi?} + {Nx.Batch.concatenate([{inputs, seed}]), multi?} end) |> Nx.Serving.client_postprocessing(fn {token_ids, _metadata}, multi? -> decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids) @@ -68,4 +73,14 @@ defmodule Bumblebee.Vision.ImageToText do |> Shared.normalize_output(multi?) end) end + + defp validate_input(%{image: image} = input) do + if Shared.image?(image) do + {:ok, %{image: image, seed: input[:seed] || :erlang.system_time()}} + else + {:error, "expected an image, got: #{inspect(image)}"} + end + end + + defp validate_input(input), do: validate_input(%{image: input}) end diff --git a/test/bumblebee/text/bart_test.exs b/test/bumblebee/text/bart_test.exs index 73d1d9de..3f351144 100644 --- a/test/bumblebee/text/bart_test.exs +++ b/test/bumblebee/text/bart_test.exs @@ -149,7 +149,8 @@ defmodule Bumblebee.Text.BartTest do inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), - "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + "seed" => Nx.tensor([[0]]) } generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) diff --git a/test/bumblebee/text/blenderbot_test.exs b/test/bumblebee/text/blenderbot_test.exs index ddc7e6e2..c11235da 100644 --- a/test/bumblebee/text/blenderbot_test.exs +++ b/test/bumblebee/text/blenderbot_test.exs @@ -74,7 +74,8 @@ defmodule Bumblebee.Text.BlenderbotTest do inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), - "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + "seed" => Nx.tensor([[0]]) } generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 0d3d87c5..86ca3576 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -55,13 +55,13 @@ defmodule Bumblebee.Text.GenerationTest do strategy: %{type: :multinomial_sampling} ) - serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config, seed: 0) + serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config) # Note that this is just a snapshot test, we do not use any # reference value, because of PRNG difference - assert %{results: [%{text: " to fall asleep.\"\n\nThis is not Wallace's fifth"}]} = - Nx.Serving.run(serving, "I was going") + assert %{results: [%{text: " to give a speech to these execs. I don't"}]} = + Nx.Serving.run(serving, %{text: "I was going", seed: 0}) end test "contrastive search" do diff --git a/test/bumblebee/text/mbart_test.exs b/test/bumblebee/text/mbart_test.exs index 4aa0e825..2d7a1f94 100644 --- a/test/bumblebee/text/mbart_test.exs +++ b/test/bumblebee/text/mbart_test.exs @@ -147,7 +147,8 @@ defmodule Bumblebee.Text.MbartTest do inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), - "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + "seed" => Nx.tensor([[0]]) } generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) diff --git a/test/bumblebee/text/t5_test.exs b/test/bumblebee/text/t5_test.exs index 5d4a2840..270ae9a9 100644 --- a/test/bumblebee/text/t5_test.exs +++ b/test/bumblebee/text/t5_test.exs @@ -159,7 +159,8 @@ defmodule Bumblebee.Text.T5Test do inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), - "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + "seed" => Nx.tensor([[0]]) } generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) @@ -187,7 +188,8 @@ defmodule Bumblebee.Text.T5Test do inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), - "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + "seed" => Nx.tensor([[0]]) } generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3)