## Save embeddings into a Chroma DB

We save the image_id and item_id as metadata for easy access later.

Uses Python 3.12 environment locally.

In [1]:
import chromadb
import pandas as pd
import numpy as np
from tqdm import tqdm
import os

We have 3 versions of the BLIP-2 model, so we have separate embeddings for each of them.

In [2]:
blip_2_model = 'gs'

In [3]:
client = chromadb.PersistentClient(path='/mnt/d/chroma')

In [4]:
collection = client.create_collection(name="blip_2_"+blip_2_model+'_multimodal')

Chroma has a batch size limit due to the underlying sqlite database. Therefore, we need to add the embeddings in batches.

In [9]:
def embed_multimodal(collection, file, start_id):
    batch_size = 1000
    embeddings_df = pd.read_pickle(file)
    n_rows = len(embeddings_df)
    n_batches = (n_rows-1)//batch_size + 1
    for i in tqdm(range(n_batches)):
        start = i * batch_size
        end = min((i+1) * batch_size, n_rows)
        embeddings = np.stack(embeddings_df.iloc[start:end]['embedding'])
        embeddings = list(embeddings.reshape((len(embeddings), -1)))
        metadatas = []
        for i in range(start, end):
            image_id = embeddings_df.loc[i, 'image_id']
            item_id = embeddings_df.loc[i, 'item_id']
            metadatas.append({'image_id': image_id, 'item_id': item_id})
        collection.add(embeddings=embeddings, metadatas=metadatas, ids=[str(i + start_id) for i in range(start, end)])
    return end + start_id

Load the embeddings created from the ABO dataset (code in Blip-2_embeddings)

In [10]:
start_id = 0
embeddings_dir = '/mnt/d/embeddings/'
for file in os.listdir(embeddings_dir):
    print(file)
    if file.startswith('embeddings_'+blip_2_model+'_multimodal'):
        start_id = embed_multimodal(collection, embeddings_dir + file, start_id)

embeddings_gs_multimodal_1.pkl


100%|██████████| 1/1 [00:01<00:00,  1.05s/it]


embeddings_gs_multimodal_1000.pkl


 98%|█████████▊| 63/64 [39:21<00:37, 37.49s/it]


KeyError: 63936

In [None]:
collection = client.create_collection(name="blip_2_"+blip_2_model+'_text')

In [None]:
def embed_text(collection, file, start_id):
    batch_size = 500
    embeddings_df = pd.read_pickle(file)
    n_rows = len(embeddings_df)
    n_batches = (n_rows-1)//batch_size + 1
    for i in tqdm(range(n_batches)):
        start = i * batch_size
        end = min((i+1) * batch_size, n_rows)
        embeddings = np.stack(embeddings_df.iloc[start:end]['embedding'])
        embeddings = list(embeddings.reshape((len(embeddings), -1)))
        metadatas = []
        for i in range(start, end):
            item_id = embeddings_df.loc[i, 'item_id']
            metadatas.append({'item_id': item_id})
        collection.add(embeddings=embeddings, metadatas=metadatas, ids=[str(i + start_id) for i in range(start, end)])
    return end + start_id

In [None]:
start_id = 0
embeddings_dir = '/mnt/d/embeddings/'
for file in os.listdir(embeddings_dir):
    print(file)
    if file.startswith('embeddings_'+blip_2_model+'_text'):
        start_id = embed_text(collection, embeddings_dir + file, start_id)

In [4]:
# collection = client.get_collection(name="blip_2_"+blip_2_model)

In [11]:
# embedding_test = embeddings_df.loc[571195, 'embedding']
# collection.query(query_embeddings=[embedding_test], include=["metadatas", "distances"], n_results=15)

{'ids': [['571195',
   '571094',
   '571178',
   '571194',
   '571175',
   '571164',
   '571115',
   '571098',
   '571134',
   '571092',
   '571062',
   '571181',
   '571113',
   '571090',
   '571099']],
 'embeddings': None,
 'documents': None,
 'uris': None,
 'data': None,
 'metadatas': [[{'item_id': 'B08B1GPF45'},
   {'item_id': 'B089ZRYCT9'},
   {'item_id': 'B08B1832RW'},
   {'item_id': 'B08DJXR1LR'},
   {'item_id': 'B08B13CW15'},
   {'item_id': 'B08DKVQRFJ'},
   {'item_id': 'B08BJX637R'},
   {'item_id': 'B08BJXS9KY'},
   {'item_id': 'B08DTQVN7R'},
   {'item_id': 'B08BJXC3DP'},
   {'item_id': 'B08B12JRW4'},
   {'item_id': 'B08BJXKCCL'},
   {'item_id': 'B08DKT7QK3'},
   {'item_id': 'B08DKVNXK8'},
   {'item_id': 'B08FQYHVNJ'}]],
 'distances': [[0.0,
   10.761358261108398,
   16.386934499279306,
   22.819978855059706,
   23.64416374343949,
   24.726531129690464,
   25.441836602259766,
   26.24988555908203,
   26.405441758722727,
   26.522689819335938,
   26.619218826293945,
   27.60114