In [1]:
"""
Script to generate embeddings for images (for input and catalog images) and store catalog embeddings to Qdrant database.
@File    : embeddings.py
@Date    : 2025-03-04
@Author  : Nandini Bohra
@Contact : nbohra@ucsd.edu

"""

'\nScript to generate embeddings for images (for input and catalog images) and store catalog embeddings to Qdrant database.\n@File    : embeddings.py\n@Date    : 2025-03-04\n@Author  : Nandini Bohra\n@Contact : nbohra@ucsd.edu\n\n'

In [2]:
import os

base_directory = "/Users/nandinibohra/Desktop/VSCodeFiles/Arthtattva_InternAssignment_Mar2025/ImageReco-ProductMatching/Product_Catalog/all_product_images/sample_images"
all_img_files = os.listdir(base_directory)
all_img_files[:10]


['sample_48.jpg',
 'sample_49.jpg',
 'sample_11.jpg',
 'sample_39.jpg',
 'sample_38.jpg',
 'sample_10.jpg',
 'sample_12.jpg',
 'sample_13.jpg',
 'sample_17.jpg',
 'sample_16.jpg']

In [3]:
all_img_urls = list(map(lambda x: os.path.join(base_directory, x), all_img_files))
all_img_urls[:10]

['/Users/nandinibohra/Desktop/VSCodeFiles/Arthtattva_InternAssignment_Mar2025/ImageReco-ProductMatching/Product_Catalog/all_product_images/sample_images/sample_48.jpg',
 '/Users/nandinibohra/Desktop/VSCodeFiles/Arthtattva_InternAssignment_Mar2025/ImageReco-ProductMatching/Product_Catalog/all_product_images/sample_images/sample_49.jpg',
 '/Users/nandinibohra/Desktop/VSCodeFiles/Arthtattva_InternAssignment_Mar2025/ImageReco-ProductMatching/Product_Catalog/all_product_images/sample_images/sample_11.jpg',
 '/Users/nandinibohra/Desktop/VSCodeFiles/Arthtattva_InternAssignment_Mar2025/ImageReco-ProductMatching/Product_Catalog/all_product_images/sample_images/sample_39.jpg',
 '/Users/nandinibohra/Desktop/VSCodeFiles/Arthtattva_InternAssignment_Mar2025/ImageReco-ProductMatching/Product_Catalog/all_product_images/sample_images/sample_38.jpg',
 '/Users/nandinibohra/Desktop/VSCodeFiles/Arthtattva_InternAssignment_Mar2025/ImageReco-ProductMatching/Product_Catalog/all_product_images/sample_images/sa

In [4]:
from pandas import DataFrame
from PIL import Image

payloads = DataFrame.from_records({"image_url": all_img_urls})
payloads["type"] = "samples"
payloads.head()


Unnamed: 0,image_url,type
0,/Users/nandinibohra/Desktop/VSCodeFiles/Arthta...,samples
1,/Users/nandinibohra/Desktop/VSCodeFiles/Arthta...,samples
2,/Users/nandinibohra/Desktop/VSCodeFiles/Arthta...,samples
3,/Users/nandinibohra/Desktop/VSCodeFiles/Arthta...,samples
4,/Users/nandinibohra/Desktop/VSCodeFiles/Arthta...,samples


In [5]:
images = list(map(lambda x: Image.open(x), payloads["image_url"]))
images[:10]

[<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1700x920>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1700x920>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1700x920>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1700x920>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1700x920>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1700x920>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1700x920>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1700x920>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1700x920>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1700x920>]

In [None]:
# resize images and conver to base 64 rep if needed to show on front end

from io import BytesIO
import math
import base64




In [None]:
# Trial with Microsoft Resnet-50 model
# https://huggingface.co/microsoft/resnet-50

from transformers import AutoImageProcessor, ResNetForImageClassification

processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")

inputs = processor(
    images, 
    return_tensors="pt", 
    # padding=True
)

outputs = model(**inputs)
embeddings = outputs.logits
embeddings

# Evaluated embeddings... not sure if this is the right fit 
# Researching and trying other models

  from .autonotebook import tqdm as notebook_tqdm
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tensor([[-13.0327, -12.8485, -11.5226,  ..., -13.6024, -11.4868,  -9.8092],
        [ -8.8716, -10.5682, -10.8089,  ..., -12.5314,  -8.5624,  -8.4759],
        [ -8.2020,  -8.5568,  -8.1728,  ..., -10.2424,  -7.6316,  -6.6405],
        ...,
        [-12.1440, -12.1543, -10.0716,  ..., -11.9698,  -7.9782,  -8.2962],
        [-12.3966, -11.1451, -12.2452,  ..., -12.0119,  -9.8162, -10.8850],
        [ -9.0621,  -7.6530,  -6.6977,  ...,  -9.7274,  -7.3580,  -6.0283]],
       grad_fn=<AddmmBackward0>)

In [None]:
# Trying with DINO V2 model
# https://huggingface.co/facebook/dinov2-base

from transformers import AutoImageProcessor, AutoModel
import torch

processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model = AutoModel.from_pretrained('facebook/dinov2-base')

inputs = processor(
    images, 
    return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

hidden_states = outputs.last_hidden_state # Shape: [batch_size=50, num_tokens=1+256, hidden_size=768]

# Removing CLS Token and taking average of all patch embeddings
all_patch_embedding = hidden_states[:, 1:, :]
# print(all_patch_embedding.shape)

avg_patch_embeddings = torch.mean(all_patch_embedding, dim=1)
# print(avg_patch_embeddings.shape)


torch.Size([50, 256, 768])
torch.Size([50, 768])


In [24]:
embedding_len = len(avg_patch_embeddings[0])
embedding_len

768

In [25]:
from dotenv import load_dotenv
load_dotenv()

True

In [26]:
from qdrant_client import QdrantClient

qclient = QdrantClient(
    url= os.getenv("QDRANT_DB_URL"),
    api_key= os.getenv("QDRANT_API_KEY")
)
qclient

<qdrant_client.qdrant_client.QdrantClient at 0x36c426000>

In [27]:
from qdrant_client.models import Distance, VectorParams

collection_name = "sample_images_2"
collection = qclient.recreate_collection(
    collection_name=collection_name,
    vectors_config=VectorParams(
        size=embedding_len,
        distance=Distance.COSINE
    )
)
collection

  collection = qclient.recreate_collection(


True

In [28]:
payload_dicts = payloads.to_dict(orient="records")
payload_dicts[:10]

[{'image_url': '/Users/nandinibohra/Desktop/VSCodeFiles/Arthtattva_InternAssignment_Mar2025/ImageReco-ProductMatching/Product_Catalog/all_product_images/sample_images/sample_48.jpg',
  'type': 'samples'},
 {'image_url': '/Users/nandinibohra/Desktop/VSCodeFiles/Arthtattva_InternAssignment_Mar2025/ImageReco-ProductMatching/Product_Catalog/all_product_images/sample_images/sample_49.jpg',
  'type': 'samples'},
 {'image_url': '/Users/nandinibohra/Desktop/VSCodeFiles/Arthtattva_InternAssignment_Mar2025/ImageReco-ProductMatching/Product_Catalog/all_product_images/sample_images/sample_11.jpg',
  'type': 'samples'},
 {'image_url': '/Users/nandinibohra/Desktop/VSCodeFiles/Arthtattva_InternAssignment_Mar2025/ImageReco-ProductMatching/Product_Catalog/all_product_images/sample_images/sample_39.jpg',
  'type': 'samples'},
 {'image_url': '/Users/nandinibohra/Desktop/VSCodeFiles/Arthtattva_InternAssignment_Mar2025/ImageReco-ProductMatching/Product_Catalog/all_product_images/sample_images/sample_38.jpg

In [29]:
from qdrant_client import models

records = [
    models.Record(
        id=idx,
        payload=payload_dicts[idx],
        vector=avg_patch_embeddings[idx]
    )
    for idx, _ in enumerate(payload_dicts)
]

In [30]:
qclient.upload_records(
    collection_name=collection_name,
    records=records
)

  qclient.upload_records(


In [None]:
# from qdrant_client import QdrantClient
# from qdrant_client.models import Distance, VectorParams
# from qdrant_client.models import PointStruct



# client = QdrantClient(url="http://localhost:6333")

# # client.create_collection(
# #     collection_name="test_collection",
# #     vectors_config=VectorParams(size=4, distance=Distance.DOT),
# # )

# # operation_info = client.upsert(
# #     collection_name="test_collection",
# #     wait=True,
# #     points=[
# #         PointStruct(id=1, vector=[0.05, 0.61, 0.76, 0.74], payload={"city": "Berlin"}),
# #         PointStruct(id=2, vector=[0.19, 0.81, 0.75, 0.11], payload={"city": "London"}),
# #         PointStruct(id=3, vector=[0.36, 0.55, 0.47, 0.94], payload={"city": "Moscow"}),
# #         PointStruct(id=4, vector=[0.18, 0.01, 0.85, 0.80], payload={"city": "New York"}),
# #         PointStruct(id=5, vector=[0.24, 0.18, 0.22, 0.44], payload={"city": "Beijing"}),
# #         PointStruct(id=6, vector=[0.35, 0.08, 0.11, 0.44], payload={"city": "Mumbai"}),
# #     ],
# # )

# # print(operation_info)

# # search_result = client.query_points(
# #     collection_name="test_collection",
# #     query=[0.2, 0.1, 0.9, 0.7],
# #     with_payload=False,
# #     limit=3
# # ).points

# # print(search_result)

# client.delete_collection(collection_name="test_collection")
# print(f"Collection 'test_collection' deleted.")
