Skip to content

Commit

Permalink
Move logit processors to a separate module
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Apr 12, 2023
1 parent b3f29bd commit 2ad5aad
Show file tree
Hide file tree
Showing 5 changed files with 388 additions and 146 deletions.
159 changes: 16 additions & 143 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -324,31 +324,36 @@ defmodule Bumblebee.Text.Generation do
end

defp get_logits_processor(min_length_fun, eos_token_id, config) do
import Bumblebee.Text.Generation.LogitsProcessing

processors =
[
if config.no_repeat_ngram_length && config.no_repeat_ngram_length > 0 do
&no_repeat_ngram_logits_processor(&1, &2, ngram_length: config.no_repeat_ngram_length)
&no_repeat_ngram_processor(&1, &2, ngram_length: config.no_repeat_ngram_length)
end,
if min_length_fun && eos_token_id do
&min_length_logits_processor(&1, &2,
&min_length_processor(&1, &2,
min_length_fun: min_length_fun,
eos_token_id: eos_token_id
)
end,
if config.forced_bos_token_id do
&bos_token_logits_processor(&1, &2, bos_token_id: config.forced_bos_token_id)
&bos_token_processor(&1, &2, bos_token_id: config.forced_bos_token_id)
end,
if config.forced_eos_token_id do
&eos_token_logits_processor(&1, &2, eos_token_id: config.forced_eos_token_id)
&eos_token_processor(&1, &2, eos_token_id: config.forced_eos_token_id)
end,
if config.forced_token_ids do
&forced_tokens_logits_processor(&1, &2, forced_token_ids: config.forced_token_ids)
&forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids)
end
] ++
if config.strategy.type == :multinomial_sampling do
[
if top_k = config.strategy[:top_k] do
&top_k_logits_warper(&1, &2, top_k: top_k)
&top_k_processor(&1, &2, top_k: top_k)
end,
if top_p = config.strategy[:top_p] do
&top_p_processor(&1, &2, top_p: top_p)
end
]
else
Expand Down Expand Up @@ -412,7 +417,7 @@ defmodule Bumblebee.Text.Generation do
:multinomial_sampling ->
prng_key = Nx.Random.key(seed)

