Skip to content

Commit

Permalink
Change image size to maps in image featurizers (#329)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Feb 9, 2024
1 parent 5ea0cbc commit 18bda76
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 53 deletions.
42 changes: 42 additions & 0 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -459,4 +459,46 @@ defmodule Bumblebee.Shared do
}
}
end

@type featurizer_image_size ::
%{height: non_neg_integer(), width: non_neg_integer()}
| %{shortest_edge: non_neg_integer()}

@doc """
Returns an exact `{height, width}` size to resize images into.
Accepts a featurizer size map.
"""
@spec featurizer_resize_size(Nx.Tensor.t(), featurizer_image_size()) ::
{height :: non_neg_integer(), width :: non_neg_integer()}
def featurizer_resize_size(images, size)

def featurizer_resize_size(_images, %{height: height, width: width}), do: {height, width}

def featurizer_resize_size(images, %{shortest_edge: size}) do
{height, width} = images_spacial_sizes(images)

{short, long} = if height < width, do: {height, width}, else: {width, height}

out_short = size
out_long = floor(size * long / short)

if height < width, do: {out_short, out_long}, else: {out_long, out_short}
end

defp images_spacial_sizes(images) do
height = Nx.axis_size(images, -3)
width = Nx.axis_size(images, -2)
{height, width}
end

@doc """
Checks whether if the given featurizer image size is fixed or depends
on the input size.
"""
@spec featurizer_size_fixed?(featurizer_image_size()) :: boolean()
def featurizer_size_fixed?(size)

def featurizer_size_fixed?(%{height: _, width: _}), do: true
def featurizer_size_fixed?(%{shortest_edge: _}), do: false
end
30 changes: 30 additions & 0 deletions lib/bumblebee/shared/converters.ex
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,34 @@ defmodule Bumblebee.Shared.Converters do
end
end
end

def image_size(opts \\ []) do
opts = Keyword.validate!(opts, single_as: :both_edges)
single_as = opts[:single_as]

true = single_as in [:both_edges, :shortest_edge]

fn name, value ->
case value do
%{"height" => height, "width" => width} ->
{:ok, %{height: height, width: width}}

[height, width] ->
{:ok, %{height: height, width: width}}

size when is_number(size) and single_as == :both_edges ->
{:ok, %{height: size, width: size}}

size when is_number(size) and single_as == :shortest_edge ->
{:ok, %{shortest_edge: size}}

%{"shortest_edge" => shortest_edge} ->
{:ok, %{shortest_edge: shortest_edge}}

