# Import

In [1]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

In [2]:
import json
from datetime import datetime

import keras
from keras import ops

import keras_cv
from keras_cv.models.feature_extractor.clip import CLIPProcessor
from keras_cv.models import CLIP

2024-04-08 22:25:42.921695: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-04-08 22:25:42.928541: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-04-08 22:25:43.041846: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
MODEL_CONFIGS = {
    "CLIP_B32": {
        "embed_dim": 512,
        "context_length": 77,
        "vocab_size": 49408,
        "transformer_width": 512,
        "transformer_heads": 8,
        "transformer_layers": 12,
        "vision_layers": 12,
        "vision_width": 768,
        "image_resolution": 224,
        "vision_patch_size": 32,
    },
    "CLIP_B16": {
        "embed_dim": 512,
        "context_length": 77,
        "vocab_size": 49408,
        "transformer_width": 512,
        "transformer_heads": 8,
        "transformer_layers": 12,
        "vision_layers": 12,
        "vision_width": 768,
        "image_resolution": 224,
        "vision_patch_size": 16,
    },
    "CLIP_L14": {
        "embed_dim": 768,
        "context_length": 77,
        "vocab_size": 49408,
        "transformer_width": 768,
        "transformer_heads": 12,
        "transformer_layers": 12,
        "vision_layers": 24,
        "vision_width": 1024,
        "image_resolution": 224,
        "vision_patch_size": 14,
    },
    "CLIP_L14_336": {
        "embed_dim": 768,
        "context_length": 77,
        "vocab_size": 49408,
        "transformer_width": 768,
        "transformer_heads": 12,
        "transformer_layers": 12,
        "vision_layers": 24,
        "vision_width": 1024,
        "image_resolution": 336,
        "vision_patch_size": 14,
    },
}
model_map_hf = {
    "CLIP_B16": "openai/clip-vit-base-patch16",
    "CLIP_B32": "openai/clip-vit-base-patch32",
    "CLIP_L14": "openai/clip-vit-large-patch14",
    "CLIP_L14_336": "openai/clip-vit-large-patch14-336",
}
config_name = "CLIP_B16"
config_name_hf = model_map_hf[config_name]

# Keras 3 CLIP

In [4]:
embed_dim = MODEL_CONFIGS[config_name]["embed_dim"]
context_length = MODEL_CONFIGS[config_name]["context_length"]
vocab_size = MODEL_CONFIGS[config_name]["vocab_size"]
transformer_width = MODEL_CONFIGS[config_name]["transformer_width"]
transformer_heads = MODEL_CONFIGS[config_name]["transformer_heads"]
transformer_layers = MODEL_CONFIGS[config_name]["transformer_layers"]
vision_layers = MODEL_CONFIGS[config_name]["vision_layers"]
vision_width = MODEL_CONFIGS[config_name]["vision_width"]
vision_patch_size = MODEL_CONFIGS[config_name]["vision_patch_size"]
image_resolution = MODEL_CONFIGS[config_name]["image_resolution"]
model = CLIP(
    embed_dim,
    image_resolution,
    vision_layers,
    vision_width,
    vision_patch_size,
    context_length,
    vocab_size,
    transformer_width,
    transformer_heads,
    transformer_layers,
)

In [5]:
model.summary()

# HF CLIP

In [11]:
from PIL import Image
import requests
import torch

from transformers import CLIPProcessor as CP
from transformers import CLIPModel as CM

In [7]:
model_hf = CM.from_pretrained(config_name_hf)
processor_hf = CP.from_pretrained(config_name_hf)

# Copy weights

In [8]:
# hugging face weights
hf_wts = model_hf.state_dict()

## Vision Encoder

