Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions keras_hub/src/models/clip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from keras_hub.src.models.clip.clip_backbone import CLIPBackbone
from keras_hub.src.models.clip.clip_presets import backbone_presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(backbone_presets, CLIPBackbone)
72 changes: 58 additions & 14 deletions keras_hub/src/models/clip/clip_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ class CLIPVisionPooler(layers.Layer):
"""

def call(self, vision_embeddings):
pooled_outputs = vision_embeddings[:, 0, :]
return pooled_outputs
return vision_embeddings[:, 0, :]

def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[-1])


class CLIPTextPooler(layers.Layer):
Expand All @@ -37,11 +39,15 @@ class CLIPTextPooler(layers.Layer):
"""

def call(self, text_embeddings, token_ids):
eos_index = ops.argmax(token_ids, axis=-1, keepdims=True)
# `keepdims` is not supported in `keras<=3.1`.
eos_index = ops.argmax(token_ids, axis=-1)
eos_index = ops.expand_dims(eos_index, axis=-1)
eos_index = ops.expand_dims(eos_index, axis=-1)
pooled_outputs = ops.take_along_axis(text_embeddings, eos_index, axis=1)
pooled_outputs = ops.squeeze(pooled_outputs, axis=1)
return pooled_outputs
return ops.squeeze(pooled_outputs, axis=1)

def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[-1])


class CLIPHead(layers.Layer):
Expand Down Expand Up @@ -86,6 +92,19 @@ def call(self, vision_embedding, text_embedding):
vision_logits = ops.transpose(text_logits)
return vision_logits, text_logits

def compute_output_shape(
self, vision_embedding_shape, text_embedding_shape
):
vision_logits_shape = (
vision_embedding_shape[0],
text_embedding_shape[0],
)
text_logits_shape = (
text_embedding_shape[0],
vision_embedding_shape[0],
)
return vision_logits_shape, text_logits_shape


@keras_hub_export("keras_hub.models.CLIPBackbone")
class CLIPBackbone(Backbone):
Expand Down Expand Up @@ -119,7 +138,7 @@ class CLIPBackbone(Backbone):
}

# Pretrained CLIP model.
model = keras_hub.models.CLIPBackbone.from_preset("clip-vit-base-patch32")
model = keras_hub.models.CLIPBackbone.from_preset("clip_vit_base_patch32")
model(input_data)

# Randomly initialized CLIP model with custom config.
Expand All @@ -140,8 +159,8 @@ class CLIPBackbone(Backbone):
intermediate_dim=2048,
)
model = keras_hub.models.CLIPBackbone(
vision_encoder=50257,
text_encoder=12,
vision_encoder=vision_encoder,
text_encoder=text_encoder,
projection_dim=256,
)
model(input_data)
Expand Down Expand Up @@ -183,12 +202,8 @@ def __init__(
token_id_input = layers.Input(
shape=(None,), dtype="int32", name="token_ids"
)
vision_outputs = self.vision_encoder({"images": image_input})
text_outputs = self.text_encoder({"token_ids": token_id_input})
vision_outputs = self.vision_pooler(vision_outputs)
text_outputs = self.text_pooler(text_outputs, token_id_input)
vision_embeddings = self.vision_projection(vision_outputs)
text_embeddings = self.text_projection(text_outputs)
vision_embeddings = self.get_vision_embeddings(image_input)
text_embeddings = self.get_text_embeddings(token_id_input)
vision_logits, text_logits = self.clip_head(
vision_embeddings, text_embeddings
)
Expand All @@ -202,13 +217,42 @@ def __init__(
"vision_logits": vision_logits,
"text_logits": text_logits,
},
dtype=dtype,
name=name,
**kwargs,
)

# === Config ===
self.projection_dim = projection_dim

def get_vision_embeddings(self, images):
"""Get the embeddings from the vision encoder.

Args:
images: The input tensor for the vision encoder.

Returns:
The output embeddings obtained by applying projection layer to the
pooled output of the vision encoder.
"""
vision_outputs = self.vision_encoder({"images": images})
vision_outputs = self.vision_pooler(vision_outputs)
return self.vision_projection(vision_outputs)

