In [1]:
"""
Script to generate embeddings for images (for catalog images) and store embeddings to Qdrant database.
@File    : image_embeddings_experiments.ipynb
@Author  : Nandini Bohra
@Contact : nbohra@ucsd.edu

@References : https://www.youtube.com/watch?v=MlRkBvOCfTY
"""

'\nScript to generate embeddings for images (for catalog images) and store embeddings to Qdrant database.\n@File    : image_embeddings_experiments.ipynb\n@Author  : Nandini Bohra\n@Contact : nbohra@ucsd.edu\n\n@References : https://www.youtube.com/watch?v=MlRkBvOCfTY\n'

In [None]:
# imports
import os
import pandas as pd

# for image resizing to b64
from io import BytesIO
import math
import base64

In [None]:
# Importing in the payloads csv on all image information
payloads = pd.read_csv("payloads.csv")

In [None]:
# Transforming local images into PIL images

images = list(map(lambda x: Image.open(x), payloads["image_url"]))
images[:5]

In [None]:
# Resizing originals to smaller images and converting to base-64 rep if needed to show on front end

target_width = 256

# Resizing images to target width
# Returns PIL image
def resize_img(url):
    pil_img = Image.open(url)
    img_aspect_ratio = pil_img.width / pil_img.height
    resized_img = pil_img.resize(
        (target_width, math.floor(target_width * img_aspect_ratio))
    )

    return resized_img

# Converting PIL image to base64 string
def img_to_base64(pil_img):
    image_data = BytesIO()
    pil_img.save(image_data, format="JPEG")
    base64_string = base64.b64encode(image_data.getvalue()).decode("utf-8")
    return base64_string

# Saving base64 reps to payloads dataframe
resized_images = list(map(lambda x: resize_img(x), payloads["image_url"])) 
base64_images = list(map(lambda x: img_to_base64(x), resized_images))
payloads["base64"] = base64_images
payloads.head()


In [None]:
# Trial with Microsoft Resnet-50 model
# https://huggingface.co/microsoft/resnet-50

from transformers import AutoImageProcessor, ResNetForImageClassification

processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")

inputs = processor(
    images, 
    return_tensors="pt", 
    # padding=True
)

outputs = model(**inputs)
embeddings = outputs.logits
embeddings

# Evaluated embeddings... not sure if this is the right fit 
# Researching and trying other models

In [None]:
# Trying DINO V2 model 
# Less for object classification and more for fine details, textures --> may be suitable for textile catalog
# https://huggingface.co/facebook/dinov2-base

from transformers import AutoImageProcessor, AutoModel
import torch

processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model = AutoModel.from_pretrained('facebook/dinov2-base')

inputs = processor(
    images, 
    return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

hidden_states = outputs.last_hidden_state # Shape: [batch_size=50, num_tokens=1+256, hidden_size=768]

# Removing CLS Token and taking average of all patch embeddings
# This is because CLS token is used for classification, semantics and not for image embeddings
# Proceeding with average of all patch embeddings to retain more fine-grained details of image

all_patch_embedding = hidden_states[:, 1:, :]
print(all_patch_embedding.shape)

avg_patch_embeddings = torch.mean(all_patch_embedding, dim=1)
print(avg_patch_embeddings.shape)

avg_patch_embeddings


In [None]:
# DINOv2 Model 
embedding_len = len(avg_patch_embeddings[0])
embedding_len

In [None]:
# DINOv2 Model 
# Visualizing cosine similarity matrix between embeddings...
import numpy as np
from numpy.linalg import norm

embeddings = avg_patch_embeddings.detach().cpu().numpy()

# calculate value to normalize each vector by
norm_factor = np.linalg.norm(embeddings, axis=1)
norm_factor.shape

cos_sim = np.dot(embeddings, embeddings.T) / (
    norm(embeddings, axis=1) * norm(embeddings, axis=1)
)
print(cos_sim.shape)

import matplotlib.pyplot as plt

plt.imshow(cos_sim)
plt.title("Cosine Similarity Heatmap for DINOv2 Image Embeddings")
plt.colorbar()
plt.xlabel("Image Index")
plt.ylabel("Image Index")
plt.show()

In [None]:
# Diagnosing issues with image embeddings in DINOv2
# Checking mean emb here to see if embs are 0-centered
print(np.mean(embeddings, axis=0))

# Checking histogram of similarities btw embeddings
# Ideally want a good spread across the range 
# Any clustering n small ranges can be problematic
cos_sim_values = cos_sim[np.triu_indices_from(cos_sim, k=1)]
plt.hist(cos_sim_values, bins=50, alpha=0.75, color='blue')
plt.show()

In [None]:
# Attempting to zero-center embeddings to improve cosine similarity
mean_embedding = np.mean(embeddings, axis=0)  
centered_embeddings = embeddings - mean_embedding 

norm_embeddings = centered_embeddings / np.linalg.norm(centered_embeddings, axis=1, keepdims=True)

cos_sim = np.dot(norm_embeddings, norm_embeddings.T)

cos_sim_values = cos_sim[np.triu_indices_from(cos_sim, k=1)]
plt.hist(cos_sim_values, bins=50, alpha=0.75, color='blue')
plt.show()

# Not super helpful, just contrasts embeddings more...

In [None]:
# Now working with CLIP instead to implement multimodal similarity search

from transformers import CLIPProcessor, CLIPModel
import torch

model_id = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_id)
model = CLIPModel.from_pretrained(model_id)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

payloads['info'] = payloads['material'] + " " + payloads['color']
tokens = processor(
    text=payloads['info'].tolist(),
    padding=True,
    return_tensors='pt'
).to(device)
tokens.keys()

