In [1]:
"""
Script to generate embeddings for images (for catalog images) and store embeddings to Qdrant database.
@File    : image_embeddings_experiments.ipynb
@Author  : Nandini Bohra
@Contact : nbohra@ucsd.edu

@References : https://www.youtube.com/watch?v=MlRkBvOCfTY
"""

'\nScript to generate embeddings for images (for catalog images) and store embeddings to Qdrant database.\n@File    : image_embeddings_experiments.ipynb\n@Author  : Nandini Bohra\n@Contact : nbohra@ucsd.edu\n\n@References : https://www.youtube.com/watch?v=MlRkBvOCfTY\n'

In [2]:
# imports
import os
import pandas as pd
import numpy as np
import re

# for image resizing to b64
from io import BytesIO
import math
import base64
from PIL import Image

# for importing in dinov2
from transformers import AutoImageProcessor, AutoModel
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Importing in the payloads csv on all image information
payloads = pd.read_csv("payloads.csv")
payloads.head()

Unnamed: 0,id,image_url,type,material,color label,avg rgb
0,0,/Users/nandinibohra/Desktop/VSCodeFiles/ImageR...,samples,JUTE,CREAM,"(np.int64(230), np.int64(223), np.int64(199))"
1,1,/Users/nandinibohra/Desktop/VSCodeFiles/ImageR...,samples,AURA,ROSE,"(np.int64(223), np.int64(186), np.int64(158))"
2,2,/Users/nandinibohra/Desktop/VSCodeFiles/ImageR...,samples,AURA,MINT,"(np.int64(162), np.int64(191), np.int64(156))"
3,3,/Users/nandinibohra/Desktop/VSCodeFiles/ImageR...,samples,SURFACE,OFFWHITE,"(np.int64(241), np.int64(233), np.int64(218))"
4,4,/Users/nandinibohra/Desktop/VSCodeFiles/ImageR...,samples,SURFACE,COFFEE,"(np.int64(144), np.int64(123), np.int64(100))"


In [4]:
def parse_rgb(s):
    if isinstance(s, str):
        s = s.replace("int64", "")
        nums = list(map(int, re.findall(r"\d+", s)))
        if len(nums) == 3:
            return nums  # [r, g, b]
    return None

In [5]:
payloads[["r", "g", "b"]] = payloads["avg rgb"].apply(
    lambda x: pd.Series(parse_rgb(x))
)

In [6]:
payloads = payloads.drop(columns=["avg rgb"])
payloads.rename(columns={"image_url": "image path"}, inplace=True)
payloads.dtypes

id              int64
image path     object
type           object
material       object
color label    object
r               int64
g               int64
b               int64
dtype: object

In [7]:
payloads.head()

Unnamed: 0,id,image path,type,material,color label,r,g,b
0,0,/Users/nandinibohra/Desktop/VSCodeFiles/ImageR...,samples,JUTE,CREAM,230,223,199
1,1,/Users/nandinibohra/Desktop/VSCodeFiles/ImageR...,samples,AURA,ROSE,223,186,158
2,2,/Users/nandinibohra/Desktop/VSCodeFiles/ImageR...,samples,AURA,MINT,162,191,156
3,3,/Users/nandinibohra/Desktop/VSCodeFiles/ImageR...,samples,SURFACE,OFFWHITE,241,233,218
4,4,/Users/nandinibohra/Desktop/VSCodeFiles/ImageR...,samples,SURFACE,COFFEE,144,123,100


In [8]:
# Less for object classification and more for fine details, textures --> may be suitable for textile catalog
# https://huggingface.co/facebook/dinov2-base

processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model = AutoModel.from_pretrained('facebook/dinov2-base')
model.eval()


def load_image_for_embedding(path, max_size=224):
    img = Image.open(path).convert("RGB")
    img.thumbnail((max_size, max_size))
    return img

def get_avg_emb(path):
    image = load_image_for_embedding(path)
    inputs = processor(image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
    
    patches = outputs.last_hidden_state[:, 1:, :]
    # print(f"Patches shape: {patches.shape}")
    emb = torch.mean(patches, dim=1)

    # Normalize embeddings
    emb = emb / emb.norm(dim=-1, keepdim=True)
    # print(f"Embedding shape (before squeeze): {emb.shape}")
    return emb.squeeze(0).numpy()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [10]:
image_paths = payloads["image path"]
embeddings = []
for path in image_paths:
    emb = get_avg_emb(path)
    embeddings.append(emb)
    
embeddings = np.array(embeddings)
print(f"Embedding shape: {embeddings.shape}")

Embedding shape: (50, 768)


In [11]:
embedding_len = embeddings.shape[1]
embedding_len

768

### Loading embeddings to Qdrant

In [12]:
# Currently holding embeddings from DINOv2 + sample information in payloads
# Loading Qdrant database access tokens from .env file

from dotenv import load_dotenv
load_dotenv()

True

In [13]:
# Initializing Qdrant client object

from qdrant_client import QdrantClient

qclient = QdrantClient(
    url= os.getenv("QDRANT_DB_URL"),
    api_key= os.getenv("QDRANT_API_KEY")
)
qclient

  qclient = QdrantClient(


<qdrant_client.qdrant_client.QdrantClient at 0x28ff253a0>

In [14]:
# Creating collection in Qdrant database 

from qdrant_client.models import Distance, VectorParams

collection_name = "TextileProductRec"
collection = qclient.recreate_collection(
    collection_name=collection_name,
    vectors_config=VectorParams(
        size=embedding_len,

        # Previously tried DOT distance, but cosine distance is more suitable for image embeddings
        distance=Distance.COSINE
    )
)
collection

  collection = qclient.recreate_collection(


True

In [15]:
# JSONifying the payloads dataframe to format metadata for each point

payload_dicts = payloads.to_dict(orient="records")
payload_dicts[:1]

[{'id': 0,
  'image path': '/Users/nandinibohra/Desktop/VSCodeFiles/ImageRecommendation_ProductMatching/Product_Catalog/all_product_images/sample_images/sample_00.jpg',
  'type': 'samples',
  'material': 'JUTE',
  'color label': 'CREAM',
  'r': 230,
  'g': 223,
  'b': 199}]

In [16]:
# Creating records of payloads to load into Qdrant

from qdrant_client import models

records = [
    models.Record(
        id=idx,
        payload=payload_dicts[idx],
        vector=embeddings[idx]
    )
    for idx, _ in enumerate(payload_dicts)
]

In [17]:
# Sending records to Qdrant database

qclient.upload_records(
    collection_name=collection_name,
    records=records
)

  qclient.upload_records(
