Skip to content

Commit

Permalink
Support chunking to enable long-form transcription (#236)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Aug 31, 2023
1 parent 56ab13c commit 32ce7b3
Show file tree
Hide file tree
Showing 14 changed files with 268 additions and 45 deletions.
13 changes: 13 additions & 0 deletions lib/bumblebee/audio.ex
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ defmodule Bumblebee.Audio do
## Options
* `:chunk_num_seconds` - enables long-form transcription by splitting
the input into chunks of the given length. Models generally have
a limit on the input length, so by chunking we can feed smaller
bits into the model, then merge the individual outputs into a
single result at the end. By default chunking is disabled
* `:context_num_seconds` - specifies the amount of overlap between
chunks on both sides of split points. The context is effectively
discarded when merging the chunks at the end, but it improves
the results at the chunk edges. Note that the context is included
in the total `:chunk_num_seconds`. Defaults to 1/6 of
`:chunk_num_seconds`
* `:seed` - random seed to use when sampling. By default the current
timestamp is used
Expand Down
153 changes: 145 additions & 8 deletions lib/bumblebee/audio/speech_to_text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,22 @@ defmodule Bumblebee.Audio.SpeechToText do
%Text.GenerationConfig{} = generation_config,
opts \\ []
) do
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], preallocate_params: false])
opts =
Keyword.validate!(opts, [
:chunk_num_seconds,
:context_num_seconds,
:seed,
:compile,
defn_options: [],
preallocate_params: false
])

%{model: model, params: params, spec: spec} = model_info

Shared.validate_architecture!(spec, [:for_conditional_generation])

chunk_num_seconds = opts[:chunk_num_seconds]
context_num_seconds = opts[:context_num_seconds]
preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

Expand Down Expand Up @@ -68,15 +78,142 @@ defmodule Bumblebee.Audio.SpeechToText do
{:error, "expected a 1-dimensional tensor or {:file, path}, got: #{inspect(other)}"}
end)

inputs = Bumblebee.apply_featurizer(featurizer, inputs, defn_options: defn_options)
{Nx.Batch.concatenate([inputs]), multi?}
all_chunks =
for input <- inputs do
if chunk_num_seconds do
chunk_input(input, sampling_rate, chunk_num_seconds, context_num_seconds)
else
[input]
end
end

all_num_chunks = Enum.map(all_chunks, &length/1)

all_chunks = List.flatten(all_chunks)
inputs = Bumblebee.apply_featurizer(featurizer, all_chunks, defn_options: defn_options)
{Nx.Batch.concatenate([inputs]), {multi?, all_num_chunks}}
end)
|> Nx.Serving.client_postprocessing(fn {results, _metadata}, {multi?, all_num_chunks} ->
all_special_tokens = Bumblebee.Tokenizer.all_special_tokens(tokenizer)

sequences =
results
|> Bumblebee.Utils.Nx.to_list()
|> Enum.map(fn sequence ->
sequence
|> Enum.filter(fn token_id ->
if token = Bumblebee.Tokenizer.id_to_token(tokenizer, token_id) do
token not in all_special_tokens
end
end)
|> Nx.tensor()
end)

{outputs, []} =
Enum.map_reduce(all_num_chunks, sequences, fn num_chunks, sequences ->
{sequences, rest} = Enum.split(sequences, num_chunks)
token_ids = merge_overlapping_sequences(sequences)
text = Bumblebee.Tokenizer.decode(tokenizer, token_ids)
output = %{results: [%{text: normalize_text(text)}]}
{output, rest}
end)

Shared.normalize_output(outputs, multi?)
end)
|> Nx.Serving.client_postprocessing(fn {token_ids, _metadata}, multi? ->
decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids)
end

defp chunk_input(input, sampling_rate, chunk_num_seconds, context_num_seconds) do
context_num_seconds = context_num_seconds || chunk_num_seconds / 6

chunk_length = floor(chunk_num_seconds * sampling_rate)
context_left = floor(context_num_seconds * sampling_rate)
context_right = context_left

input_length = Nx.axis_size(input, 0)
step = chunk_length - context_left - context_right

0..(input_length - 1)//step
|> Enum.reduce_while([], fn chunk_start_idx, chunks ->
chunk_end_idx = chunk_start_idx + chunk_length

# All right contexts must be full, otherwise it is the last item
last? =
if context_right > 0 do
chunk_end_idx > input_length
else
chunk_end_idx >= input_length
end

chunk = input[chunk_start_idx..(min(chunk_end_idx, input_length) - 1)]
chunks = [chunk | chunks]

{if(last?, do: :halt, else: :cont), chunks}
end)
|> Enum.reverse()
end

