From 35b45163e2a15aaefdba42585eb3a39cfe76710a Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 16 Feb 2023 18:03:02 -0800 Subject: [PATCH 1/6] Start to add more generation strategies --- lib/bumblebee/text.ex | 11 ++ lib/bumblebee/text/generation.ex | 196 ++++++++++++++++++++++-- test/bumblebee/text/generation_test.exs | 21 ++- 3 files changed, 217 insertions(+), 11 deletions(-) diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index 5fd341d4..20fc7026 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -116,6 +116,17 @@ defmodule Bumblebee.Text do * `:min_new_tokens` - the minimum number of tokens to be generated, ignoring the number of tokens in the prompt + * `:num_beams` - the number of beams to use in a beam search. If set + to 1, beam search will not be used. Defaults to `1` + + * `:early_stopping` - whether to stop the beam search when at least + `:num_beams` sentences are finished per batch or not. Defaults to + `false` + + * `:sample` - whether or not to use random sampling. Defaults to `false` + + * `:prng_key` - random key to use when sampling. Defaults to `nil` + * `:compile` - compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the following keys: diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index ddbac600..0315a35a 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -121,6 +121,17 @@ defmodule Bumblebee.Text.Generation do (including padding). In general, prefer `:min_new_tokens`, which ignores the number of tokens in the prompt + * `:num_beams` - the number of beams to use in a beam search. If set + to 1, beam search will not be used. Defaults to `1` + + * `:early_stopping` - whether to stop the beam search when at least + `:num_beams` sentences are finished per batch or not. Defaults to + `false` + + * `:sample` - whether or not to use random sampling. Defaults to `false` + + * `:prng_key` - random key to use when sampling. Defaults to `nil` + * `:decoder_start_token_id` - the id of the initial token when generating from scratch, in case of encoder-decoder models @@ -148,6 +159,10 @@ defmodule Bumblebee.Text.Generation do min_new_tokens: nil, max_length: nil, min_length: nil, + num_beams: 1, + sample: false, + early_stopping: false, + prng_key: nil, decoder_start_token_id: Map.get(spec, :decoder_start_token_id), bos_token_id: Map.get(spec, :bos_token_id), eos_token_id: Map.get(spec, :eos_token_id), @@ -164,6 +179,11 @@ defmodule Bumblebee.Text.Generation do forced_eos_token_id = opts[:forced_eos_token_id] forced_token_ids = opts[:forced_token_ids] + num_beams = opts[:num_beams] + early_stopping = opts[:early_stopping] + sample = opts[:sample] + prng_key = opts[:prng_key] + {max_length_fun, min_length_fun} = lazy_lengths_from_opts(opts) {prepare_inputs_fun, update_inputs_fun} = @@ -188,7 +208,11 @@ defmodule Bumblebee.Text.Generation do prepare_inputs_fun, update_inputs_fun, pad_token_id: pad_token_id, - eos_token_id: eos_token_id + eos_token_id: eos_token_id, + num_beams: num_beams, + early_stopping: early_stopping, + sample: sample, + prng_key: prng_key ) end @@ -374,19 +398,43 @@ defmodule Bumblebee.Text.Generation do update_inputs_fun, opts \\ [] ) do + {num_beams, opts} = Keyword.pop!(opts, :num_beams) + {early_stopping, opts} = Keyword.pop!(opts, :early_stopping) + {sample, opts} = Keyword.pop!(opts, :sample) + {prng_key, opts} = Keyword.pop!(opts, :prng_key) + {decoder_inputs, decoder_input_ids, max_length} = prepare_inputs_fun.(inputs, params) - greedy( - decoder_inputs, - decoder_input_ids, - predict_fun, - params, - logits_processor_fun, - update_inputs_fun, - [max_length: max_length] ++ opts - ) + cond do + sample and num_beams == 1 -> + prng_key = prng_key || Nx.Random.key(:erlang.system_time()) + prng_key = Nx.backend_copy(prng_key, Nx.BinaryBackend) + + sample( + inputs, + decoder_input_ids, + predict_fun, + params, + logits_processor_fun, + update_inputs_fun, + [max_length: max_length, prng_key: prng_key] ++ opts + ) + + true -> + greedy( + decoder_inputs, + decoder_input_ids, + predict_fun, + params, + logits_processor_fun, + update_inputs_fun, + [max_length: max_length] ++ opts + ) + end end + ## Greedy Generation + defnp greedy( inputs, decoder_input_ids, @@ -507,6 +555,134 @@ defmodule Bumblebee.Text.Generation do {sequences, length + 1, finished?, inputs} end + ## Sampling + + defnp sample( + inputs, + decoder_input_ids, + predict_fun, + params, + logits_processor_fun, + update_inputs_fun, + opts \\ [] + ) do + max_length = opts[:max_length] + pad_token_id = opts[:pad_token_id] + eos_token_id = opts[:eos_token_id] + prng_key = opts[:prng_key] + + {batch_size, length} = Nx.shape(decoder_input_ids) + + if length > max_length do + raise ArgumentError, "expected the input to be at most #{max_length} tokens, got: #{length}" + end + + sequences = Nx.broadcast(pad_token_id, {batch_size, max_length}) + sequences = Nx.put_slice(sequences, [0, 0], decoder_input_ids) + + finished? = Nx.broadcast(Nx.tensor(0, type: :u8), {batch_size}) + + input_length = length + + # The loop works with inputs of length 1, so if the initial input + # is longer, we make the initial pass outside + {sequences, length, finished?, inputs, prng_key} = + if length > 1 do + sample_step( + sequences, + length, + finished?, + inputs, + input_length, + predict_fun, + params, + prng_key, + logits_processor_fun, + update_inputs_fun, + pad_token_id: pad_token_id, + eos_token_id: eos_token_id + ) + else + {sequences, length, finished?, inputs, prng_key} + end + + {sequences, _length, _finished?, _inputs, _params, _key} = + while {sequences, length, finished?, inputs, params, prng_key}, + sample_condition(finished?, length, max_length) do + {sequences, length, finished?, inputs, prng_key} = + sample_step( + sequences, + length, + finished?, + inputs, + input_length, + predict_fun, + params, + prng_key, + logits_processor_fun, + update_inputs_fun, + pad_token_id: pad_token_id, + eos_token_id: eos_token_id + ) + + {sequences, length, finished?, inputs, params, prng_key} + end + + sequences + end + + defnp sample_condition(finished?, length, max_length) do + Nx.logical_not(Nx.logical_or(Nx.all(finished?), Nx.equal(length, max_length))) + end + + defnp sample_step( + sequences, + length, + finished?, + inputs, + input_length, + predict_fun, + params, + prng_key, + logits_processor_fun, + update_inputs_fun, + opts \\ [] + ) do + pad_token_id = opts[:pad_token_id] + eos_token_id = opts[:eos_token_id] + + key = Nx.Random.split(prng_key) + {key, key_next} = {key[1], key[0]} + + model_outputs = predict_fun.(params, inputs) + + logits = model_outputs.logits[[0..-1//1, -1]] + + logits = + logits_processor_fun.(logits, %{ + sequences: sequences, + length: length, + input_length: input_length + }) + + # TODO: logits warper + + vocab = Nx.iota(Nx.shape(logits), axis: -1) + probabilities = Axon.Activations.softmax(logits) + next_token = Nx.Random.choice(key, vocab, probabilities) + + next_finished? = Nx.logical_or(finished?, Nx.equal(next_token, eos_token_id)) + next_token = next_token * Nx.logical_not(next_finished?) + pad_token_id * next_finished? + next_token = Nx.new_axis(next_token, 1) + + sequences = Nx.put_slice(sequences, [0, length], next_token) + inputs = update_inputs_fun.(inputs, model_outputs, next_token) + + {sequences, length + 1, next_finished?, inputs, key_next} + end + + ## Beam Search + # Logit processors defnp bos_token_logits_processor(logits, context, opts \\ []) do diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 8fb7e5da..6d4c0f6a 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -6,7 +6,7 @@ defmodule Bumblebee.Text.GenerationTest do @moduletag model_test_tags() describe "integration" do - test "generates text" do + test "generates text with greedy generation" do {:ok, model_info} = Bumblebee.load_model({:hf, "facebook/bart-large-cnn"}) {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-cnn"}) @@ -21,5 +21,24 @@ defmodule Bumblebee.Text.GenerationTest do assert %{results: [%{text: "PG&E scheduled the black"}]} = Nx.Serving.run(serving, article) end + + test "generates text with sampling" do + {:ok, model_info} = Bumblebee.load_model({:hf, "gpt2"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "gpt2"}) + + serving = + Bumblebee.Text.generation(model_info, tokenizer, + max_new_tokens: 8, + sample: true, + prng_key: Nx.Random.key(0) + ) + + prompt = """ + I enjoy walking with my cute dog + """ + + assert %{results: [%{text: "On the field, on a field trip,"}]} = + Nx.Serving.run(serving, prompt) + end end end From d3a6e3c0d5cf0c107fbea6f6e1e8deae09aa5078 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 17 Feb 2023 06:02:05 -0800 Subject: [PATCH 2/6] Update lib/bumblebee/text/generation.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- lib/bumblebee/text/generation.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 0315a35a..131accdb 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -669,7 +669,7 @@ defmodule Bumblebee.Text.Generation do vocab = Nx.iota(Nx.shape(logits), axis: -1) probabilities = Axon.Activations.softmax(logits) - next_token = Nx.Random.choice(key, vocab, probabilities) + next_token = Nx.Random.choice(key, vocab, probabilities, []) next_finished? = Nx.logical_or(finished?, Nx.equal(next_token, eos_token_id)) next_token = next_token * Nx.logical_not(next_finished?) + pad_token_id * next_finished? From 47a58a663153b11d17b6c22fddcc5b5f5994717b Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 17 Feb 2023 06:02:09 -0800 Subject: [PATCH 3/6] Update lib/bumblebee/text/generation.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- lib/bumblebee/text/generation.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 131accdb..1667bf8a 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -632,7 +632,7 @@ defmodule Bumblebee.Text.Generation do end defnp sample_condition(finished?, length, max_length) do - Nx.logical_not(Nx.logical_or(Nx.all(finished?), Nx.equal(length, max_length))) + not(Nx.all(finished?) or length == max_length) end defnp sample_step( From f2cc3f8f3d209d2a993786abd4600f764b8747f3 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 17 Feb 2023 07:34:49 -0800 Subject: [PATCH 4/6] Pass generation tests --- lib/bumblebee/text.ex | 9 +--- lib/bumblebee/text/generation.ex | 56 ++++++++++++------------- mix.lock | 16 +++---- test/bumblebee/text/generation_test.exs | 9 ++-- 4 files changed, 42 insertions(+), 48 deletions(-) diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index 20fc7026..3e4190e2 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -116,16 +116,9 @@ defmodule Bumblebee.Text do * `:min_new_tokens` - the minimum number of tokens to be generated, ignoring the number of tokens in the prompt - * `:num_beams` - the number of beams to use in a beam search. If set - to 1, beam search will not be used. Defaults to `1` - - * `:early_stopping` - whether to stop the beam search when at least - `:num_beams` sentences are finished per batch or not. Defaults to - `false` - * `:sample` - whether or not to use random sampling. Defaults to `false` - * `:prng_key` - random key to use when sampling. Defaults to `nil` + * `:seed` - random seed to use when sampling. Defaults to `nil` * `:compile` - compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 1667bf8a..5d78101f 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -48,7 +48,7 @@ defmodule Bumblebee.Text.Generation do generate_fun = build_generate(model_info.model, model_info.spec, opts) Nx.Serving.new( - fn -> + fn defn_options -> generate_fun = Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> inputs = %{ @@ -64,7 +64,7 @@ defmodule Bumblebee.Text.Generation do generate_fun.(params, inputs) end end, - batch_size: batch_size + [batch_size: batch_size] ++ defn_options ) |> Nx.Serving.client_preprocessing(fn input -> {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) @@ -121,16 +121,9 @@ defmodule Bumblebee.Text.Generation do (including padding). In general, prefer `:min_new_tokens`, which ignores the number of tokens in the prompt - * `:num_beams` - the number of beams to use in a beam search. If set - to 1, beam search will not be used. Defaults to `1` - - * `:early_stopping` - whether to stop the beam search when at least - `:num_beams` sentences are finished per batch or not. Defaults to - `false` - * `:sample` - whether or not to use random sampling. Defaults to `false` - * `:prng_key` - random key to use when sampling. Defaults to `nil` + * `:seed` - random seed to use when sampling. Defaults to `nil` * `:decoder_start_token_id` - the id of the initial token when generating from scratch, in case of encoder-decoder models @@ -159,10 +152,8 @@ defmodule Bumblebee.Text.Generation do min_new_tokens: nil, max_length: nil, min_length: nil, - num_beams: 1, sample: false, - early_stopping: false, - prng_key: nil, + seed: nil, decoder_start_token_id: Map.get(spec, :decoder_start_token_id), bos_token_id: Map.get(spec, :bos_token_id), eos_token_id: Map.get(spec, :eos_token_id), @@ -179,10 +170,8 @@ defmodule Bumblebee.Text.Generation do forced_eos_token_id = opts[:forced_eos_token_id] forced_token_ids = opts[:forced_token_ids] - num_beams = opts[:num_beams] - early_stopping = opts[:early_stopping] sample = opts[:sample] - prng_key = opts[:prng_key] + seed = opts[:seed] {max_length_fun, min_length_fun} = lazy_lengths_from_opts(opts) @@ -209,10 +198,8 @@ defmodule Bumblebee.Text.Generation do update_inputs_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id, - num_beams: num_beams, - early_stopping: early_stopping, sample: sample, - prng_key: prng_key + seed: seed ) end @@ -398,20 +385,18 @@ defmodule Bumblebee.Text.Generation do update_inputs_fun, opts \\ [] ) do - {num_beams, opts} = Keyword.pop!(opts, :num_beams) - {early_stopping, opts} = Keyword.pop!(opts, :early_stopping) {sample, opts} = Keyword.pop!(opts, :sample) - {prng_key, opts} = Keyword.pop!(opts, :prng_key) + {seed, opts} = Keyword.pop!(opts, :seed) {decoder_inputs, decoder_input_ids, max_length} = prepare_inputs_fun.(inputs, params) cond do - sample and num_beams == 1 -> - prng_key = prng_key || Nx.Random.key(:erlang.system_time()) - prng_key = Nx.backend_copy(prng_key, Nx.BinaryBackend) + sample -> + seed = seed || :erlang.system_time() + prng_key = Nx.Random.key(seed) sample( - inputs, + decoder_inputs, decoder_input_ids, predict_fun, params, @@ -632,7 +617,7 @@ defmodule Bumblebee.Text.Generation do end defnp sample_condition(finished?, length, max_length) do - not(Nx.all(finished?) or length == max_length) + not (Nx.all(finished?) or length == max_length) end defnp sample_step( @@ -669,11 +654,10 @@ defmodule Bumblebee.Text.Generation do vocab = Nx.iota(Nx.shape(logits), axis: -1) probabilities = Axon.Activations.softmax(logits) - next_token = Nx.Random.choice(key, vocab, probabilities, []) + next_token = batched_choice(key, vocab, probabilities) next_finished? = Nx.logical_or(finished?, Nx.equal(next_token, eos_token_id)) next_token = next_token * Nx.logical_not(next_finished?) + pad_token_id * next_finished? - next_token = Nx.new_axis(next_token, 1) sequences = Nx.put_slice(sequences, [0, length], next_token) inputs = update_inputs_fun.(inputs, model_outputs, next_token) @@ -681,6 +665,20 @@ defmodule Bumblebee.Text.Generation do {sequences, length + 1, next_finished?, inputs, key_next} end + deftransformp batched_choice(key, vocab, probabilities) do + {batch_size, _} = Nx.shape(vocab) + + tokens = + for i <- 0..(batch_size - 1) do + candidate_tokens = vocab[[i, 0..-1//1]] + candidate_probs = probabilities[[i, 0..-1//1]] + {token, _} = Nx.Random.choice(key, candidate_tokens, candidate_probs, samples: 1) + Nx.new_axis(token, 0) + end + + Nx.concatenate(tokens, axis: 0) + end + ## Beam Search # Logit processors diff --git a/mix.lock b/mix.lock index 77ac8c22..f83d8964 100644 --- a/mix.lock +++ b/mix.lock @@ -3,35 +3,35 @@ "bypass": {:hex, :bypass, "2.1.0", "909782781bf8e20ee86a9cabde36b259d44af8b9f38756173e8f5e2e1fabb9b1", [:mix], [{:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.0", [hex: :plug_cowboy, repo: "hexpm", optional: false]}, {:ranch, "~> 1.3", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "d9b5df8fa5b7a6efa08384e9bbecfe4ce61c77d28a4282f79e02f1ef78d96b80"}, "castore": {:hex, :castore, "0.1.22", "4127549e411bedd012ca3a308dede574f43819fe9394254ca55ab4895abfa1a2", [:mix], [], "hexpm", "c17576df47eb5aa1ee40cc4134316a99f5cad3e215d5c77b8dd3cfef12a22cac"}, "cc_precompiler": {:hex, :cc_precompiler, "0.1.5", "ac3ef86f31ab579b856192a948e956cc3e4bb5006e303c4ab4b24958108e218a", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "ee5b2e56eb03798231a3d322579fff509139a534ef54205d04c188e18cab1f57"}, - "complex": {:hex, :complex, "0.4.3", "84db4aad241099a8785446ac6eacf498bf3a60634a0e45c7745d875714ddbf98", [:mix], [], "hexpm", "2ceda96ebddcc22697974f1a2666d4cc5dfdd34f8cd8c4f9dced037bcb41eeb5"}, + "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, "cowboy": {:hex, :cowboy, "2.9.0", "865dd8b6607e14cf03282e10e934023a1bd8be6f6bacf921a7e2a96d800cd452", [:make, :rebar3], [{:cowlib, "2.11.0", [hex: :cowlib, repo: "hexpm", optional: false]}, {:ranch, "1.8.0", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "2c729f934b4e1aa149aff882f57c6372c15399a20d54f65c8d67bef583021bde"}, "cowboy_telemetry": {:hex, :cowboy_telemetry, "0.4.0", "f239f68b588efa7707abce16a84d0d2acf3a0f50571f8bb7f56a15865aae820c", [:rebar3], [{:cowboy, "~> 2.7", [hex: :cowboy, repo: "hexpm", optional: false]}, {:telemetry, "~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "7d98bac1ee4565d31b62d59f8823dfd8356a169e7fcbb83831b8a5397404c9de"}, "cowlib": {:hex, :cowlib, "2.11.0", "0b9ff9c346629256c42ebe1eeb769a83c6cb771a6ee5960bd110ab0b9b872063", [:make, :rebar3], [], "hexpm", "2b3e9da0b21c4565751a6d4901c20d1b4cc25cbb7fd50d91d2ab6dd287bc86a9"}, "decimal": {:hex, :decimal, "2.0.0", "a78296e617b0f5dd4c6caf57c714431347912ffb1d0842e998e9792b5642d697", [:mix], [], "hexpm", "34666e9c55dea81013e77d9d87370fe6cb6291d1ef32f46a1600230b1d44f577"}, "dll_loader_helper": {:hex, :dll_loader_helper, "0.1.10", "ba85d66f82c1748513dbaee71aa9d0593bb9a65dba246b980753c4d683b0a07b", [:make, :mix], [{:castore, ">= 0.0.0", [hex: :castore, repo: "hexpm", optional: false]}, {:cc_precompiler, "~> 0.1", [hex: :cc_precompiler, repo: "hexpm", optional: false]}], "hexpm", "c0d02a2d8cd0085252f7551a343f89060bb7beb3f303d991e46a7370ed257485"}, - "earmark_parser": {:hex, :earmark_parser, "1.4.29", "149d50dcb3a93d9f3d6f3ecf18c918fb5a2d3c001b5d3305c926cddfbd33355b", [:mix], [], "hexpm", "4902af1b3eb139016aed210888748db8070b8125c2342ce3dcae4f38dcc63503"}, - "elixir_make": {:hex, :elixir_make, "0.7.3", "c37fdae1b52d2cc51069713a58c2314877c1ad40800a57efb213f77b078a460d", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "24ada3e3996adbed1fa024ca14995ef2ba3d0d17b678b0f3f2b1f66e6ce2b274"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.30", "0b938aa5b9bafd455056440cdaa2a79197ca5e693830b4a982beada840513c5f", [:mix], [], "hexpm", "3b5385c2d36b0473d0b206927b841343d25adb14f95f0110062506b300cd5a1b"}, + "elixir_make": {:hex, :elixir_make, "0.7.4", "5439110c964ffdd8212ca919b5b8beac423085a77ad33d5e394abe812c2d2d75", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "70c33052f7b00c813fd66d15a3cf1f7d1e122860c572ec81b8181b1276074157"}, "ex_doc": {:hex, :ex_doc, "0.29.1", "b1c652fa5f92ee9cf15c75271168027f92039b3877094290a75abcaac82a9f77", [:mix], [{:earmark_parser, "~> 1.4.19", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "b7745fa6374a36daf484e2a2012274950e084815b936b1319aeebcf7809574f6"}, - "exla": {:git, "https://github.com/elixir-nx/nx.git", "13027a000f31fd196e50bcce54f754a19c24a1d2", [sparse: "exla"]}, + "exla": {:git, "https://github.com/elixir-nx/nx.git", "597103b22837db56e2ba9c2c06030935ab592077", [sparse: "exla"]}, "jason": {:hex, :jason, "1.4.0", "e855647bc964a44e2f67df589ccf49105ae039d4179db7f6271dfd3843dc27e6", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "79a3791085b2a0f743ca04cec0f7be26443738779d09302e01318f97bdb82121"}, "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"}, "mime": {:hex, :mime, "2.0.3", "3676436d3d1f7b81b5a2d2bd8405f412c677558c81b1c92be58c00562bb59095", [:mix], [], "hexpm", "27a30bf0db44d25eecba73755acf4068cbfe26a4372f9eb3e4ea3a45956bff6b"}, "nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"}, - "nx": {:git, "https://github.com/elixir-nx/nx.git", "13027a000f31fd196e50bcce54f754a19c24a1d2", [sparse: "nx"]}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "597103b22837db56e2ba9c2c06030935ab592077", [sparse: "nx"]}, "nx_image": {:hex, :nx_image, "0.1.0", "ae10fa41fa95126f934d6160ef4320f7db583535fb868415f2562fe19969d245", [:mix], [{:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "60a2928164cdca540b4c180ff25579b97a5f2a650fc890d40db3e1a7798c93ad"}, - "nx_signal": {:git, "https://github.com/polvalente/nx-signal.git", "d6f0b87daa31d49bae09cec06c9f6e7747cf8701", [branch: "main"]}, + "nx_signal": {:git, "https://github.com/polvalente/nx-signal.git", "11626601db097d3731151883c3a2e4844e2ee2ee", [branch: "main"]}, "plug": {:hex, :plug, "1.14.0", "ba4f558468f69cbd9f6b356d25443d0b796fbdc887e03fa89001384a9cac638f", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "bf020432c7d4feb7b3af16a0c2701455cbbbb95e5b6866132cb09eb0c29adc14"}, "plug_cowboy": {:hex, :plug_cowboy, "2.6.0", "d1cf12ff96a1ca4f52207c5271a6c351a4733f413803488d75b70ccf44aebec2", [:mix], [{:cowboy, "~> 2.7", [hex: :cowboy, repo: "hexpm", optional: false]}, {:cowboy_telemetry, "~> 0.3", [hex: :cowboy_telemetry, repo: "hexpm", optional: false]}, {:plug, "~> 1.14", [hex: :plug, repo: "hexpm", optional: false]}], "hexpm", "073cf20b753ce6682ed72905cd62a2d4bd9bad1bf9f7feb02a1b8e525bd94fa6"}, "plug_crypto": {:hex, :plug_crypto, "1.2.3", "8f77d13aeb32bfd9e654cb68f0af517b371fb34c56c9f2b58fe3df1235c1251a", [:mix], [], "hexpm", "b5672099c6ad5c202c45f5a403f21a3411247f164e4a8fab056e5cd8a290f4a2"}, "progress_bar": {:hex, :progress_bar, "2.0.1", "7b40200112ae533d5adceb80ff75fbe66dc753bca5f6c55c073bfc122d71896d", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "2519eb58a2f149a3a094e729378256d8cb6d96a259ec94841bd69fdc71f18f87"}, "ranch": {:hex, :ranch, "1.8.0", "8c7a100a139fd57f17327b6413e4167ac559fbc04ca7448e9be9057311597a1d", [:make, :rebar3], [], "hexpm", "49fbcfd3682fab1f5d109351b61257676da1a2fdbe295904176d5e521a2ddfe5"}, - "rustler_precompiled": {:hex, :rustler_precompiled, "0.5.5", "a075a92c8e748ce5c4f7b2cf573a072d206a6d8d99c53f627e81d3f2b10616a3", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "e8a7f1abfec8d68683bb25d14efc88496f091ef113f7f4c45d39f3606f7223f6"}, + "rustler_precompiled": {:hex, :rustler_precompiled, "0.6.1", "160b545bce8bf9a3f1b436b2c10f53574036a0db628e40f393328cbbe593602f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "0dd269fa261c4e3df290b12031c575fff07a542749f7b0e8b744d72d66c43600"}, "stb_image": {:hex, :stb_image, "0.6.0", "f08de87e3481249d1a96860e64739e378cd65ed40505c4383682974864dd209c", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "5ae107f35891b32679238e2d549c9f2a211fbfa163a0c7d0d8326b24b3bc34e4"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "tokenizers": {:hex, :tokenizers, "0.2.0", "3aa9811396680f849803f6a3978a310a653059613592710ce5f883d67ff17a33", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.5", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "2496fd44cf96bcefc70e75cf7e34126de8b63ccb9ad35967d6d5d8661cbdb6b7"}, - "torchx": {:git, "https://github.com/elixir-nx/nx.git", "13027a000f31fd196e50bcce54f754a19c24a1d2", [sparse: "torchx"]}, + "torchx": {:git, "https://github.com/elixir-nx/nx.git", "597103b22837db56e2ba9c2c06030935ab592077", [sparse: "torchx"]}, "unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"}, "unzip": {:hex, :unzip, "0.8.0", "ee21d87c21b01567317387dab4228ac570ca15b41cfc221a067354cbf8e68c4d", [:mix], [], "hexpm", "ffa67a483efcedcb5876971a50947222e104d5f8fea2c4a0441e6f7967854827"}, "xla": {:hex, :xla, "0.4.3", "cf6201aaa44d990298996156a83a16b9a87c5fbb257758dbf4c3e83c5e1c4b96", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "caae164b56dcaec6fbcabcd7dea14303afde07623b0cfa4a3cd2576b923105f5"}, diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 6d4c0f6a..80691e6e 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -30,15 +30,18 @@ defmodule Bumblebee.Text.GenerationTest do Bumblebee.Text.generation(model_info, tokenizer, max_new_tokens: 8, sample: true, - prng_key: Nx.Random.key(0) + seed: 0 ) prompt = """ I enjoy walking with my cute dog """ - assert %{results: [%{text: "On the field, on a field trip,"}]} = - Nx.Serving.run(serving, prompt) + assert %{ + results: [ + %{text: "I enjoy walking with my cute dog\n\nThey are always there for me\n"} + ] + } = Nx.Serving.run(serving, prompt) end end end From fbd10cef13b325241b60c8005836ca135793be0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Wed, 12 Apr 2023 14:54:53 +0200 Subject: [PATCH 5/6] Add sampling logits processors and tests --- lib/bumblebee/text/generation.ex | 188 ++++------------ .../text/generation/logits_processing.ex | 159 ++++++++++++++ lib/bumblebee/text/generation_config.ex | 60 ++++-- .../text/zero_shot_classification.ex | 2 +- .../generation/logits_processing_test.exs | 200 ++++++++++++++++++ .../bumblebee/text/generation_config_test.exs | 15 +- 6 files changed, 458 insertions(+), 166 deletions(-) create mode 100644 lib/bumblebee/text/generation/logits_processing.ex create mode 100644 test/bumblebee/text/generation/logits_processing_test.exs diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index eb0c7d36..609f3f02 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -324,26 +324,41 @@ defmodule Bumblebee.Text.Generation do end defp get_logits_processor(min_length_fun, eos_token_id, config) do - processors = [ - if config.no_repeat_ngram_length && config.no_repeat_ngram_length > 0 do - &no_repeat_ngram_logits_processor(&1, &2, ngram_length: config.no_repeat_ngram_length) - end, - if min_length_fun && eos_token_id do - &min_length_logits_processor(&1, &2, - min_length_fun: min_length_fun, - eos_token_id: eos_token_id - ) - end, - if config.forced_bos_token_id do - &bos_token_logits_processor(&1, &2, bos_token_id: config.forced_bos_token_id) - end, - if config.forced_eos_token_id do - &eos_token_logits_processor(&1, &2, eos_token_id: config.forced_eos_token_id) - end, - if config.forced_token_ids do - &forced_tokens_logits_processor(&1, &2, forced_token_ids: config.forced_token_ids) - end - ] + import Bumblebee.Text.Generation.LogitsProcessing + + processors = + [ + if config.no_repeat_ngram_length && config.no_repeat_ngram_length > 0 do + &no_repeat_ngram_processor(&1, &2, ngram_length: config.no_repeat_ngram_length) + end, + if min_length_fun && eos_token_id do + &min_length_processor(&1, &2, + min_length_fun: min_length_fun, + eos_token_id: eos_token_id + ) + end, + if config.forced_bos_token_id do + &bos_token_processor(&1, &2, bos_token_id: config.forced_bos_token_id) + end, + if config.forced_eos_token_id do + &eos_token_processor(&1, &2, eos_token_id: config.forced_eos_token_id) + end, + if config.forced_token_ids do + &forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids) + end + ] ++ + if config.strategy.type == :multinomial_sampling do + [ + if top_k = config.strategy[:top_k] do + &top_k_processor(&1, &2, top_k: top_k) + end, + if top_p = config.strategy[:top_p] do + &top_p_processor(&1, &2, top_p: top_p) + end + ] + else + [] + end fn logits, context -> for processor <- processors, processor, reduce: logits do @@ -402,7 +417,7 @@ defmodule Bumblebee.Text.Generation do :multinomial_sampling -> prng_key = Nx.Random.key(seed) - sample( + sampling( decoder_inputs, decoder_input_ids, predict_fun, @@ -725,7 +740,7 @@ defmodule Bumblebee.Text.Generation do # Multinomial sampling - defnp sample( + defnp sampling( inputs, decoder_input_ids, predict_fun, @@ -746,7 +761,7 @@ defmodule Bumblebee.Text.Generation do # is longer, we make the initial pass outside {sequences, length, finished?, inputs, prng_key} = if length > 1 do - sample_step( + sampling_step( sequences, length, finished?, @@ -768,7 +783,7 @@ defmodule Bumblebee.Text.Generation do while {sequences, length, finished?, inputs, params, prng_key}, continue?(finished?, length, max_length) do {sequences, length, finished?, inputs, prng_key} = - sample_step( + sampling_step( sequences, length, finished?, @@ -789,7 +804,7 @@ defmodule Bumblebee.Text.Generation do sequences end - defnp sample_step( + defnp sampling_step( sequences, length, finished?, @@ -819,8 +834,6 @@ defmodule Bumblebee.Text.Generation do input_length: input_length }) - # TODO: logits warper - scores = Axon.Activations.softmax(logits) token_id = batched_choice(key, scores) @@ -866,125 +879,4 @@ defmodule Bumblebee.Text.Generation do # |> Nx.squeeze() # |> Nx.devectorize() # end - - # Logit processors - - defnp bos_token_logits_processor(logits, context, opts \\ []) do - opts = keyword!(opts, [:bos_token_id]) - bos_token_id = opts[:bos_token_id] - - if context.length == 1 do - force_token_id(logits, token_id: bos_token_id) - else - logits - end - end - - defnp eos_token_logits_processor(logits, context, opts \\ []) do - opts = keyword!(opts, [:eos_token_id]) - eos_token_id = opts[:eos_token_id] - - max_length = Nx.axis_size(context.sequences, 1) - - if context.length == max_length - 1 do - force_token_id(logits, token_id: eos_token_id) - else - logits - end - end - - deftransformp forced_tokens_logits_processor(logits, context, opts \\ []) do - opts = Keyword.validate!(opts, [:forced_token_ids]) - forced_token_ids = opts[:forced_token_ids] - - clauses = - for {idx, token_id} <- forced_token_ids do - {Nx.equal(context.length, idx), force_token_id(logits, token_id: token_id)} - end - - # Note that we can't use defn ifs inside transform, so we build - # the expression directly - Nx.Defn.Expr.cond(clauses, logits) - end - - defnp min_length_logits_processor(logits, context, opts \\ []) do - opts = keyword!(opts, [:eos_token_id, :min_length_fun]) - eos_token_id = opts[:eos_token_id] - min_length_fun = opts[:min_length_fun] - - min_length = min_length_fun.(context.input_length) - - if context.length < min_length do - ignore_token_id(logits, token_id: eos_token_id) - else - logits - end - end - - defnp no_repeat_ngram_logits_processor(logits, context, opts \\ []) do - opts = keyword!(opts, [:ngram_length]) - ngram_length = opts[:ngram_length] - - if context.length + 1 < ngram_length do - logits - else - # Given a sequence of last {ngram_length - 1} tokens, we look - # for prior occurrences of that sequence and we want to make the - # subsequent token ignored. This way the n-gram is not repeated - # this time around - - ngram_but_one_length = ngram_length - 1 - - last_ngram_but_one = - Nx.slice_along_axis( - context.sequences, - context.length - ngram_but_one_length, - ngram_but_one_length, - axis: 1 - ) - - {_, _, _, _, logits} = - while {i = 0, last_ngram_but_one, sequences = context.sequences, length = context.length, - logits}, - i + ngram_but_one_length < length do - ngram_but_one = Nx.slice_along_axis(sequences, i, ngram_but_one_length, axis: 1) - - batch_size = Nx.axis_size(logits, 0) - - token_id = sequences[[.., i + ngram_but_one_length]] - indices = Nx.stack([Nx.iota({batch_size}), token_id], axis: -1) - - match? = Nx.all(ngram_but_one == last_ngram_but_one, axes: [1]) - updates = Nx.select(match?, Nx.Constants.neg_infinity(), 0) - - logits = Nx.indexed_add(logits, indices, updates) - - {i + 1, last_ngram_but_one, sequences, length, logits} - end - - logits - end - end - - defnp force_token_id(logits, opts \\ []) do - token_id = opts[:token_id] - - batch_size = Nx.axis_size(logits, 0) - - Nx.Constants.neg_infinity() - |> Nx.broadcast(logits) - |> Nx.put_slice([0, token_id], Nx.broadcast(0, {batch_size, 1})) - end - - defnp ignore_token_id(logits, opts \\ []) do - token_id = opts[:token_id] - - batch_size = Nx.axis_size(logits, 0) - - Nx.put_slice( - logits, - [0, token_id], - Nx.broadcast(Nx.Constants.neg_infinity(), {batch_size, 1}) - ) - end end diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex new file mode 100644 index 00000000..35e5afdf --- /dev/null +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -0,0 +1,159 @@ +defmodule Bumblebee.Text.Generation.LogitsProcessing do + @moduledoc false + + import Nx.Defn + + defn bos_token_processor(logits, context, opts \\ []) do + opts = keyword!(opts, [:bos_token_id]) + bos_token_id = opts[:bos_token_id] + + if context.length == 1 do + force_token_id(logits, bos_token_id) + else + logits + end + end + + defn eos_token_processor(logits, context, opts \\ []) do + opts = keyword!(opts, [:eos_token_id]) + eos_token_id = opts[:eos_token_id] + + max_length = Nx.axis_size(context.sequences, 1) + + if context.length == max_length - 1 do + force_token_id(logits, eos_token_id) + else + logits + end + end + + defn forced_tokens_processor(logits, context, opts \\ []) do + opts = keyword!(opts, [:forced_token_ids]) + forced_token_ids(logits, context, opts[:forced_token_ids]) + end + + deftransformp forced_token_ids(logits, context, forced_token_ids) do + clauses = + for {idx, token_id} <- forced_token_ids do + {Nx.equal(context.length, idx), force_token_id(logits, token_id)} + end + + # Note that we can't use defn ifs inside transform, so we build + # the expression directly + Nx.Defn.Expr.cond(clauses, logits) + end + + defn min_length_processor(logits, context, opts \\ []) do + opts = keyword!(opts, [:eos_token_id, :min_length_fun]) + eos_token_id = opts[:eos_token_id] + min_length_fun = opts[:min_length_fun] + + min_length = min_length_fun.(context.input_length) + + if context.length < min_length do + ignore_token_id(logits, eos_token_id) + else + logits + end + end + + defn no_repeat_ngram_processor(logits, context, opts \\ []) do + opts = keyword!(opts, [:ngram_length]) + ngram_length = opts[:ngram_length] + + if context.length + 1 < ngram_length do + logits + else + # Given a sequence of last {ngram_length - 1} tokens, we look + # for prior occurrences of that sequence and we want to make the + # subsequent token ignored. This way the n-gram is not repeated + # this time around + + ngram_but_one_length = ngram_length - 1 + + last_ngram_but_one = + Nx.slice_along_axis( + context.sequences, + context.length - ngram_but_one_length, + ngram_but_one_length, + axis: 1 + ) + + {_, _, _, _, logits} = + while {i = 0, last_ngram_but_one, sequences = context.sequences, length = context.length, + logits}, + i + ngram_but_one_length < length do + ngram_but_one = Nx.slice_along_axis(sequences, i, ngram_but_one_length, axis: 1) + + batch_size = Nx.axis_size(logits, 0) + + token_id = sequences[[.., i + ngram_but_one_length]] + indices = Nx.stack([Nx.iota({batch_size}), token_id], axis: -1) + + match? = Nx.all(ngram_but_one == last_ngram_but_one, axes: [1]) + updates = Nx.select(match?, Nx.Constants.neg_infinity(), 0) + + logits = Nx.indexed_add(logits, indices, updates) + + {i + 1, last_ngram_but_one, sequences, length, logits} + end + + logits + end + end + + deftransformp force_token_id(logits, token_id) do + batch_size = Nx.axis_size(logits, 0) + + Nx.Constants.neg_infinity() + |> Nx.broadcast(logits) + |> Nx.put_slice([0, token_id], Nx.broadcast(0, {batch_size, 1})) + end + + deftransformp ignore_token_id(logits, token_id) do + batch_size = Nx.axis_size(logits, 0) + + Nx.put_slice( + logits, + [0, token_id], + Nx.broadcast(Nx.Constants.neg_infinity(), {batch_size, 1}) + ) + end + + # Processors manipulating the probability distribution + + defn top_k_processor(logits, _context, opts \\ []) do + opts = keyword!(opts, [:top_k]) + top_k = opts[:top_k] + + {top_k_logits, _} = Nx.top_k(logits, k: top_k) + kth_logit = top_k_logits[[.., -1]] + Nx.select(logits < kth_logit, Nx.Constants.neg_infinity(), logits) + end + + defn top_p_processor(logits, _context, opts \\ []) do + opts = keyword!(opts, [:top_p]) + top_p = opts[:top_p] + + sorted_idx = Nx.argsort(logits, axis: 1) + + cumulative_scores = + logits + |> Nx.take_along_axis(sorted_idx, axis: 1) + |> Axon.Activations.softmax() + |> Nx.cumulative_sum(axis: 1) + + ordered_ignore_mask = cumulative_scores <= 1 - top_p + + # Arrange the mask back into the original logits order + ignore_mask = + Nx.indexed_put( + Nx.broadcast(0.0, Nx.shape(sorted_idx)), + Nx.stack([Nx.iota(Nx.shape(sorted_idx), axis: 0), sorted_idx], axis: -1) + |> Nx.reshape({:auto, 2}), + Nx.flatten(ordered_ignore_mask) + ) + + Nx.select(ignore_mask, Nx.Constants.neg_infinity(), logits) + end +end diff --git a/lib/bumblebee/text/generation_config.ex b/lib/bumblebee/text/generation_config.ex index d44b154f..46410275 100644 --- a/lib/bumblebee/text/generation_config.ex +++ b/lib/bumblebee/text/generation_config.ex @@ -44,18 +44,27 @@ defmodule Bumblebee.Text.GenerationConfig do Example: `%{type: :greedy_search}`. * `:contrastive_search` - state-of-the-art decoding method, capable - of producing high quality, coherent sequences. This method gives - deterministic results. See [this article](https://huggingface.co/blog/introducing-csearch) + of producing high quality, coherent sequences. The results are + deterministic. See [this article](https://huggingface.co/blog/introducing-csearch) for more details. - * `:top_k` - the number of highest probability vocabulary tokens considered as a continuation + * `:top_k` (required) - the number of highest probability vocabulary tokens considered + as a continuation - * `:alpha` - the weight of degeneration penalty. It balances the model confidence - and the penalty + * `:alpha` (required) - the weight of degeneration penalty. It balances the model + confidence and the penalty Example: `%{type: :contrastive_search, top_k: 4, penalty_alpha: 0.6}`. - All strategy options must be present in the map\ + * `:multinomial_sampling` - this method samples tokens according to the probability + distribution given by the model. The results are nondeterministic, unless a seed + is specified. + + * `:top_k` (optional) - when specified, restricts sampling to top-k most probable + candidates + + * `:top_p` (optional) - when specified, restricts sampling to tokens which probabilities + add up to top-p """ ] ] @@ -175,15 +184,15 @@ defmodule Bumblebee.Text.GenerationConfig do end defp validate_strategy!(%{type: :greedy_search} = strategy) do - validate_strategy_keys!(strategy, [:type]) + validate_strategy_keys!(strategy, [:type], []) end defp validate_strategy!(%{type: :contrastive_search} = strategy) do - validate_strategy_keys!(strategy, [:type, :top_k, :alpha]) + validate_strategy_keys!(strategy, [:type, :top_k, :alpha], []) end defp validate_strategy!(%{type: :multinomial_sampling} = strategy) do - validate_strategy_keys!(strategy, [:type]) + validate_strategy_keys!(strategy, [:type], [:top_k, :top_p]) end defp validate_strategy!(%{type: type}) do @@ -200,14 +209,21 @@ defmodule Bumblebee.Text.GenerationConfig do raise ArgumentError, "expected strategy to be a map, but got: #{inspect(other)}" end - defp validate_strategy_keys!(strategy, keys) do - case {Enum.sort(Map.keys(strategy)), Enum.sort(keys)} do - {keys, keys} -> - :ok + defp validate_strategy_keys!(strategy, required_keys, optional_keys) do + actual = strategy |> Map.keys() |> Enum.sort() + + missing_keys = Enum.sort(required_keys -- actual) + + if missing_keys != [] do + raise ArgumentError, + "missing keys #{inspect(missing_keys)} for strategy #{inspect(strategy.type)}" + end + + extra_keys = Enum.sort((actual -- required_keys) -- optional_keys) - {expected, actual} -> - raise ArgumentError, - "expected #{inspect(strategy.type)} strategy to have keys #{inspect(expected)}, but got: #{inspect(actual)}" + if extra_keys != [] do + raise ArgumentError, + "unexpected keys #{inspect(extra_keys)} for strategy #{inspect(strategy.type)}" end end @@ -264,11 +280,23 @@ defmodule Bumblebee.Text.GenerationConfig do strategy_opts = data |> convert!( + sample: {"do_sample", boolean()}, top_k: {"top_k", number()}, + top_p: {"top_p", number()}, alpha: {"penalty_alpha", number()} ) |> Map.new() |> case do + %{sample: true} = opts -> + options = + Map.filter(opts, fn + {:top_k, k} when k > 0 -> true + {:top_p, p} when p < 1.0 -> true + _ -> false + end) + + [strategy: Map.merge(%{type: :multinomial_sampling}, options)] + %{top_k: top_k, alpha: alpha} when top_k > 1 and alpha > 0 -> [strategy: %{type: :contrastive_search, top_k: top_k, alpha: alpha}] diff --git a/lib/bumblebee/text/zero_shot_classification.ex b/lib/bumblebee/text/zero_shot_classification.ex index 7c3ba7e2..25e4d763 100644 --- a/lib/bumblebee/text/zero_shot_classification.ex +++ b/lib/bumblebee/text/zero_shot_classification.ex @@ -91,7 +91,7 @@ defmodule Bumblebee.Text.ZeroShotClassification do end) |> Nx.Serving.client_postprocessing(fn scores, _metadata, multi? -> for scores <- Utils.Nx.batch_to_list(scores) do - scores = Axon.Layers.softmax(scores[[.., entailment_id]]) + scores = Axon.Activations.softmax(scores[[.., entailment_id]]) k = min(top_k, Nx.size(scores)) {top_scores, top_indices} = Nx.top_k(scores, k: k) diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs new file mode 100644 index 00000000..3f5d30b3 --- /dev/null +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -0,0 +1,200 @@ +defmodule Bumblebee.Text.Generation.LogitsProcessingTest do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + alias Bumblebee.Text.Generation.LogitsProcessing + + describe "bos_token_processor/3" do + test "forces BOS token at position 1" do + logits = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) + + context = context([1, 0, 0, 0]) + + assert_equal( + LogitsProcessing.bos_token_processor(logits, context, bos_token_id: 1), + Nx.tensor([[:neg_infinity, 0.0, :neg_infinity, :neg_infinity]]) + ) + end + + test "leaves logits unchanged for further positions" do + logits = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) + + context = context([1, 1, 0, 0]) + + assert_equal( + LogitsProcessing.bos_token_processor(logits, context, bos_token_id: 1), + logits + ) + end + end + + describe "eos_token_processor/3" do + test "forces EOS token at last position" do + logits = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) + + context = context([1, 1, 1, 0]) + + assert_equal( + LogitsProcessing.eos_token_processor(logits, context, eos_token_id: 2), + Nx.tensor([[:neg_infinity, :neg_infinity, 0.0, :neg_infinity]]) + ) + end + + test "leaves logits unchanged for other positions" do + logits = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) + + context = context([1, 1, 0, 0]) + + assert_equal( + LogitsProcessing.eos_token_processor(logits, context, eos_token_id: 1), + logits + ) + end + end + + describe "forced_tokens_processor/3" do + test "forces tokens at the specified positions" do + logits = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) + + context = context([1, 0, 0, 0]) + + assert_equal( + LogitsProcessing.forced_tokens_processor(logits, context, + forced_token_ids: [{1, 2}, {2, 1}] + ), + Nx.tensor([[:neg_infinity, :neg_infinity, 0.0, :neg_infinity]]) + ) + + context = context([1, 1, 0, 0]) + + assert_equal( + LogitsProcessing.forced_tokens_processor(logits, context, + forced_token_ids: [{1, 2}, {2, 1}] + ), + Nx.tensor([[:neg_infinity, 0.0, :neg_infinity, :neg_infinity]]) + ) + + context = context([1, 1, 1, 0]) + + assert_equal( + LogitsProcessing.forced_tokens_processor(logits, context, + forced_token_ids: [{1, 2}, {2, 1}] + ), + logits + ) + end + end + + describe "min_length_processor/3" do + test "ignores EOS token when the sequence is not long enough" do + logits = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) + + context = context([1, 1, 0, 0]) + + assert_equal( + LogitsProcessing.min_length_processor(logits, context, + eos_token_id: 2, + min_length_fun: fn _ -> 3 end + ), + Nx.tensor([[1.0, 2.0, :neg_infinity, 4.0]]) + ) + end + + test "leaves logits unchanged if the sequence is long enough" do + logits = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) + + context = context([1, 1, 1, 0]) + + assert_equal( + LogitsProcessing.min_length_processor(logits, context, + eos_token_id: 2, + min_length_fun: fn _ -> 3 end + ), + logits + ) + end + end + + describe "no_repeat_ngram_processor/3" do + test "ignores token that would produce duplicated n-gram" do + logits = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) + + context = context([2, 3, 2, 0]) + + assert_equal( + LogitsProcessing.no_repeat_ngram_processor(logits, context, ngram_length: 2), + Nx.tensor([[1.0, 2.0, 3.0, :neg_infinity]]) + ) + end + + test "leaves logits unchanged otherwise" do + logits = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) + + context = context([2, 3, 1, 0]) + + assert_equal( + LogitsProcessing.no_repeat_ngram_processor(logits, context, ngram_length: 2), + logits + ) + end + end + + describe "top_k_processor/3" do + test "keeps top-k highest logits" do + logits = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) + + context = context([1, 0, 0, 0]) + + assert_equal( + LogitsProcessing.top_k_processor(logits, context, top_k: 2), + Nx.tensor([[:neg_infinity, :neg_infinity, 3.0, 4.0]]) + ) + end + + test "keeps all logits that tie" do + logits = Nx.tensor([[3.0, 2.0, 3.0, 4.0]]) + + context = context([1, 0, 0, 0]) + + assert_equal( + LogitsProcessing.top_k_processor(logits, context, top_k: 2), + Nx.tensor([[3.0, :neg_infinity, 3.0, 4.0]]) + ) + end + end + + describe "top_p_processor/3" do + test "keeps logits adding up to top-p probability" do + # We take a log (inverse of softmax) on a probability distribution + logits = Nx.tensor([[0.1, 0.2, 0.3, 0.4]]) |> Nx.log() + + context = context([1, 0, 0, 0]) + + assert_equal( + LogitsProcessing.top_p_processor(logits, context, top_p: 0.7) |> Nx.exp(), + # Zeros mean negative infinity logits + Nx.tensor([[0.0, 0.0, 0.3, 0.4]]) + ) + end + + test "surpasses top-p if there is no exact match" do + logits = Nx.tensor([[0.1, 0.2, 0.3, 0.4]]) |> Nx.log() + + context = context([1, 0, 0, 0]) + + assert_equal( + LogitsProcessing.top_p_processor(logits, context, top_p: 0.6) |> Nx.exp(), + Nx.tensor([[0.0, 0.0, 0.3, 0.4]]) + ) + end + end + + defp context(sequence) do + %{ + sequences: Nx.tensor([sequence]), + length: Enum.count(sequence, &(&1 != 0)), + input_length: 1 + } + end +end diff --git a/test/bumblebee/text/generation_config_test.exs b/test/bumblebee/text/generation_config_test.exs index 9d0dc86a..742ec629 100644 --- a/test/bumblebee/text/generation_config_test.exs +++ b/test/bumblebee/text/generation_config_test.exs @@ -58,12 +58,25 @@ defmodule Bumblebee.Text.GenerationConfigTest do end assert_raise ArgumentError, - "expected :contrastive_search strategy to have keys [:type], but got: [:alpha, :top_k, :type]", + "missing keys [:alpha, :top_k] for strategy :contrastive_search", fn -> GenerationConfig.config(%GenerationConfig{}, strategy: %{type: :contrastive_search} ) end + + assert_raise ArgumentError, + "unexpected keys [:unexpected] for strategy :contrastive_search", + fn -> + GenerationConfig.config(%GenerationConfig{}, + strategy: %{ + type: :contrastive_search, + top_k: 4, + alpha: 0.6, + unexpected: true + } + ) + end end end end From 0c6f053672e7fb62cf98539f46462577f6e33ec3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Wed, 12 Apr 2023 18:45:07 +0200 Subject: [PATCH 6/6] Add note --- test/bumblebee/text/generation_test.exs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index fdbbe449..9a65ec83 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -67,6 +67,9 @@ defmodule Bumblebee.Text.GenerationTest do defn_options: [compiler: EXLA] ) + # Note that this is just a snapshot test, we do not use any + # reference value, because of PRNG difference + assert %{ results: [ %{text: "I was going to fall asleep.\"\n\nThis is not Wallace's fifth"}