In [None]:
# CLIP Model 
# Getting text embeddings from CLIP

with torch.no_grad():
    text_emb = model.get_text_features(**tokens)

print(text_emb.shape)
print(text_emb.min(), text_emb.max())


In [None]:
# CLIP Model
import numpy as np

# detach text emb from graph, move to CPU, and convert to numpy array
text_emb = text_emb.detach().cpu().numpy()

# calculate value to normalize each vector by
norm_factor = np.linalg.norm(text_emb, axis=1)
norm_factor.shape

In [None]:
# CLIP Model
# Getting image embeddings from CLIP
img_inputs = processor(
    text=None,
    images=images,
    return_tensors='pt'
)['pixel_values'].to(device)
img_inputs.shape

In [None]:
# CLIP Model
img_emb = model.get_image_features(img_inputs)
print(img_emb.shape)
print(img_emb.min(), img_emb.max())

# NORMALIZE
# detach text emb from graph, move to CPU, and convert to numpy array
img_emb = img_emb.detach().cpu().numpy()

img_emb = img_emb.T / np.linalg.norm(img_emb, axis=1)
# transpose back to (21, 512)
img_emb = img_emb.T
print(img_emb.shape)
print(img_emb.min(), img_emb.max())

In [None]:
# CLIP Model
from numpy.linalg import norm

cos_sim = np.dot(text_emb, img_emb.T) / (
    norm(text_emb, axis=1) * norm(img_emb, axis=1)
)
cos_sim.shape

import matplotlib.pyplot as plt

plt.imshow(cos_sim)
plt.title("Cosine Similarity Heatmap for CLIP Text-Image Embeddings")
plt.colorbar()
plt.xlabel("Image Embeddings")
plt.ylabel("Text Embeddings")
plt.show()

In [None]:

# Normalize both image and text embeddings
image_embeddings = img_emb / np.linalg.norm(img_emb, axis=1, keepdims=True)
text_embeddings = text_emb / np.linalg.norm(text_emb, axis=1, keepdims=True)

# Compute cosine similarity between images and text (each row corresponds to one image-text pair)
image_text_sim = np.dot(image_embeddings, text_embeddings.T)

# Visualize the similarity matrix
plt.imshow(image_text_sim, cmap="viridis")
plt.show()


In [None]:
similarities = np.dot(image_embeddings, text_embeddings.T).flatten()

plt.hist(similarities, bins=100, color="blue")
plt.xlabel("Cosine Similarity")
plt.ylabel("Frequency")
plt.title("Distribution of Image-Text Similarities")
plt.show()

In [None]:
# Currently holding embeddings from DINOv2 + sample information in payloads
# Loading Qdrant database access tokens from .env file

from dotenv import load_dotenv
load_dotenv()

In [None]:
# Initializing Qdrant client object

from qdrant_client import QdrantClient

qclient = QdrantClient(
    url= os.getenv("QDRANT_DB_URL"),
    api_key= os.getenv("QDRANT_API_KEY")
)
qclient

In [None]:
# Creating collection in Qdrant database 

from qdrant_client.models import Distance, VectorParams

collection_name = "sample_images_2"
collection = qclient.recreate_collection(
    collection_name=collection_name,
    vectors_config=VectorParams(
        size=embedding_len,

        # Previously tried DOT distance, but cosine distance is more suitable for image embeddings
        distance=Distance.COSINE
    )
)
collection

In [None]:
# JSONifying the payloads dataframe to format metadata for each point

payload_dicts = payloads.to_dict(orient="records")
payload_dicts[:1]

In [None]:
# Creating records of payloads to load into Qdrant

from qdrant_client import models

records = [
    models.Record(
        id=idx,
        payload=payload_dicts[idx],
        vector=avg_patch_embeddings[idx]
    )
    for idx, _ in enumerate(payload_dicts)
]

In [None]:
# Sending records to Qdrant database

qclient.upload_records(
    collection_name=collection_name,
    records=records
)

In [None]:
# from qdrant_client import QdrantClient
# from qdrant_client.models import Distance, VectorParams
# from qdrant_client.models import PointStruct



# client = QdrantClient(url="http://localhost:6333")

# # client.create_collection(
# #     collection_name="test_collection",
# #     vectors_config=VectorParams(size=4, distance=Distance.DOT),
# # )

# # operation_info = client.upsert(
# #     collection_name="test_collection",
# #     wait=True,
# #     points=[
# #         PointStruct(id=1, vector=[0.05, 0.61, 0.76, 0.74], payload={"city": "Berlin"}),
# #         PointStruct(id=2, vector=[0.19, 0.81, 0.75, 0.11], payload={"city": "London"}),
# #         PointStruct(id=3, vector=[0.36, 0.55, 0.47, 0.94], payload={"city": "Moscow"}),
# #         PointStruct(id=4, vector=[0.18, 0.01, 0.85, 0.80], payload={"city": "New York"}),
# #         PointStruct(id=5, vector=[0.24, 0.18, 0.22, 0.44], payload={"city": "Beijing"}),
# #         PointStruct(id=6, vector=[0.35, 0.08, 0.11, 0.44], payload={"city": "Mumbai"}),
# #     ],
# # )

# # print(operation_info)

# # search_result = client.query_points(
# #     collection_name="test_collection",
# #     query=[0.2, 0.1, 0.9, 0.7],
# #     with_payload=False,
# #     limit=3
# # ).points

# # print(search_result)

# client.delete_collection(collection_name="test_collection")
# print(f"Collection 'test_collection' deleted.")