def get_text_embeddings(self, token_ids):
"""Get the embeddings from the text encoder.

Args:
token_ids: The input int tensor for the text encoder.

Returns:
The output embeddings obtained by applying projection layer to the
pooled output of the text encoder.
"""
text_outputs = self.text_encoder({"token_ids": token_ids})
text_outputs = self.text_pooler(text_outputs, token_ids)
return self.text_projection(text_outputs)

def get_config(self):
config = super().get_config()
config.update(
Expand Down
54 changes: 54 additions & 0 deletions keras_hub/src/models/clip/clip_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
from keras import ops

from keras_hub.src.models.clip.clip_backbone import CLIPBackbone
from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder
from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder
from keras_hub.src.tests.test_case import TestCase


class CLIPBackboneTest(TestCase):
def setUp(self):
vision_encoder = CLIPVisionEncoder(
16, 64, 2, 2, 128, name="vision_encoder"
)
text_encoder = CLIPTextEncoder(
64, 64, 64, 2, 2, 128, name="text_encoder"
)
self.init_kwargs = {
"vision_encoder": vision_encoder,
"text_encoder": text_encoder,
"projection_dim": 64,
}
self.input_data = {
"images": ops.ones((2, 224, 224, 3)),
"token_ids": ops.ones((2, 77), dtype="int32"),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=CLIPBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape={
"vision_logits": (2, 2),
"text_logits": (2, 2),
},
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=CLIPBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in CLIPBackbone.presets:
self.run_preset_test(
cls=CLIPBackbone,
preset=preset,
input_data=self.input_data,
)
33 changes: 30 additions & 3 deletions keras_hub/src/models/clip/clip_encoder_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,33 @@ def quick_gelu(x):
return x * ops.sigmoid(1.702 * x)


# TODO: Deprecate this in favor of `keras.layers.MultiHeadAttention` once the
# dtype compatibility issue is resolved.
class CLIPMultiHeadAttention(layers.MultiHeadAttention):
def _compute_attention(
self, query, key, value, attention_mask=None, training=None
):
query = ops.multiply(
query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
)
attention_scores = ops.einsum(self._dot_product_equation, key, query)
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
# Fix the dtype compatibility.
attention_scores = ops.cast(attention_scores, value.dtype)
if self.dropout:
final_attn_scores = self._dropout_layer(
attention_scores, training=training
)
else:
final_attn_scores = attention_scores
attention_output = ops.einsum(
self._combine_equation, final_attn_scores, value
)
return attention_output, attention_scores


class CLIPEncoderBlock(layers.Layer):
def __init__(
self,
Expand All @@ -33,16 +60,16 @@ def __init__(
intermediate_activation = quick_gelu

self.layer_norm_1 = layers.LayerNormalization(
epsilon=1e-5, dtype="float32", name="layer_norm_1"
epsilon=1e-5, dtype=self.dtype_policy, name="layer_norm_1"
)
self.attention = layers.MultiHeadAttention(
self.attention = CLIPMultiHeadAttention(
num_heads,
hidden_dim // num_heads,
dtype=self.dtype_policy,
name="attention",
)
self.layer_norm_2 = layers.LayerNormalization(
epsilon=1e-5, dtype="float32", name="layer_norm_2"
epsilon=1e-5, dtype=self.dtype_policy, name="layer_norm_2"
)
self.dense_1 = layers.Dense(
self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"
Expand Down
109 changes: 109 additions & 0 deletions keras_hub/src/models/clip/clip_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""CLIP model preset configurations."""

# Metadata for loading pretrained model weights.
backbone_presets = {
"clip_vit_base_patch16": {
"metadata": {
"description": (
"150 million parameter, 12-layer for vision and 12-layer for "
"text, patch size of 16, CLIP model."
),
"params": 149620934,
"official_name": "CLIP",
"path": "clip",
"model_card": "https://github.com/openai/CLIP/blob/main/model-card.md",
},
"kaggle_handle": "kaggle://kerashub/clip/keras/clip_vit_base_patch16/1",
},
"clip_vit_base_patch32": {
"metadata": {
"description": (
"151 million parameter, 12-layer for vision and 12-layer for "
"text, patch size of 32, CLIP model."
),
"params": 151277363,
"official_name": "CLIP",
"path": "clip",
"model_card": "https://github.com/openai/CLIP/blob/main/model-card.md",
},
"kaggle_handle": "kaggle://kerashub/clip/keras/clip_vit_base_patch32/1",
},
"clip_vit_large_patch14": {
"metadata": {
"description": (
"428 million parameter, 24-layer for vision and 12-layer for "
"text, patch size of 14, CLIP model."
),
"params": 427616770,
"official_name": "CLIP",
"path": "clip",
"model_card": "https://github.com/openai/CLIP/blob/main/model-card.md",
},
"kaggle_handle": "kaggle://kerashub/clip/keras/clip_vit_large_patch14/1",
},
"clip_vit_large_patch14_336": {
"metadata": {
"description": (
"428 million parameter, 24-layer for vision and 12-layer for "
"text, patch size of 14, image size of 336, CLIP model."
),
"params": 427944770,
"official_name": "CLIP",
"path": "clip",
"model_card": "https://github.com/openai/CLIP/blob/main/model-card.md",
},
"kaggle_handle": "kaggle://kerashub/clip/keras/clip_vit_large_patch14_336/1",
},
"clip_vit_b_32_laion2b_s34b_b79k": {
"metadata": {
"description": (
"151 million parameter, 12-layer for vision and 12-layer for "
"text, patch size of 32, Open CLIP model."
),
"params": 151277363,
"official_name": "Open CLIP",
"path": "clip",
"model_card": "https://github.com/mlfoundations/open_clip",
},
"kaggle_handle": "kaggle://kerashub/clip/keras/clip_vit_b_32_laion2b_s34b_b79k/1",
},
"clip_vit_h_14_laion2b_s32b_b79k": {
"metadata": {
"description": (
"986 million parameter, 32-layer for vision and 24-layer for "
"text, patch size of 14, Open CLIP model."
),
"params": 986109698,
"official_name": "Open CLIP",
"path": "clip",
"model_card": "https://github.com/mlfoundations/open_clip",
},
"kaggle_handle": "kaggle://kerashub/clip/keras/clip_vit_h_14_laion2b_s32b_b79k/1",
},
"clip_vit_g_14_laion2b_s12b_b42k": {
"metadata": {
"description": (
"1.4 billion parameter, 40-layer for vision and 24-layer for "
"text, patch size of 14, Open CLIP model."
),
"params": 1366678530,
"official_name": "Open CLIP",
"path": "clip",
"model_card": "https://github.com/mlfoundations/open_clip",
},
"kaggle_handle": "kaggle://kerashub/clip/keras/clip_vit_g_14_laion2b_s12b_b42k/1",
},
"clip_vit_bigg_14_laion2b_39b_b160k": {
"metadata": {
"description": (
"2.5 billion parameter, 48-layer for vision and 32-layer for "
"text, patch size of 14, Open CLIP model."
),
"params": 2539567362,
"official_name": "Open CLIP",
"path": "clip",
"model_card": "https://github.com/mlfoundations/open_clip",
},
"kaggle_handle": "kaggle://kerashub/clip/keras/clip_vit_bigg_14_laion2b_39b_b160k/1",
},
}
3 changes: 2 additions & 1 deletion keras_hub/src/models/clip/clip_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
for i in range(num_layers)
]
self.layer_norm = layers.LayerNormalization(
epsilon=1e-6, dtype="float32", name=f"{prefix}layer_norm"
epsilon=1e-6, dtype=dtype, name=f"{prefix}layer_norm"
)

# === Functional Model ===
Expand All @@ -108,6 +108,7 @@ def __init__(
super().__init__(
inputs={"token_ids": token_id_input},
outputs=outputs,
dtype=dtype,
name=name,
**kwargs,
)
Expand Down
Loading
Loading