## Save embeddings into a Chroma DB

We save the image_id and item_id as metadata for easy access later.

Uses Python 3.12 environment locally.

In [1]:
import chromadb
import pandas as pd
import numpy as np
from tqdm import tqdm

We have 3 versions of the BLIP-2 model, so we have separate embeddings for each of them.

In [2]:
blip_2_model = 'gs'

In [3]:
client = chromadb.PersistentClient(path='D:/chroma')

In [4]:
collection = client.create_collection(name="blip_2_"+blip_2_model)

Load the embeddings created from the ABO dataset (code in Blip-2_embeddings)

In [6]:
embeddings_df = pd.read_pickle('D:/embeddings/embeddings_'+blip_2_model+'.pkl')
embeddings_df

Unnamed: 0,image_id,item_id,embedding
0,81-DuD5XzmL,B0857LSVB7,"[0.07762626, -0.5834237, 0.07436231, -1.293449..."
1,61+woWTqkwL,B0857LSVB7,"[-0.5496451, -0.5162424, -0.2742843, 0.2406437..."
2,61SE4RTPjdL,B0857LSVB7,"[-0.59508175, -0.39230114, -0.30060068, 0.3181..."
3,81YCp3dcurL,B07C5FF8QS,"[0.6407373, -0.6000935, 0.81145144, -1.0256921..."
4,817GQ6xx-QL,B07C5FF8QS,"[1.1870946, -0.38225535, 0.66963655, -0.698766..."
...,...,...,...
571191,,B08DKX14VY,"[-0.3429492, -0.5522551, 0.065248385, -0.11148..."
571192,,B08BJXKDD2,"[0.19367197, 0.24559124, 0.20168683, 0.0394387..."
571193,,B08FMKVLJ3,"[0.013649158, -0.34035835, 0.011010983, 0.0603..."
571194,,B08DJXR1LR,"[0.2649222, -0.09033289, 0.21277864, -0.189302..."


Chroma has a batch size limit due to the underlying sqlite database. Therefore, we need to add the embeddings in batches.

In [6]:
batch_size = client.get_max_batch_size()
n_rows = len(embeddings_df)
n_batches = (n_rows-1)//batch_size + 1
for i in tqdm(range(n_batches)):
    start = i * batch_size
    end = min((i+1) * batch_size, n_rows)
    embeddings = np.stack(embeddings_df.loc[start:end-1, 'embedding'])
    metadatas = []
    for i in range(start, end):
        image_id = embeddings_df.loc[i, 'image_id']
        item_id = embeddings_df.loc[i, 'item_id']
        metadatas.append({'image_id': image_id, 'item_id': item_id})
    collection.add(embeddings=embeddings, metadatas=metadatas, ids=[str(i) for i in range(start, end)])

100%|██████████| 105/105 [10:20<00:00,  5.91s/it]


In [4]:
collection = client.get_collection(name="blip_2_"+blip_2_model)

In [11]:
embedding_test = embeddings_df.loc[571195, 'embedding']
collection.query(query_embeddings=[embedding_test], include=["metadatas", "distances"], n_results=15)

{'ids': [['571195',
   '571094',
   '571178',
   '571194',
   '571175',
   '571164',
   '571115',
   '571098',
   '571134',
   '571092',
   '571062',
   '571181',
   '571113',
   '571090',
   '571099']],
 'embeddings': None,
 'documents': None,
 'uris': None,
 'data': None,
 'metadatas': [[{'item_id': 'B08B1GPF45'},
   {'item_id': 'B089ZRYCT9'},
   {'item_id': 'B08B1832RW'},
   {'item_id': 'B08DJXR1LR'},
   {'item_id': 'B08B13CW15'},
   {'item_id': 'B08DKVQRFJ'},
   {'item_id': 'B08BJX637R'},
   {'item_id': 'B08BJXS9KY'},
   {'item_id': 'B08DTQVN7R'},
   {'item_id': 'B08BJXC3DP'},
   {'item_id': 'B08B12JRW4'},
   {'item_id': 'B08BJXKCCL'},
   {'item_id': 'B08DKT7QK3'},
   {'item_id': 'B08DKVNXK8'},
   {'item_id': 'B08FQYHVNJ'}]],
 'distances': [[0.0,
   10.761358261108398,
   16.386934499279306,
   22.819978855059706,
   23.64416374343949,
   24.726531129690464,
   25.441836602259766,
   26.24988555908203,
   26.405441758722727,
   26.522689819335938,
   26.619218826293945,
   27.60114