sample(
sampling(
decoder_inputs,
decoder_input_ids,
predict_fun,
Expand Down Expand Up @@ -735,7 +740,7 @@ defmodule Bumblebee.Text.Generation do

# Multinomial sampling

defnp sample(
defnp sampling(
inputs,
decoder_input_ids,
predict_fun,
Expand All @@ -756,7 +761,7 @@ defmodule Bumblebee.Text.Generation do
# is longer, we make the initial pass outside
{sequences, length, finished?, inputs, prng_key} =
if length > 1 do
sample_step(
sampling_step(
sequences,
length,
finished?,
Expand All @@ -778,7 +783,7 @@ defmodule Bumblebee.Text.Generation do
while {sequences, length, finished?, inputs, params, prng_key},
continue?(finished?, length, max_length) do
{sequences, length, finished?, inputs, prng_key} =
sample_step(
sampling_step(
sequences,
length,
finished?,
Expand All @@ -799,7 +804,7 @@ defmodule Bumblebee.Text.Generation do
sequences
end

defnp sample_step(
defnp sampling_step(
sequences,
length,
finished?,
Expand Down Expand Up @@ -829,8 +834,6 @@ defmodule Bumblebee.Text.Generation do
input_length: input_length
})

# TODO: logits warper

scores = Axon.Activations.softmax(logits)
token_id = batched_choice(key, scores)

Expand Down Expand Up @@ -876,134 +879,4 @@ defmodule Bumblebee.Text.Generation do
# |> Nx.squeeze()
# |> Nx.devectorize()
# end

# Logit processors

defnp bos_token_logits_processor(logits, context, opts \\ []) do
opts = keyword!(opts, [:bos_token_id])
bos_token_id = opts[:bos_token_id]

if context.length == 1 do
force_token_id(logits, token_id: bos_token_id)
else
logits
end
end

defnp eos_token_logits_processor(logits, context, opts \\ []) do
opts = keyword!(opts, [:eos_token_id])
eos_token_id = opts[:eos_token_id]

max_length = Nx.axis_size(context.sequences, 1)

if context.length == max_length - 1 do
force_token_id(logits, token_id: eos_token_id)
else
logits
end
end

deftransformp forced_tokens_logits_processor(logits, context, opts \\ []) do
opts = Keyword.validate!(opts, [:forced_token_ids])
forced_token_ids = opts[:forced_token_ids]

clauses =
for {idx, token_id} <- forced_token_ids do
{Nx.equal(context.length, idx), force_token_id(logits, token_id: token_id)}
end

# Note that we can't use defn ifs inside transform, so we build
# the expression directly
Nx.Defn.Expr.cond(clauses, logits)
end

defnp min_length_logits_processor(logits, context, opts \\ []) do
opts = keyword!(opts, [:eos_token_id, :min_length_fun])
eos_token_id = opts[:eos_token_id]
min_length_fun = opts[:min_length_fun]

min_length = min_length_fun.(context.input_length)

if context.length < min_length do
ignore_token_id(logits, token_id: eos_token_id)
else
logits
end
end

defnp no_repeat_ngram_logits_processor(logits, context, opts \\ []) do
opts = keyword!(opts, [:ngram_length])
ngram_length = opts[:ngram_length]

if context.length + 1 < ngram_length do
logits
else
# Given a sequence of last {ngram_length - 1} tokens, we look
# for prior occurrences of that sequence and we want to make the
# subsequent token ignored. This way the n-gram is not repeated
# this time around

ngram_but_one_length = ngram_length - 1

last_ngram_but_one =
Nx.slice_along_axis(
context.sequences,
context.length - ngram_but_one_length,
ngram_but_one_length,
axis: 1
)

{_, _, _, _, logits} =
while {i = 0, last_ngram_but_one, sequences = context.sequences, length = context.length,
logits},
i + ngram_but_one_length < length do
ngram_but_one = Nx.slice_along_axis(sequences, i, ngram_but_one_length, axis: 1)

batch_size = Nx.axis_size(logits, 0)

token_id = sequences[[.., i + ngram_but_one_length]]
indices = Nx.stack([Nx.iota({batch_size}), token_id], axis: -1)

match? = Nx.all(ngram_but_one == last_ngram_but_one, axes: [1])
updates = Nx.select(match?, Nx.Constants.neg_infinity(), 0)

logits = Nx.indexed_add(logits, indices, updates)

{i + 1, last_ngram_but_one, sequences, length, logits}
end

logits
end
end

defnp force_token_id(logits, opts \\ []) do
token_id = opts[:token_id]

batch_size = Nx.axis_size(logits, 0)

Nx.Constants.neg_infinity()
|> Nx.broadcast(logits)
|> Nx.put_slice([0, token_id], Nx.broadcast(0, {batch_size, 1}))
end

defnp ignore_token_id(logits, opts \\ []) do
token_id = opts[:token_id]

batch_size = Nx.axis_size(logits, 0)

Nx.put_slice(
logits,
[0, token_id],
Nx.broadcast(Nx.Constants.neg_infinity(), {batch_size, 1})
)
end

defnp top_k_logits_warper(logits, _context, opts \\ []) do
opts = keyword!(opts, [:top_k])
top_k = opts[:top_k]

{top_k_logits, _} = Nx.top_k(logits, k: top_k)
kth_logit = top_k_logits[[.., -1]]
Nx.select(logits < kth_logit, Nx.Constants.neg_infinity(), logits)
end
end
159 changes: 159 additions & 0 deletions lib/bumblebee/text/generation/logits_processing.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
defmodule Bumblebee.Text.Generation.LogitsProcessing do
@moduledoc false

import Nx.Defn

defn bos_token_processor(logits, context, opts \\ []) do
opts = keyword!(opts, [:bos_token_id])
bos_token_id = opts[:bos_token_id]

if context.length == 1 do
force_token_id(logits, bos_token_id)
else
logits
end
end

defn eos_token_processor(logits, context, opts \\ []) do
opts = keyword!(opts, [:eos_token_id])
eos_token_id = opts[:eos_token_id]

max_length = Nx.axis_size(context.sequences, 1)

if context.length == max_length - 1 do
force_token_id(logits, eos_token_id)
else
logits
end
end

defn forced_tokens_processor(logits, context, opts \\ []) do
opts = keyword!(opts, [:forced_token_ids])
forced_token_ids(logits, context, opts[:forced_token_ids])
end

deftransformp forced_token_ids(logits, context, forced_token_ids) do
clauses =
for {idx, token_id} <- forced_token_ids do
{Nx.equal(context.length, idx), force_token_id(logits, token_id)}
end

# Note that we can't use defn ifs inside transform, so we build
# the expression directly
Nx.Defn.Expr.cond(clauses, logits)
end

defn min_length_processor(logits, context, opts \\ []) do
opts = keyword!(opts, [:eos_token_id, :min_length_fun])
eos_token_id = opts[:eos_token_id]
min_length_fun = opts[:min_length_fun]

min_length = min_length_fun.(context.input_length)

if context.length < min_length do
ignore_token_id(logits, eos_token_id)
else
logits
end
end

defn no_repeat_ngram_processor(logits, context, opts \\ []) do
opts = keyword!(opts, [:ngram_length])
ngram_length = opts[:ngram_length]

if context.length + 1 < ngram_length do
logits
else
# Given a sequence of last {ngram_length - 1} tokens, we look
# for prior occurrences of that sequence and we want to make the
# subsequent token ignored. This way the n-gram is not repeated
# this time around

ngram_but_one_length = ngram_length - 1

last_ngram_but_one =
Nx.slice_along_axis(
context.sequences,
context.length - ngram_but_one_length,
ngram_but_one_length,
axis: 1
)

{_, _, _, _, logits} =
while {i = 0, last_ngram_but_one, sequences = context.sequences, length = context.length,
logits},
i + ngram_but_one_length < length do
ngram_but_one = Nx.slice_along_axis(sequences, i, ngram_but_one_length, axis: 1)

batch_size = Nx.axis_size(logits, 0)

token_id = sequences[[.., i + ngram_but_one_length]]
indices = Nx.stack([Nx.iota({batch_size}), token_id], axis: -1)

match? = Nx.all(ngram_but_one == last_ngram_but_one, axes: [1])
updates = Nx.select(match?, Nx.Constants.neg_infinity(), 0)

logits = Nx.indexed_add(logits, indices, updates)

{i + 1, last_ngram_but_one, sequences, length, logits}
end

logits
end
end

deftransformp force_token_id(logits, token_id) do
batch_size = Nx.axis_size(logits, 0)

Nx.Constants.neg_infinity()
|> Nx.broadcast(logits)
|> Nx.put_slice([0, token_id], Nx.broadcast(0, {batch_size, 1}))
end

deftransformp ignore_token_id(logits, token_id) do
batch_size = Nx.axis_size(logits, 0)

Nx.put_slice(
logits,
[0, token_id],
Nx.broadcast(Nx.Constants.neg_infinity(), {batch_size, 1})
)
end

# Processors manipulating the probability distribution

defn top_k_processor(logits, _context, opts \\ []) do
opts = keyword!(opts, [:top_k])
top_k = opts[:top_k]

{top_k_logits, _} = Nx.top_k(logits, k: top_k)
kth_logit = top_k_logits[[.., -1]]
Nx.select(logits < kth_logit, Nx.Constants.neg_infinity(), logits)
end

defn top_p_processor(logits, _context, opts \\ []) do
opts = keyword!(opts, [:top_p])
top_p = opts[:top_p]

sorted_idx = Nx.argsort(logits, axis: 1)

cumulative_scores =
logits
|> Nx.take_along_axis(sorted_idx, axis: 1)
|> Axon.Activations.softmax()
|> Nx.cumulative_sum(axis: 1)

ordered_ignore_mask = cumulative_scores <= 1 - top_p

# Arrange the mask back into the original logits order
ignore_mask =
Nx.indexed_put(
Nx.broadcast(0.0, Nx.shape(sorted_idx)),
Nx.stack([Nx.iota(Nx.shape(sorted_idx), axis: 0), sorted_idx], axis: -1)
|> Nx.reshape({:auto, 2}),
Nx.flatten(ordered_ignore_mask)
)

Nx.select(ignore_mask, Nx.Constants.neg_infinity(), logits)
end
end
Loading

0 comments on commit 2ad5aad

Please sign in to comment.