defp merge_overlapping_sequences(sequences) do
# We have a number of consecutive, overlapping sequences and we
# want to merge them into a single sequence. To merge a pair of
# consecutive sequences we slide the sequences and compare the
# overlap:
#
# abcd (left)
# cde (right)
# => compare c = d
#
# abcd (left)
# cde (right)
# => compare cd = cd
#
# We find the best alignment, then cut the overlap in half and
# concatenate the left an right part accordingly. In the example
# above, we would use the second alignment, taking `abc` from the
# left sequence and `de` from the right one.

{[left_sequence], right_sequences} = Enum.split(sequences, 1)

{acc, left_sequence} =
for right_sequence <- right_sequences, reduce: {[], left_sequence} do
{acc, left_sequence} ->
left_length = Nx.size(left_sequence)
right_length = Nx.size(right_sequence)

{_max_match_score, overlap_indices} =
for i <- 1..(left_length + right_length - 1),
reduce: {0.0, {left_length, left_length, 0, 0}} do
{max_match_score, overlap_indices} ->
left_start = max(0, left_length - i)
left_stop = min(left_length, left_length + right_length - i)
left_overlap = left_sequence[left_start..(left_stop - 1)]

right_start = max(0, i - left_length)
right_stop = min(right_length, i)
right_overlap = right_sequence[right_start..(right_stop - 1)]

num_matches = Nx.equal(left_overlap, right_overlap) |> Nx.sum() |> Nx.to_number()

# Epsilon to favor long perfect matches
eps = i / 10000.0
match_score = num_matches / i + eps

if num_matches > 1 and match_score > max_match_score do
overlap_indices = {left_start, left_stop, right_start, right_stop}
{match_score, overlap_indices}
else
{max_match_score, overlap_indices}
end
end