In [9]:
model.get_layer("clip_head").logit_scale.assign(
    hf_wts.pop("logit_scale").numpy()
)
model.get_layer("image_encoder").get_layer(
    "clip_patch_embedding"
).class_embedding.assign(
    hf_wts.pop("vision_model.embeddings.class_embedding").numpy().T
)
model.get_layer("image_encoder").get_layer(
    "clip_patch_embedding"
).positional_embedding.assign(
    hf_wts.pop("vision_model.embeddings.position_embedding.weight").numpy()
)
model.get_layer("image_encoder").get_layer(
    "clip_patch_embedding"
).conv1.weights[0].assign(
    hf_wts.pop("vision_model.embeddings.patch_embedding.weight")
    .permute(3, 2, 1, 0)
    .numpy()
)
model.get_layer("image_encoder").get_layer("ln_1").weights[0].assign(
    hf_wts.pop("vision_model.pre_layrnorm.weight").numpy()
)
model.get_layer("image_encoder").get_layer("ln_1").weights[1].assign(
    hf_wts.pop("vision_model.pre_layrnorm.bias").numpy()
)
model.get_layer("image_encoder").get_layer("ln_2").weights[0].assign(
    hf_wts.pop("vision_model.post_layernorm.weight").numpy()
)
model.get_layer("image_encoder").get_layer("ln_2").weights[1].assign(
    hf_wts.pop("vision_model.post_layernorm.bias").numpy()
)
model.get_layer("image_encoder").get_layer("vision_projector").weights[
    0
].assign(hf_wts.pop("visual_projection.weight").numpy().T)

In [10]:
for i in range(0, MODEL_CONFIGS[config_name]["vision_layers"]):
    if i == 0:
        residual_attention = f"residual_attention"
    else:
        residual_attention = f"residual_attention_{i}"

    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.q_proj.weights[0].assign(
        hf_wts.pop(f"vision_model.encoder.layers.{i}.self_attn.q_proj.weight").T
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.q_proj.weights[1].assign(
        hf_wts.pop(f"vision_model.encoder.layers.{i}.self_attn.q_proj.bias")
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.k_proj.weights[0].assign(
        hf_wts.pop(f"vision_model.encoder.layers.{i}.self_attn.k_proj.weight").T
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.k_proj.weights[1].assign(
        hf_wts.pop(f"vision_model.encoder.layers.{i}.self_attn.k_proj.bias")
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.v_proj.weights[0].assign(
        hf_wts.pop(f"vision_model.encoder.layers.{i}.self_attn.v_proj.weight").T
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.v_proj.weights[1].assign(
        hf_wts.pop(f"vision_model.encoder.layers.{i}.self_attn.v_proj.bias")
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.out_proj.weights[1].assign(
        hf_wts.pop(
            f"vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
        ).numpy()
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.out_proj.weights[0].assign(
        hf_wts.pop(f"vision_model.encoder.layers.{i}.self_attn.out_proj.weight")
        .numpy()
        .T
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].ln_1.weights[0].assign(
        hf_wts.pop(
            f"vision_model.encoder.layers.{i}.layer_norm1.weight"
        ).numpy()
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].ln_1.weights[1].assign(
        hf_wts.pop(f"vision_model.encoder.layers.{i}.layer_norm1.bias").numpy()
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].ln_2.weights[0].assign(
        hf_wts.pop(
            f"vision_model.encoder.layers.{i}.layer_norm2.weight"
        ).numpy()
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].ln_2.weights[1].assign(
        hf_wts.pop(f"vision_model.encoder.layers.{i}.layer_norm2.bias").numpy()
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].mlp_dense_1.weights[0].assign(
        hf_wts.pop(f"vision_model.encoder.layers.{i}.mlp.fc1.weight").numpy().T
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].mlp_dense_1.weights[1].assign(
        hf_wts.pop(f"vision_model.encoder.layers.{i}.mlp.fc1.bias").numpy()
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].mlp_dense_2.weights[0].assign(
        hf_wts.pop(f"vision_model.encoder.layers.{i}.mlp.fc2.weight").numpy().T
    )
    model.get_layer("image_encoder").get_layer("clip_encoder").resblocks[
        i
    ].mlp_dense_2.weights[1].assign(
        hf_wts.pop(f"vision_model.encoder.layers.{i}.mlp.fc2.bias").numpy()
    )

## Text encoder

In [11]:
model.get_layer("text_encoder").get_layer("text_projector").weights[0].assign(
    hf_wts.pop("text_projection.weight").numpy().T
)
model.get_layer("text_encoder").get_layer("token_embedding").weights[0].assign(
    hf_wts.pop("text_model.embeddings.token_embedding.weight").numpy()
)
model.get_layer("text_encoder").get_layer("positional_embedding").weights[
    0
].assign(hf_wts.pop("text_model.embeddings.position_embedding.weight").numpy())
model.get_layer("text_encoder").get_layer("ln_final").weights[0].assign(
    hf_wts.pop("text_model.final_layer_norm.weight")
)
model.get_layer("text_encoder").get_layer("ln_final").weights[1].assign(
    hf_wts.pop("text_model.final_layer_norm.bias")
)

In [12]:
for i in range(MODEL_CONFIGS[config_name]["transformer_layers"]):
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.k_proj.weights[0].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.self_attn.k_proj.weight").T
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.k_proj.weights[1].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.self_attn.k_proj.bias")
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.q_proj.weights[0].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.self_attn.q_proj.weight").T
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.q_proj.weights[1].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.self_attn.q_proj.bias")
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.v_proj.weights[0].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.self_attn.v_proj.weight").T
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.v_proj.weights[1].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.self_attn.v_proj.bias")
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.out_proj.weights[0].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.self_attn.out_proj.weight").T
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].attn.out_proj.weights[1].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.self_attn.out_proj.bias")
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].ln_1.weights[0].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.layer_norm1.weight").numpy()
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].ln_1.weights[1].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.layer_norm1.bias").numpy()
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].ln_2.weights[0].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.layer_norm2.weight").numpy()
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].ln_2.weights[1].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.layer_norm2.bias").numpy()
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].mlp_dense_1.weights[0].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.mlp.fc1.weight").numpy().T
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].mlp_dense_1.weights[1].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.mlp.fc1.bias").numpy()
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].mlp_dense_2.weights[0].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.mlp.fc2.weight").numpy().T
    )
    model.get_layer("text_encoder").get_layer("clip_encoder").resblocks[
        i
    ].mlp_dense_2.weights[1].assign(
        hf_wts.pop(f"text_model.encoder.layers.{i}.mlp.fc2.bias").numpy()
    )