_ ->
{:error,
"expected #{inspect(name)} to be a number, a list or a map with height and width, got: #{inspect(value)}"}
end
end
end
end
29 changes: 8 additions & 21 deletions lib/bumblebee/vision/blip_featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ defmodule Bumblebee.Vision.BlipFeaturizer do
options = [
resize: [
default: true,
doc: "whether to resize (and optionally center crop) the input to the given `:size`"
doc: "whether to resize the input to the given `:size`"
],
size: [
default: 384,
default: %{height: 384, width: 384},
doc: """
the size to resize the input to. Either a single number or a `{height, width}` tuple.
Only has an effect if `:resize` is `true`
the size to resize the input to, given as `%{height: ..., width: ...}`. Only has
an effect if `:resize` is `true`
"""
],
resize_method: [
Expand Down Expand Up @@ -65,8 +65,8 @@ defmodule Bumblebee.Vision.BlipFeaturizer do
|> Image.normalize_channels(length(featurizer.image_mean))

if featurizer.resize do
size = Image.normalize_size(featurizer.size)
NxImage.resize(images, size, method: featurizer.resize_method)
%{height: height, width: width} = featurizer.size
NxImage.resize(images, {height, width}, method: featurizer.resize_method)
else
images
end
Expand All @@ -76,7 +76,7 @@ defmodule Bumblebee.Vision.BlipFeaturizer do

@impl true
def batch_template(featurizer, batch_size) do
{height, width} = Image.normalize_size(featurizer.size)
%{height: height, width: width} = featurizer.size
num_channels = length(featurizer.image_mean)
Nx.template({batch_size, height, width, num_channels}, :f32)
end
Expand Down Expand Up @@ -106,7 +106,7 @@ defmodule Bumblebee.Vision.BlipFeaturizer do
opts =
convert!(data,
resize: {"do_resize", boolean()},
size: {"size", one_of([number(), size_as_map()])},
size: {"size", image_size()},
resize_method: {"resample", resize_method()},
normalize: {"do_normalize", boolean()},
image_mean: {"image_mean", list(number())},
Expand All @@ -115,18 +115,5 @@ defmodule Bumblebee.Vision.BlipFeaturizer do

@for.config(featurizer, opts)
end

defp size_as_map() do
fn name, value ->
case value do
%{"height" => height, "width" => width} ->
{:ok, {height, width}}

_ ->
{:error,
"expected #{inspect(name)} to be a map with height and width, got: #{inspect(value)}"}
end
end
end
end
end
48 changes: 36 additions & 12 deletions lib/bumblebee/vision/clip_featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ defmodule Bumblebee.Vision.ClipFeaturizer do
options = [
resize: [
default: true,
doc: "whether to resize (and optionally center crop) the input to the given `:size`"
doc: "whether to resize the input to the given `:size`"
],
size: [
default: 224,
default: %{shortest_edge: 224},
doc: """
the size to resize the input to. The image is resized to (`:size`, `:size`). Only has
an effect if `:resize` is `true`
the size to resize the input to, either `%{height: ..., width: ...}` or `%{shortest_edge: ...}`.
Only has an effect if `:resize` is `true`
"""
],
resize_method: [
Expand All @@ -26,8 +26,11 @@ defmodule Bumblebee.Vision.ClipFeaturizer do
"""
],
crop_size: [
default: 224,
doc: "the size to center crop the image to. Only has an effect if `:center_crop` is `true`"
default: %{height: 224, width: 224},
doc: """
the size to center crop the image to, given as `%{height: ..., width: ...}`. Only has an effect
if `:center_crop` is `true`
"""
],
normalize: [
default: true,
Expand Down Expand Up @@ -61,7 +64,16 @@ defmodule Bumblebee.Vision.ClipFeaturizer do

@impl true
def config(featurizer, opts) do
Shared.put_config_attrs(featurizer, opts)
featurizer = Shared.put_config_attrs(featurizer, opts)

if featurizer.resize and Shared.featurizer_size_fixed?(featurizer.size) and
not featurizer.center_crop do
raise ArgumentError,
"the resize shape depends on the input shape and cropping is disabled." <>
"You must either configure a fixed size or enable cropping"
end

featurizer
end

@impl true
Expand All @@ -77,13 +89,15 @@ defmodule Bumblebee.Vision.ClipFeaturizer do

images =
if featurizer.resize do
NxImage.resize_short(images, featurizer.size, method: featurizer.resize_method)
size = Shared.featurizer_resize_size(images, featurizer.size)
NxImage.resize(images, size, method: featurizer.resize_method)
else
images
end

if featurizer.center_crop do
NxImage.center_crop(images, {featurizer.crop_size, featurizer.crop_size})
%{height: height, width: width} = featurizer.crop_size
NxImage.center_crop(images, {height, width})
else
images
end
Expand All @@ -94,7 +108,17 @@ defmodule Bumblebee.Vision.ClipFeaturizer do
@impl true
def batch_template(featurizer, batch_size) do
num_channels = length(featurizer.image_mean)
Nx.template({batch_size, featurizer.size, featurizer.size, num_channels}, :f32)

{height, width} =
case featurizer do
%{center_crop: true, crop_size: %{height: height, width: width}} ->
{height, width}

%{resize: true, size: %{height: height, width: width}} ->
{height, width}
end

Nx.template({batch_size, height, width, num_channels}, :f32)
end

@impl true
Expand Down Expand Up @@ -122,10 +146,10 @@ defmodule Bumblebee.Vision.ClipFeaturizer do
opts =
convert!(data,
resize: {"do_resize", boolean()},
size: {"size", number()},
size: {"size", image_size(single_as: :shortest_edge)},
resize_method: {"resample", resize_method()},
center_crop: {"do_center_crop", boolean()},
crop_size: {"crop_size", number()},
crop_size: {"crop_size", image_size()},
normalize: {"do_normalize", boolean()},
image_mean: {"image_mean", list(number())},
image_std: {"image_std", list(number())}
Expand Down
22 changes: 21 additions & 1 deletion lib/bumblebee/vision/convnext_featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ defmodule Bumblebee.Vision.ConvNextFeaturizer do
opts =
convert!(data,
resize: {"do_resize", boolean()},
size: {"size", number()},
size: {"size", size()},
resize_method: {"resample", resize_method()},
crop_percentage: {"crop_pct", number()},
normalize: {"do_normalize", boolean()},
Expand All @@ -131,5 +131,25 @@ defmodule Bumblebee.Vision.ConvNextFeaturizer do

@for.config(featurizer, opts)
end

defp size() do
# Note that in contrast to other featurizers, in this case size
# is always a single number and its meaning depends on the input
# size. huggingface/transformers put it under the "shortest_edge"
# key, but we keep it as a single number as it is more clear.
fn name, value ->
case value do
%{"shortest_edge" => size} ->
{:ok, size}

size when is_number(size) ->
{:ok, size}

_ ->
{:error,
"expected #{inspect(name)} to be a number or a map with shortest_edge, got: #{inspect(value)}"}
end
end
end
end
end
28 changes: 16 additions & 12 deletions lib/bumblebee/vision/deit_featurizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ defmodule Bumblebee.Vision.DeitFeaturizer do
options = [
resize: [
default: true,
doc: "whether to resize (and optionally center crop) the input to the given `:size`"
doc: "whether to resize the input to the given `:size`"
],
size: [
default: 256,
default: %{height: 256, width: 256},
doc: """
the size to resize the input to. Either a single number or a `{height, width}` tuple.
Only has an effect if `:resize` is `true`
the size to resize the input to, given as `%{height: ..., width: ...}`. Only has
an effect if `:resize` is `true`
"""
],
resize_method: [
Expand All @@ -26,8 +26,11 @@ defmodule Bumblebee.Vision.DeitFeaturizer do
"""
],
crop_size: [
default: 224,
doc: "the size to center crop the image to. Only has an effect if `:center_crop` is `true`"
default: %{height: 224, width: 224},
doc: """
the size to center crop the image to, given as `%{height: ..., width: ...}`. Only has an effect
if `:center_crop` is `true`
"""
],
normalize: [
default: true,
Expand Down Expand Up @@ -76,8 +79,8 @@ defmodule Bumblebee.Vision.DeitFeaturizer do
|> Image.normalize_channels(length(featurizer.image_mean))

if featurizer.resize do
size = Image.normalize_size(featurizer.size)
NxImage.resize(images, size, method: featurizer.resize_method)
%{height: height, width: width} = featurizer.size
NxImage.resize(images, {height, width}, method: featurizer.resize_method)
else
images
end
Expand All @@ -87,7 +90,7 @@ defmodule Bumblebee.Vision.DeitFeaturizer do

@impl true
def batch_template(featurizer, batch_size) do
{height, width} = Image.normalize_size(featurizer.size)
%{height: height, width: width} = featurizer.size
num_channels = length(featurizer.image_mean)
Nx.template({batch_size, height, width, num_channels}, :f32)
end
Expand All @@ -96,7 +99,8 @@ defmodule Bumblebee.Vision.DeitFeaturizer do
def process_batch(featurizer, images) do
images =
if featurizer.center_crop do
NxImage.center_crop(images, {featurizer.crop_size, featurizer.crop_size})
%{height: height, width: width} = featurizer.crop_size
NxImage.center_crop(images, {height, width})
else
images
end
Expand Down Expand Up @@ -124,10 +128,10 @@ defmodule Bumblebee.Vision.DeitFeaturizer do
opts =
convert!(data,
resize: {"do_resize", boolean()},
size: {"size", one_of([number(), tuple([number(), number()])])},
size: {"size", image_size()},
resize_method: {"resample", resize_method()},
center_crop: {"do_center_crop", boolean()},
crop_size: {"crop_size", number()},
crop_size: {"crop_size", image_size()},
normalize: {"do_normalize", boolean()},
image_mean: {"image_mean", list(number())},
image_std: {"image_std", list(number())}
Expand Down
Loading

0 comments on commit 18bda76

Please sign in to comment.