In [57]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [7]:
def get_text_embedding(text: str):
    inputs = clip_processor(text=[text], return_tensors="pt", padding=True)
    text_embeddings = clip_model.get_text_features(**inputs)
    return text_embeddings

In [37]:
def get_image_embedding(image_path: str):
    image = Image.open(image_path)
    inputs = clip_processor(images=image, return_tensors="pt")
    image_embeddings = clip_model.get_image_features(**inputs)
    return image_embeddings


In [79]:
def get_joint_embedding(image_path: str, text: str):
    # Preprocessing
    image = Image.open(image_path)
    image_inputs = clip_processor(images=image, return_tensors="pt")
    text_inputs = clip_processor(text=[text], return_tensors="pt", padding=True)

    # Generate embeddings
    with torch.no_grad():
        image_embedding = clip_model.get_image_features(**image_inputs)
        text_embedding = clip_model.get_text_features(**text_inputs)
    joint_embedding = torch.cat([image_embedding, text_embedding], dim=-1)

    return joint_embedding

In [80]:
# Verify length to be 2 x 512
len(get_joint_embedding("../data/motorcycle_1.jpg", "a bike")[0])

1024

In [100]:
# Examples
ex1 = get_joint_embedding("../data/motorcycle_1.jpg", "a bike outside")[0]
ex2 = get_joint_embedding("../data/motorcycle_2.jpg", "a bike outside")[0]
ex3 = get_joint_embedding("../data/cat_2.jpeg", "a bike outside")[0]

# Cosine Similarity

In [101]:
import numpy as np

def cosine_similarity(vec1, vec2):
    similarity = np.dot(vec1,vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2))
    return similarity

In [102]:
ex1_embed = np.array(ex1)
ex2_embed = np.array(ex2)
ex3_embed = np.array(ex3)
sim_ex1_ex2 = cosine_similarity(ex1_embed, ex2_embed)
sim_ex1_ex3 = cosine_similarity(ex1_embed, ex3_embed)

  ex1_embed = np.array(ex1)
  ex2_embed = np.array(ex2)
  ex3_embed = np.array(ex3)


In [103]:
print("Cosine similarity between ex1_embeded and ex2_embeded is:")
display(sim_ex1_ex2)
print("Cosine similarity between ex1_embeded and ex3_embeded is:")
display(sim_ex1_ex3)

Cosine similarity between ex1_embeded and ex2_embeded is:


np.float32(0.8063979)

Cosine similarity between ex1_embeded and ex3_embeded is:


np.float32(0.7158405)