# Cut in the middle of the overlap
{left_start, left_stop, right_start, right_stop} = overlap_indices
left_mid = div(left_stop + left_start, 2)
right_mid = div(right_stop + right_start, 2)
{[left_sequence[0..(left_mid - 1)] | acc], right_sequence[right_mid..-1//1]}
end

decoded
|> Enum.map(&%{results: [%{text: normalize_text(&1)}]})
|> Shared.normalize_output(multi?)
Enum.reduce([left_sequence | acc], [], fn sequence, acc ->
Nx.to_flat_list(sequence) ++ acc
end)
end

Expand Down
25 changes: 15 additions & 10 deletions lib/bumblebee/audio/whisper_featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,28 @@ defmodule Bumblebee.Audio.WhisperFeaturizer do
def apply(featurizer, raw_samples, defn_options) do
max_length = featurizer.num_seconds * featurizer.sampling_rate

transformed_samples =
samples =
for sample <- List.wrap(raw_samples) do
unless Nx.rank(sample) == 1 do
raise ArgumentError,
"expected sample to be a 1-rank tensor, got: #{Nx.rank(sample)}-rank"
end

pad_size = max_length - Nx.axis_size(sample, 0)
sample = Nx.pad(sample, featurizer.padding_value, [{0, pad_size, 0}])

Nx.Defn.jit(&extract_fbank_features/2, defn_options).(sample,
fft_length: featurizer.fft_length,
sampling_rate: featurizer.sampling_rate,
mel_bins: featurizer.feature_size,
hop_length: featurizer.hop_length
)
Nx.pad(sample, featurizer.padding_value, [{0, pad_size, 0}])
end

samples = Nx.stack(transformed_samples)
samples = samples |> Nx.stack() |> Nx.vectorize(:batch)

samples =
Nx.Defn.jit(&extract_fbank_features/2, defn_options).(samples,
fft_length: featurizer.fft_length,
sampling_rate: featurizer.sampling_rate,
mel_bins: featurizer.feature_size,
hop_length: featurizer.hop_length
)

samples = Nx.devectorize(samples)

%{"input_features" => samples}
end
Expand All @@ -92,6 +95,8 @@ defmodule Bumblebee.Audio.WhisperFeaturizer do
window_padding: :reflect
)

stft = stft[0..-2//1]

# Magic numbers taken from the reference implementation. This yields
# max_mel ~ 3016
frequency_spacing = 200.0 / 3
Expand Down
39 changes: 33 additions & 6 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -340,16 +340,23 @@ defmodule Bumblebee.Shared do
def load_special_tokens(special_tokens, data) do
for {key, default_token} <- special_tokens, into: %{} do
token =
case data["#{key}_token"] do
nil -> default_token
%{"content" => token} when is_binary(token) -> token
token when is_binary(token) -> token
if token = data["#{key}_token"] do
load_token(token)
else
default_token
end

{key, token}
end
end

@doc """
Normalizes a persisted token into token string.
"""
@spec load_token(String.t() | map()) :: String.t()
def load_token(token) when is_binary(token), do: token
def load_token(%{"content" => token}) when is_binary(token), do: token

@doc """
Converts logits to scores as per the given scores function.
Expand Down Expand Up @@ -427,7 +434,8 @@ defmodule Bumblebee.Shared do
quote do
defstruct [
:tokenizer,
special_tokens: unquote(special_tokens)
special_tokens: unquote(special_tokens),
additional_special_tokens: []
]

@behaviour Bumblebee.Tokenizer
Expand Down Expand Up @@ -457,6 +465,11 @@ defmodule Bumblebee.Shared do
tokenizer.special_tokens
end

@impl true
def additional_special_tokens(tokenizer) do
tokenizer.additional_special_tokens
end

defimpl Bumblebee.HuggingFace.Transformers.Config do
def load(tokenizer, %{
"tokenizer_file" => path,
Expand All @@ -467,7 +480,21 @@ defmodule Bumblebee.Shared do
special_tokens =
Bumblebee.Shared.load_special_tokens(tokenizer.special_tokens, special_tokens_map)

%{tokenizer | tokenizer: native_tokenizer, special_tokens: special_tokens}
additional_special_tokens =
case special_tokens_map do
%{"additional_special_tokens" => tokens} ->
for token <- tokens, do: Bumblebee.Shared.load_token(token), into: MapSet.new()

_ ->
[]
end

%{
tokenizer
| tokenizer: native_tokenizer,
special_tokens: special_tokens,
additional_special_tokens: additional_special_tokens
}
end
end
end
Expand Down
16 changes: 16 additions & 0 deletions lib/bumblebee/tokenizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ defmodule Bumblebee.Tokenizer do
"""
@callback special_tokens(t()) :: %{special_token_type() => token()}

@doc """
Returns a list with extra special tokens, in addition to the named
`special_tokens/1`.
"""
@callback additional_special_tokens(t()) :: MapSet.t(token())

@doc """
Decodes a list of token ids into a sentence.
"""
Expand Down Expand Up @@ -111,4 +117,14 @@ defmodule Bumblebee.Tokenizer do
token_to_id(tokenizer, token)
end
end

@doc """
Returns all special tokens, including any extra tokens.
"""
@spec all_special_tokens(t()) :: list(token_id())
def all_special_tokens(%module{} = tokenizer) do
special_tokens = module.special_tokens(tokenizer)
additional_special_tokens = module.additional_special_tokens(tokenizer)
for {_type, token} <- special_tokens, do: token, into: additional_special_tokens
end
end
14 changes: 7 additions & 7 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ defmodule Bumblebee.MixProject do
[
{:axon, "~> 0.6.0", axon_opts()},
{:tokenizers, "~> 0.4"},
{:nx, "~> 0.6.0"},
{:exla, "~> 0.6.0", only: [:dev, :test]},
{:torchx, "~> 0.6.0", only: [:dev, :test]},
# {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
# {:exla, github: "elixir-nx/nx", sparse: "exla", override: true, only: [:dev, :test]},
# {:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]},
# {:nx, "~> 0.6.0"},
# {:exla, "~> 0.6.0", only: [:dev, :test]},
# {:torchx, "~> 0.6.0", only: [:dev, :test]},
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
{:exla, github: "elixir-nx/nx", sparse: "exla", override: true, only: [:dev, :test]},
{:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]},
{:nx_image, "~> 0.1.0"},
{:unpickler, "~> 0.1.0"},
{:safetensors, "~> 0.1.1"},
Expand All @@ -48,7 +48,7 @@ defmodule Bumblebee.MixProject do
{:stb_image, "~> 0.6.0", only: :test},
{:bypass, "~> 2.1", only: :test},
{:ex_doc, "~> 0.28", only: :dev, runtime: false},
{:nx_signal, "~> 0.1.0"}
{:nx_signal, "~> 0.2.0"}
]
end

Expand Down
Loading

0 comments on commit 32ce7b3

Please sign in to comment.