In [None]:
!pip install torch transformers pillow scikit-learn numpy
import numpy as np
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics.pairwise import cosine_similarity
from itertools import product


In [44]:

class EmbeddingComparator:

    # Constructor to initialize the CLIP model and processor (options: "openai/clip-vit-base-patch32" for speed, "openai/clip-vit-large-patch14" for accuracy).

    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = CLIPModel.from_pretrained(model_name).to(self.device)
        self.processor = CLIPProcessor.from_pretrained(model_name)

    # Encode an image into an embedding vector (returns a numpy array of shape (1, embedding_dim)).
        
    def encode_image(self, image_path):
        
        image = Image.open(image_path)
        inputs = self.processor(images=image, return_tensors="pt", padding=True).to(self.device)
        
        with torch.no_grad():
            image_features = self.model.get_image_features(**inputs)
        
        # Normalize embeddings
        image_embedding = image_features / image_features.norm(dim=-1, keepdim=True)
        return image_embedding.cpu().numpy()
    
    # Encode text into an embedding vector (returns a numpy array of shape (1, embedding_dim)).

    def encode_text(self, text):
       
        inputs = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
        
        with torch.no_grad():
            text_features = self.model.get_text_features(**inputs)
        
        # Normalize embeddings
        text_embedding = text_features / text_features.norm(dim=-1, keepdim=True)
        return text_embedding.cpu().numpy()
    
    # Compute cosine similarity between two embedding vectors (returns a score between -1 and 1).
    
    def compare_embeddings(self, embedding1, embedding2):
        
        return cosine_similarity(embedding1, embedding2)[0][0]


In [None]:
def testing(imagepath, simple_text, detailed_text, different_text, random_text):

    image_path = imagepath
    comparator = EmbeddingComparator()
    image_embedding = comparator.encode_image(image_path) 

    # Example 1: Compare image with text 
    text = simple_text
    text_embedding = comparator.encode_text(text)
    similarity = comparator.compare_embeddings(image_embedding, text_embedding)
    print(f"Similarity between image and '{text}': {similarity:.4f}")

    # Example 2: Compare with another text
    text = detailed_text
    text_embedding = comparator.encode_text(text)
    similarity = comparator.compare_embeddings(image_embedding, text_embedding)
    print(f"Similarity between image and '{text}': {similarity:.4f}")

    # Example 3: Compare with another text
    text = different_text
    text_embedding = comparator.encode_text(text)
    similarity = comparator.compare_embeddings(image_embedding, text_embedding)
    print(f"Similarity between image and '{text}': {similarity:.4f}")

    # Example 4: Compare with random text
    text = random_text
    text_embedding = comparator.encode_text(text)
    similarity = comparator.compare_embeddings(image_embedding, text_embedding)
    print(f"Similarity between image and '{text}': {similarity:.4f}")
    print(f"")

In [92]:
# Basic testing of the comparator to check if it works as expected (should return 1.0 for identical inputs)
print(f"-------BASIC TEST-------"); print(f"")
comparator = EmbeddingComparator();
embedding1 = comparator.encode_image("dog.jpg"); 
embedding2 = comparator.encode_image("dog.jpg");
similarity = comparator.compare_embeddings(embedding1, embedding2);
print(f"Similarity between the image and itself: {similarity:.4f}")
embedding1 = comparator.encode_text("dog");
embedding2 = comparator.encode_text("dog");
similarity = comparator.compare_embeddings(embedding1, embedding2);   
print(f"Similarity between the text and itself: {similarity:.4f}")
print(f"")  


# First we will use the simple image of a dog
print(f"-------SIMPLE IMAGE OF A DOG-------"); print(f"")
testing("dog.jpg", "a picture of a dog", "a picture of a golden retriever sticking its tongue out", "a cat", "??dasd;'12p[]")
print(f"")

# Next we will use the picture of a dog and a ball
print(f"-------IMAGE OF A DOG AND A BALL-------"); print(f"")
testing("dog_ball.jpg", "a picture of a dog and a ball", "a picture of a small brown dog and a yellow ball", "a cat", "??dasd;'12p[]")
print(f"")

-------BASIC TEST-------

Similarity between the image and itself: 1.0000
Similarity between the text and itself: 1.0000

-------SIMPLE IMAGE OF A DOG-------

Similarity between image and 'a picture of a dog': 0.2776
Similarity between image and 'a picture of a golden retriever sticking its tongue out': 0.3076
Similarity between image and 'a cat': 0.1926
Similarity between image and '??dasd;'12p[]': 0.2000


-------IMAGE OF A DOG AND A BALL-------

Similarity between image and 'a picture of a dog and a ball': 0.2975
Similarity between image and 'a picture of a small brown dog and a yellow ball': 0.3001
Similarity between image and 'a cat': 0.1954
Similarity between image and '??dasd;'12p[]': 0.2089


