Skip to content

Commit

Permalink
Move featurizer batch part to serving computation (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Sep 13, 2023
1 parent 391fcd0 commit 2af8cf3
Show file tree
Hide file tree
Showing 15 changed files with 302 additions and 124 deletions.
9 changes: 8 additions & 1 deletion lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,14 @@ defmodule Bumblebee do
@spec apply_featurizer(Bumblebee.Featurizer.t(), any(), keyword()) :: any()
def apply_featurizer(%module{} = featurizer, input, opts \\ []) do
opts = Keyword.validate!(opts, defn_options: [])
module.apply(featurizer, input, opts[:defn_options])

batch = module.process_input(featurizer, input)

if Code.ensure_loaded?(module) and function_exported?(module, :process_batch, 2) do
Nx.Defn.jit_apply(&module.process_batch(featurizer, &1), [batch], opts[:defn_options])
else
batch
end
end

@doc """
Expand Down
12 changes: 7 additions & 5 deletions lib/bumblebee/audio/speech_to_text_whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,18 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
{generate_opts, generation_config} = generate_opts(generation_config, opts)
generate_fun = Text.Generation.build_generate(model, spec, generation_config, generate_opts)

generate_fun = fn params, inputs ->
inputs = Bumblebee.Featurizer.process_batch(featurizer, inputs)
generate_fun.(params, inputs)
end

Nx.Serving.new(
fn defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
inputs = %{
"input_features" => Shared.input_template(spec, "input_features", [batch_size])
}

inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size)
[params, inputs]
end)

Expand Down Expand Up @@ -102,7 +104,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
all_chunks = List.flatten(all_chunks)
{all_chunks, lengths} = Enum.unzip(all_chunks)

inputs = Bumblebee.apply_featurizer(featurizer, all_chunks, defn_options: defn_options)
inputs = Bumblebee.Featurizer.process_input(featurizer, all_chunks)
{Nx.Batch.concatenate([inputs]), {multi?, all_num_chunks, lengths}}
end)
|> maybe_stream(opts[:stream], spec, featurizer, tokenizer, timestamps?)
Expand Down
20 changes: 15 additions & 5 deletions lib/bumblebee/audio/whisper_featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ defmodule Bumblebee.Audio.WhisperFeaturizer do
end

@impl true
def apply(featurizer, raw_samples, defn_options) do
def process_input(featurizer, raw_samples) do
max_length = featurizer.num_seconds * featurizer.sampling_rate

samples =
Expand All @@ -67,17 +67,27 @@ defmodule Bumblebee.Audio.WhisperFeaturizer do
Nx.pad(sample, featurizer.padding_value, [{0, pad_size, 0}])
end

samples = samples |> Nx.stack() |> Nx.vectorize(:batch)
Nx.stack(samples)
end

@impl true
def batch_template(featurizer, batch_size) do
max_length = featurizer.num_seconds * featurizer.sampling_rate
Nx.template({batch_size, max_length}, :f32)
end

@impl true
def process_batch(featurizer, samples) do
samples =
Nx.Defn.jit(&extract_fbank_features/2, defn_options).(samples,
samples
|> Nx.vectorize(:batch)
|> extract_fbank_features(
fft_length: featurizer.fft_length,
sampling_rate: featurizer.sampling_rate,
mel_bins: featurizer.feature_size,
hop_length: featurizer.hop_length
)

samples = Nx.devectorize(samples)
|> Nx.devectorize()

%{"input_features" => samples}
end
Expand Down
66 changes: 64 additions & 2 deletions lib/bumblebee/featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,69 @@ defmodule Bumblebee.Featurizer do
@type t :: Bumblebee.Configurable.t()

@doc """
Performs feature extraction on the given input.
Converts the given input to a batched tensor (or a tensor container).
Numerical batch processing should be moved to `c:process_batch/2`
whenever possible.
"""
@callback process_input(t(), input :: any()) :: Nx.t() | Nx.Container.t()

@doc """
Returns an input template for `c:process_batch/2`.
The shape is effectively the same as the result of `c:process_input/2`,
except for the batch size.
"""
@callback batch_template(t(), batch_size :: pos_integer()) :: Nx.t() | Nx.Container.t()

@doc """
Optional batch processing stage.
This is a numerical function. It receives the result of `c:process_input/2`,
except the batch size may differ.
When using featurizer as part of `Nx.Serving`, the batch stage can
be merged with the model computation and compiled together.
"""
@callback process_batch(t(), input :: Nx.t() | Nx.Container.t()) :: Nx.t() | Nx.Container.t()

@optional_callbacks batch_template: 2, process_batch: 2

@doc """
Converts the given input to a batched tensor (or a tensor container).
"""
@spec process_input(t(), any()) :: Nx.t() | Nx.Container.t()
def process_input(%module{} = featurizer, input) do
module.process_input(featurizer, input)
end

@doc """
Returns an input template for `process_batch/2`.
If the featurizer does not define batch processing, `nil` is returned.
"""
@spec batch_template(t(), pos_integer()) :: Nx.t() | Nx.Container.t() | nil
def batch_template(%module{} = featurizer, batch_size) do
if Code.ensure_loaded?(module) and function_exported?(module, :batch_template, 2) do
module.batch_template(featurizer, batch_size)
end
end

@doc """
Optional batch processing stage.
This is a numerical function. It receives the result of `c:process_input/2`,
except the batch size may differ.
If the featurizer does not define batch processing, the input is
returned as is.
"""
@callback apply(t(), input :: any(), defn_options :: keyword()) :: any()
@spec process_batch(t(), Nx.t() | Nx.Container.t()) :: Nx.t() | Nx.Container.t()
def process_batch(%module{} = featurizer, batch) do
if Code.ensure_loaded?(module) and function_exported?(module, :process_batch, 2) do
module.process_batch(featurizer, batch)
else
batch
end
end
end
2 changes: 1 addition & 1 deletion lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ defmodule Bumblebee.Text.Generation do
"""
@spec extra_config_module(Bumblebee.ModelSpec.t()) :: module() | nil
def extra_config_module(%module{} = spec) do
if function_exported?(module, :extra_config_module, 1) do
if Code.ensure_loaded?(module) and function_exported?(module, :extra_config_module, 1) do
module.extra_config_module(spec)
end
end
Expand Down
40 changes: 24 additions & 16 deletions lib/bumblebee/vision/blip_featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,34 @@ defmodule Bumblebee.Vision.BlipFeaturizer do
end

@impl true
def apply(featurizer, images, _defn_options) do
def process_input(featurizer, images) do
images = List.wrap(images)

images =
for image <- images do
images =
image
|> Image.to_batched_tensor()
|> Nx.as_type(:f32)
|> Image.normalize_channels(length(featurizer.image_mean))

if featurizer.resize do
size = Image.normalize_size(featurizer.size)
NxImage.resize(images, size, method: featurizer.resize_method)
else
images
end
for image <- images do
images =
image
|> Image.to_batched_tensor()
|> Nx.as_type(:f32)
|> Image.normalize_channels(length(featurizer.image_mean))

if featurizer.resize do
size = Image.normalize_size(featurizer.size)
NxImage.resize(images, size, method: featurizer.resize_method)
else
images
end
|> Nx.concatenate()
end
|> Nx.concatenate()
end

@impl true
def batch_template(featurizer, batch_size) do
num_channels = length(featurizer.image_mean)
Nx.template({batch_size, featurizer.size, featurizer.size, num_channels}, :f32)
end

@impl true
def process_batch(featurizer, images) do
images = NxImage.to_continuous(images, 0, 1)

images =
Expand Down
46 changes: 27 additions & 19 deletions lib/bumblebee/vision/clip_featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -65,32 +65,40 @@ defmodule Bumblebee.Vision.ClipFeaturizer do
end

@impl true
def apply(featurizer, images, _defn_options) do
def process_input(featurizer, images) do
images = List.wrap(images)

images =
for image <- images do
images =
image
|> Image.to_batched_tensor()
|> Nx.as_type(:f32)
|> Image.normalize_channels(length(featurizer.image_mean))

images =
if featurizer.resize do
NxImage.resize_short(images, featurizer.size, method: featurizer.resize_method)
else
images
end

if featurizer.center_crop do
NxImage.center_crop(images, {featurizer.crop_size, featurizer.crop_size})
for image <- images do
images =
image
|> Image.to_batched_tensor()
|> Nx.as_type(:f32)
|> Image.normalize_channels(length(featurizer.image_mean))

images =
if featurizer.resize do
NxImage.resize_short(images, featurizer.size, method: featurizer.resize_method)
else
images
end

if featurizer.center_crop do
NxImage.center_crop(images, {featurizer.crop_size, featurizer.crop_size})
else
images
end
|> Nx.concatenate()
end
|> Nx.concatenate()
end

@impl true
def batch_template(featurizer, batch_size) do
num_channels = length(featurizer.image_mean)
Nx.template({batch_size, featurizer.size, featurizer.size, num_channels}, :f32)
end

@impl true
def process_batch(featurizer, images) do
images = NxImage.to_continuous(images, 0, 1)

images =
Expand Down
60 changes: 34 additions & 26 deletions lib/bumblebee/vision/convnext_featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -60,36 +60,44 @@ defmodule Bumblebee.Vision.ConvNextFeaturizer do
end

@impl true
def apply(featurizer, images, _defn_options) do
def process_input(featurizer, images) do
images = List.wrap(images)

images =
for image <- images do
images =
image
|> Image.to_batched_tensor()
|> Nx.as_type(:f32)
|> Image.normalize_channels(length(featurizer.image_mean))

cond do
not featurizer.resize ->
images

featurizer.size >= 384 ->
NxImage.resize(images, {featurizer.size, featurizer.size},
method: featurizer.resize_method
)

true ->
scale_size = floor(featurizer.size / featurizer.crop_percentage)

images
|> NxImage.resize_short(scale_size, method: featurizer.resize_method)
|> NxImage.center_crop({featurizer.size, featurizer.size})
end
for image <- images do
images =
image
|> Image.to_batched_tensor()
|> Nx.as_type(:f32)
|> Image.normalize_channels(length(featurizer.image_mean))

cond do
not featurizer.resize ->
images

featurizer.size >= 384 ->
NxImage.resize(images, {featurizer.size, featurizer.size},
method: featurizer.resize_method
)

true ->
scale_size = floor(featurizer.size / featurizer.crop_percentage)

images
|> NxImage.resize_short(scale_size, method: featurizer.resize_method)
|> NxImage.center_crop({featurizer.size, featurizer.size})
end
|> Nx.concatenate()
end
|> Nx.concatenate()
end

@impl true
def batch_template(featurizer, batch_size) do
num_channels = length(featurizer.image_mean)
Nx.template({batch_size, featurizer.size, featurizer.size, num_channels}, :f32)
end

@impl true
def process_batch(featurizer, images) do
images = NxImage.to_continuous(images, 0, 1)

images =
Expand Down
Loading

0 comments on commit 2af8cf3

Please sign in to comment.