In [13]:
# verify that we copied all weights
hf_wts.keys()

odict_keys([])

# Save Weights

In [14]:
os.makedirs(config_name, exist_ok=True)
model.save_weights(os.path.join(config_name, "model.weights.h5"))

In [15]:
config = {
    "module": "keras_cv.models.feature_extractor.clip.clip_model",
    "class_name": "CLIP",
    "config": model.get_config(),
    "registered_name": "keras_cv>CLIP",
    "weights": "model.weights.h5",
}

with open(os.path.join(config_name, "config.json"), "w") as config_file:
    json.dump(config, config_file)

metadata = {
    "keras_version": keras.__version__,
    "keras_cv_version": keras_cv.__version__,
    "parameter_count": model.count_params(),
    "date_saved": datetime.utcnow().strftime("%Y-%m-%d@%H:%M:%S"),
}

with open(os.path.join(config_name, "metadata.json"), "w") as metadata_file:
    json.dump(metadata, metadata_file)

# Verify numerics

In [16]:
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# image = Image.open(requests.get(url, stream=True).raw)

# inputs = processor_hf(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)

# outputs = model_hf(**inputs)
# logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
# probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities

In [17]:
# import matplotlib.pyplot as plt

# plt.imshow(image)

In [18]:
# probs

In [19]:
# VOCAB_PATH = keras.utils.get_file(None, "https://storage.googleapis.com/keras-cv/models/clip/vocab.json")
# MERGE_PATH = keras.utils.get_file(None, "https://storage.googleapis.com/keras-cv/models/clip/merges.txt")

In [20]:
# processor = CLIPProcessor(224, VOCAB_PATH, MERGE_PATH)
# text_processed = processor(["a photo of a cat", "a photo of a dog"])

In [21]:
# image_processed = ops.convert_to_tensor(inputs['pixel_values'].detach().cpu().permute(0, 2, 3, 1).numpy())

In [22]:
# outputs = model({
#     "images": image_processed,
#     **text_processed
# })

In [23]:
# ops.softmax(outputs["image_logits"], axis=1)  # we can take the softmax to get the label probabilities

In [24]:
# model.load_weights("model.weights.h5")

In [25]:
# outputs = model({
#     "images": image_processed,
#     **text_processed
# })

In [26]:
# ops.softmax(outputs["image_logits"], axis=1)  # we can take the softmax to get the label probabilities