Skip to content

Commit

Permalink
Add projection heads for ClipText and ClipVision (#276)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Nov 12, 2023
1 parent 5df9cbd commit 76a45d3
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 14 deletions.
2 changes: 1 addition & 1 deletion lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ defmodule Bumblebee.Text do
{:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-large"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-large"})
serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer)
serving = Bumblebee.Text.text_embedding(model_info, tokenizer)
text = "query: Cats are cute."
Nx.Serving.run(serving, text)
Expand Down
30 changes: 28 additions & 2 deletions lib/bumblebee/text/clip_text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ defmodule Bumblebee.Text.ClipText do
doc:
"the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder"
],
projection_size: [
default: 512,
doc: "the dimensionality of the projection layer"
],
activation: [
default: :quick_gelu,
doc: "the activation function"
Expand Down Expand Up @@ -62,6 +66,10 @@ defmodule Bumblebee.Text.ClipText do
* `:base` - the base text model
* `:for_embedding` - the base model with a single projection layer
on top. The head returns a vector embedded in the joint text-image
CLIP space
## Inputs
* `"input_ids"` - `{batch_size, sequence_length}`
Expand Down Expand Up @@ -95,7 +103,7 @@ defmodule Bumblebee.Text.ClipText do
alias Bumblebee.Layers

@impl true
def architectures(), do: [:base]
def architectures(), do: [:base, :for_embedding]

@impl true
def config(spec, opts \\ []) do
Expand All @@ -120,6 +128,22 @@ defmodule Bumblebee.Text.ClipText do
|> Layers.output()
end

def model(%__MODULE__{architecture: :for_embedding} = spec) do
inputs = inputs()

outputs = core(inputs, spec)

embedding =
outputs.pooled_state
|> Axon.dense(spec.projection_size, use_bias: false, name: "embedding_head.output")

Layers.output(%{
embedding: embedding,
hidden_states: outputs.hidden_states,
attentions: outputs.attentions
})
end

defp inputs() do
shape = {nil, nil}

Expand Down Expand Up @@ -226,6 +250,7 @@ defmodule Bumblebee.Text.ClipText do
num_blocks: {"num_hidden_layers", number()},
num_attention_heads: {"num_attention_heads", number()},
intermediate_size: {"intermediate_size", number()},
projection_size: {"projection_dim", number()},
activation: {"hidden_act", atom()},
attention_dropout_rate: {"attention_dropout", number()},
layer_norm_epsilon: {"layer_norm_eps", number()}
Expand All @@ -252,7 +277,8 @@ defmodule Bumblebee.Text.ClipText do
"encoder.blocks.{n}.ffn.intermediate" => "text_model.encoder.layers.{n}.mlp.fc1",
"encoder.blocks.{n}.ffn.output" => "text_model.encoder.layers.{n}.mlp.fc2",
"encoder.blocks.{n}.output_norm" => "text_model.encoder.layers.{n}.layer_norm2",
"norm" => "text_model.final_layer_norm"
"norm" => "text_model.final_layer_norm",
"embedding_head.output" => "text_projection"
}
end
end
Expand Down
17 changes: 13 additions & 4 deletions lib/bumblebee/text/text_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,19 @@ defmodule Bumblebee.Text.TextEmbedding do
output = encoder.(params, inputs)

output =
if is_map(output) do
output[output_attribute]
else
output
case output do
%{^output_attribute => output} ->
output

%{} ->
keys = output |> Map.keys() |> Enum.sort()

raise ArgumentError,
"key #{inspect(output_attribute)} not found in the output map," <>
" you may want to set :output_attribute to one of the map keys: #{inspect(keys)}"

_ ->
output
end

