# Import & Paths


In [1]:
import json
import openai
import os
import pandas as pd
import qdrant_client
from tqdm import tqdm

import nest_asyncio
nest_asyncio.apply()

from typing import Union, Dict, Optional
from utils import format_metadata

import llama_index
print(llama_index.__version__)
from llama_index import ServiceContext, VectorStoreIndex, StorageContext
from llama_index.schema import Document, ImageDocument, QueryBundle
from llama_index.embeddings import OpenAIEmbedding
from llama_index.vector_stores import QdrantVectorStore
from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex

0.9.39


In [2]:
MAIN_DIR = os.path.dirname(os.getcwd())
DATA_DIR = os.path.join(MAIN_DIR, "data")
ARTIFACT_DIR = os.path.join(MAIN_DIR, "artifacts")
DATABASE_DIR = os.path.join(DATA_DIR, "db")
REFERENCE_DIR = os.path.join(DATA_DIR, "reference")
CONSUMER_DIR = os.path.join(DATA_DIR, "consumer")

with open(os.path.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    api_keys = json.load(f)
    
os.environ["OPENAI_API_KEY"] = api_keys["OPENAI_API_KEY"]
openai.api_key = api_keys["OPENAI_API_KEY"]

metadata_df = pd.read_csv(os.path.join(DATA_DIR, "exp_metadata_1000.csv"))
master_metadata = json.load(open(os.path.join(DATA_DIR, "metadata", "master_metadata.json")))

# Load text embeddings generation

In [3]:
TEXT_EMBED_MODEL = "text-embedding-3-large"
text_embeddings = OpenAIEmbedding(model=TEXT_EMBED_MODEL)

with open(os.path.join(DATA_DIR, "reference_list_1000.txt"), "r") as f:
    reference_images = f.readlines()

with open(os.path.join(DATA_DIR, "consumer_list_1000.txt"), "r") as f:
    consumer_images = f.readlines()

reference_images = [image.strip() for image in reference_images]
consumer_images = [image.strip() for image in consumer_images]

print("Number of reference images:", len(reference_images))
print("Number of consumer images:", len(consumer_images))

Number of reference images: 2000
Number of consumer images: 5000


## Generate text embeddings from given metadata

In [None]:
# formatted_metadatas = [format_metadata(metadata) for metadata in metadata_df["metadata"]]
# print(formatted_metadatas[0])

# text_embs = text_embeddings.get_text_embedding_batch(formatted_metadatas, show_progress=True)

# text_emb_dict = {}
# for ndc11, image_file, text_emb, metadata, text_content in zip(metadata_df["ndc11"], metadata_df["first_reference"], text_embs, metadata_df["metadata"], formatted_metadatas):
#     name = master_metadata[image_file]['name']
#     text_emb_dict[ndc11] = {
#         "name": name, "text_emb": text_emb, "metadata": metadata, "text_content": text_content
#     }

# with open(os.path.join(DATA_DIR, "embeddings", f"REFERENCE_EXTRADISP_{TEXT_EMBED_MODEL}.json"), "w") as f:
#     json.dump(text_emb_dict, f)

## Generate Text Embeddings from GPT-4V descriptions

### Reference

In [17]:
with open(os.path.join(ARTIFACT_DIR, "reference_1-1000/reference-gpt-extracted-features-n=1000.json"), "r") as f:
    ref_emb_text_dicts = json.load(f)
    
for k, v in ref_emb_text_dicts.items():
    if not v:
        ref_emb_text_dicts[k] = format_metadata(master_metadata[k]["metadata"])

reference_emb_texts = [ref_emb_text_dicts[ref_image] for ref_image in reference_images] 
reference_metadatas = [master_metadata[ref_image] for ref_image in reference_images]

# reference_text_embs = text_embeddings.get_text_embedding_batch(reference_emb_texts, show_progress=True)

# reference_text_emb_dict = {}

# for reference_image, reference_emb_text, reference_metadata, reference_text_emb \
#     in zip(reference_images, reference_emb_texts, reference_metadatas, reference_text_embs):
#         reference_text_emb_dict[reference_image] = {
#             'name': reference_metadata["name"],
#             'text_emb': reference_text_emb,
#             'metadata': reference_metadata,
#             'text_content': reference_emb_text
#         }

# with open(os.path.join(DATA_DIR, "embeddings", f"REFERENCE_GPTV_{TEXT_EMBED_MODEL}.json"), "w") as f:
#     json.dump(reference_text_emb_dict, f)

### Consumers

In [22]:
with open(os.path.join(ARTIFACT_DIR, "consumer_1-1000/consumer-gpt-extracted-features-n=1000.json"), "r") as f:
    consumer_emb_text_dicts = json.load(f)

consumer_emb_texts = [consumer_emb_text_dicts[image] for image in consumer_images] 
consumer_metadatas = [master_metadata[image] for image in consumer_images]

valid_indices = [idx for idx, content in enumerate(consumer_emb_texts) if content]
print("Number of valid extracted descriptions:", len(valid_indices))

valid_consumer_emb_texts = [consumer_emb_texts[idx] for idx in valid_indices]

# consumer_text_embs = text_embeddings.get_text_embedding_batch(valid_consumer_emb_texts, show_progress=True)

consumer_text_emb_dict = {}

for consumer_image, consumer_metadata \
    in zip(consumer_images, consumer_metadatas):
        consumer_text_emb_dict[consumer_image] = {
            'name': consumer_metadata["name"], 'text_emb': None,
            'metadata': consumer_metadata, 'text_content': None
        }

for valid_idx, valid_consumer_emb_text, consumer_text_emb in zip(valid_indices, valid_consumer_emb_texts, consumer_text_embs):
    consumer_image = consumer_images[valid_idx]
    consumer_text_emb_dict[consumer_image]['text_emb'] = consumer_text_emb
    consumer_text_emb_dict[consumer_image]['text_content'] = valid_consumer_emb_text

with open(os.path.join(DATA_DIR, "embeddings", f"CONSUMER_GPTV_{TEXT_EMBED_MODEL}.json"), "w") as f:
    json.dump(consumer_text_emb_dict, f)

Number of valid extracted descriptions: 4895


In [16]:
with open(os.path.join(DATA_DIR, "embeddings", f"CONSUMER_GPTV_{TEXT_EMBED_MODEL}.json"), "w") as f:
    json.dump(consumer_text_emb_dict, f)

# Database Creation

## Text Embeddings Only

In [14]:
# with open(os.path.join(DATA_DIR, "embeddings", "REFERENCE_EXTRADISP_text-embedding-3-large.json"), "r") as f:
#     text_emb_dict = json.load(f)

with open(os.path.join(DATA_DIR, "embeddings", "REFERENCE_GPTV_text-embedding-3-large.json"), "r") as f:
    text_emb_dict = json.load(f)
    
reference_text_embs = [text_emb_dict[ref_image]['text_emb'] for ref_image in reference_images]

In [5]:
client = qdrant_client.QdrantClient(path = os.path.join(DATABASE_DIR, "qdrant"))

In [18]:
text_document_list = []
for reference_image, reference_emb_text, reference_metadata, reference_text_emb \
    in zip(reference_images, reference_emb_texts, reference_metadatas, reference_text_embs):
    reference_metadata.update({"image_path": reference_image})
    text_document_list.append(
        Document(
            text=reference_emb_text,
            metadata=reference_metadata,
            embedding=reference_text_emb
            )
    )
    
print(len(text_document_list))

2000


In [21]:
text_store = QdrantVectorStore(client=client, collection_name="text_only")
text_service_context = ServiceContext.from_defaults(embed_model = text_embeddings)
text_storage_context = StorageContext.from_defaults(vector_store = text_store)

text_index = VectorStoreIndex.from_documents(
    documents = text_document_list,
    storage_context=text_storage_context,
    service_context=text_service_context
)

In [25]:
text_retriever = text_index.as_retriever(similarity_top_k=2000)

## Image Embeddings Only

In [39]:
with open(os.path.join(DATA_DIR, "embeddings", "clip-ViT-L14@336px-image-all.json"), "r") as f:
    image_emb_dict = json.load(f)

In [None]:
image_document_list = []
for reference_image, reference_emb_text, reference_metadata, reference_text_emb \
    in zip(reference_images, reference_emb_texts, reference_metadatas, reference_text_embs):
    reference_metadata.update({"image_path": reference_image})
    image_document_list.append(
        ImageDocument(
            image_path=reference_image,
            text=reference_emb_text,
            metadata=reference_metadata,
            embedding=image_emb_dict[reference_image],
            text_embedding=reference_text_emb
            )
    )
    
print(len(image_document_list))
len(image_document_list[0].embedding)

In [41]:
image_store = QdrantVectorStore(client=client, collection_name="image_only")
image_service_context = ServiceContext.from_defaults(embed_model = text_embeddings)
image_storage_context = StorageContext.from_defaults(image_store = image_store)

image_index = VectorStoreIndex.from_documents(
    documents = image_document_list,
    storage_context=image_storage_context,
    service_context=image_service_context
)

In [48]:
image_retriever = image_index.as_retriever(similarity_top_k = 2000)

In [73]:
rank = 0

for input_query in tqdm(reference_images):
    input_query_bundle = QueryBundle(
        query_str = "dummy",
        embedding = image_emb_dict[input_query]
    )

    retrieved_images = image_retriever.retrieve(input_query_bundle)
    
    for idx, retrieved_image in enumerate(retrieved_images):
        if retrieved_image.node.metadata["ndc11"] == master_metadata[input_query]["ndc11"]:
            if idx != 0:
                rank += (idx+1)
                break
            


  0%|          | 0/2000 [00:00<?, ?it/s]

100%|██████████| 2000/2000 [20:35<00:00,  1.62it/s]

6200.67





## Hybrid Text and Image Embeddings

In [None]:
image_document_list = []
for reference_image, reference_emb_text, reference_metadata, reference_text_emb \
    in zip(reference_images, reference_emb_texts, reference_metadatas, reference_text_embs):
    reference_metadata.update({"image_path": reference_image})
    image_document_list.append(
        ImageDocument(
            image_path=reference_image,
            text=reference_emb_text,
            metadata=reference_metadata,
            embedding=image_emb_dict[reference_image],
            text_embedding=reference_text_emb
            )
    )
    
print(len(image_document_list))
len(image_document_list[0].embedding)

In [None]:
image_document_list = []
for reference_image, reference_emb_text, reference_metadata, reference_text_emb \
    in zip(reference_images, reference_emb_texts, reference_metadatas, reference_text_embs):
    reference_metadata.update({"image_path": reference_image})
    image_document_list.append(
        ImageDocument(
            image_path=reference_image,
            text=reference_emb_text,
            metadata=reference_metadata,
            embedding=image_emb_dict[reference_image],
            text_embedding=reference_text_emb
            )
    )
    
print(len(image_document_list))
len(image_document_list[0].embedding)