In [3]:
from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image
import numpy as np
import os
from sklearn.metrics.pairwise import cosine_similarity

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def get_image_embeddings(image_list):
    inputs = processor(images=image_list, return_tensors="pt")
    with torch.no_grad():
        image_features = model.get_image_features(**inputs)
    return image_features

image_dir = 'data'
image_list = []
image_filenames = []

for filename in os.listdir(image_dir):
    if filename.endswith(('jpg', 'jpeg', 'png', 'bmp', 'gif')):
        image_path = os.path.join(image_dir, filename)
        img = Image.open(image_path)
        image_list.append(img)
        image_filenames.append(filename)

image_embeddings = get_image_embeddings(image_list)

image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)

query_image_path = "./data/input/pineapple.jpg"
query_image = Image.open(query_image_path)
query_embedding = get_image_embeddings([query_image])
query_embedding = query_embedding / query_embedding.norm(dim=-1, keepdim=True)

similarity_scores = cosine_similarity(query_embedding, image_embeddings)

print(sorted(similarity_scores[0]))


best_match_index = np.argmax(similarity_scores)
best_match_filename = image_filenames[best_match_index]

#print(f"Best match: {best_match_filename} with similarity score: {similarity_scores[0][best_match_index]:.4f}")


# Get the top 3 similar items
top_k = 3
top_k_indices = np.argsort(similarity_scores[0])[-top_k:][::-1]  # Indices of top 3 matches
top_k_filenames = [image_filenames[i] for i in top_k_indices]
top_k_scores = [similarity_scores[0][i] for i in top_k_indices]

# Print the top 3 similar items
for i, (filename, score) in enumerate(zip(top_k_filenames, top_k_scores)):
    print(f"Top {i + 1} match: {filename} with similarity score: {score:.4f}")

[np.float32(0.21607408), np.float32(0.4147325), np.float32(0.44693565), np.float32(0.45196933), np.float32(0.47855628), np.float32(0.48485374), np.float32(0.4991418), np.float32(0.53800696), np.float32(0.61015093), np.float32(0.6744036), np.float32(0.6748359), np.float32(0.7826628), np.float32(0.81130695), np.float32(0.8668111)]
Top 1 match: product-img-13.jpg with similarity score: 0.8668
Top 2 match: pineapple-2.jpg with similarity score: 0.8113
Top 3 match: pine3.jpg with similarity score: 0.7827