output =
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/vision.ex
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ defmodule Bumblebee.Vision do
module: Bumblebee.Vision.ClipVision
)
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/clip-vit-base-patch32"})
serving = Bumblebee.Vision.ImageEmbedding.image_embedding(clip, featurizer)
serving = Bumblebee.Vision.image_embedding(clip, featurizer)
image = StbImage.read_file!(path)
Nx.Serving.run(serving, image)
#=> %{
Expand Down
30 changes: 28 additions & 2 deletions lib/bumblebee/vision/clip_vision.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ defmodule Bumblebee.Vision.ClipVision do
docs:
"the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder"
],
projection_size: [
default: 512,
doc: "the dimensionality of the projection layer"
],
activation: [
default: :quick_gelu,
doc: "the activation function"
Expand All @@ -57,6 +61,10 @@ defmodule Bumblebee.Vision.ClipVision do
* `:base` - the base image model
* `:for_embedding` - the base model with a single projection layer
on top. The head returns a vector embedded in the joint text-image
CLIP space
## Inputs
* `"pixel_values"` - `{batch_size, image_size, image_size, num_channels}`
Expand All @@ -78,7 +86,7 @@ defmodule Bumblebee.Vision.ClipVision do
alias Bumblebee.Layers

@impl true
def architectures(), do: [:base]
def architectures(), do: [:base, :for_embedding]

@impl true
def config(spec, opts \\ []) do
Expand All @@ -102,6 +110,22 @@ defmodule Bumblebee.Vision.ClipVision do
|> Layers.output()
end

def model(%__MODULE__{architecture: :for_embedding} = spec) do
inputs = inputs(spec)

outputs = core(inputs, spec)

embedding =
outputs.pooled_state
|> Axon.dense(spec.projection_size, use_bias: false, name: "projection_head.output")

Layers.output(%{
embedding: embedding,
hidden_states: outputs.hidden_states,
attentions: outputs.attentions
})
end

defp inputs(spec) do
shape = {nil, spec.image_size, spec.image_size, spec.num_channels}

Expand Down Expand Up @@ -220,6 +244,7 @@ defmodule Bumblebee.Vision.ClipVision do
num_blocks: {"num_hidden_layers", number()},
num_attention_heads: {"num_attention_heads", number()},
intermediate_size: {"intermediate_size", number()},
projection_size: {"projection_dim", number()},
activation: {"hidden_act", atom()},
attention_dropout_rate: {"attention_dropout", number()},
layer_norm_epsilon: {"layer_norm_eps", number()}
Expand Down Expand Up @@ -253,7 +278,8 @@ defmodule Bumblebee.Vision.ClipVision do
"encoder.blocks.{n}.ffn.output" => "vision_model.encoder.layers.{n}.mlp.fc2",
"encoder.blocks.{n}.output_norm" => "vision_model.encoder.layers.{n}.layer_norm2",
"pre_norm" => "vision_model.pre_layrnorm",
"post_norm" => "vision_model.post_layernorm"
"post_norm" => "vision_model.post_layernorm",
"projection_head.output" => "visual_projection"
}
end
end
Expand Down
17 changes: 13 additions & 4 deletions lib/bumblebee/vision/image_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,19 @@ defmodule Bumblebee.Vision.ImageEmbedding do
output = encoder.(params, inputs)

output =
if is_map(output) do
output[output_attribute]
else
output
case output do
%{^output_attribute => output} ->
output

%{} ->
keys = output |> Map.keys() |> Enum.sort()

raise ArgumentError,
"key #{inspect(output_attribute)} not found in the output map," <>
" you may want to set :output_attribute to one of the map keys: #{inspect(keys)}"

_ ->
output
end

output =
Expand Down
28 changes: 28 additions & 0 deletions test/bumblebee/text/clip_text_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,33 @@ defmodule Bumblebee.Text.ClipTextTest do
atol: 1.0e-4
)
end

test "embedding model" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "openai/clip-vit-base-patch32"},
module: Bumblebee.Text.ClipText,
architecture: :for_embedding
)

assert %Bumblebee.Text.ClipText{architecture: :for_embedding} = spec

inputs = %{
"input_ids" =>
Nx.tensor([
[49406, 320, 1125, 539, 320, 2368, 49407]
]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1]])
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.embedding) == {1, 512}

assert_all_close(
outputs.embedding[[.., 1..3]],
Nx.tensor([[0.0733, -0.2448, -0.2212]]),
atol: 1.0e-4
)
end
end
end
24 changes: 24 additions & 0 deletions test/bumblebee/vision/clip_vision_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,29 @@ defmodule Bumblebee.Vision.ClipVisionTest do
atol: 1.0e-4
)
end

test "embedding model" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "openai/clip-vit-base-patch32"},
module: Bumblebee.Vision.ClipVision,
architecture: :for_embedding
)

assert %Bumblebee.Vision.ClipVision{architecture: :for_embedding} = spec

inputs = %{
"pixel_values" => Nx.broadcast(0.5, {1, 224, 224, 3})
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.embedding) == {1, 512}

assert_all_close(
outputs.embedding[[.., 1..3]],
Nx.tensor([[-0.3381, -0.0196, -0.4053]]),
atol: 1.0e-4
)
end
end
end

0 comments on commit 76a45d3

Please sign in to comment.