## Save embeddings into a Chroma DB

We save the image_id and item_id as metadata for easy access later.

In [1]:
import chromadb
import pandas as pd
import numpy as np
from tqdm import tqdm

We have 3 versions of the BLIP-2 model, so we have separate embeddings for each of them.

In [2]:
blip_2_model = 'pretrain'

In [3]:
client = chromadb.PersistentClient(path='D:/chroma')

In [4]:
collection = client.create_collection(name="blip_2_"+blip_2_model)

Load the embeddings created from the ABO dataset (code in Blip-2_embeddings)

In [4]:
embeddings_df = pd.read_pickle('D:/embeddings/embeddings_'+blip_2_model+'.pkl')
embeddings_df

Unnamed: 0,image_id,item_id,embedding
0,81-DuD5XzmL,B0857LSVB7,"[-0.4358139, -0.7983236, 0.45861137, -1.535490..."
1,61+woWTqkwL,B0857LSVB7,"[-0.29962516, -0.45568618, 0.20228578, -0.1655..."
2,61SE4RTPjdL,B0857LSVB7,"[-0.374703, -0.43819726, 0.14495471, -0.064711..."
3,81YCp3dcurL,B07C5FF8QS,"[0.43136543, -1.4675522, 0.37719846, -1.568843..."
4,817GQ6xx-QL,B07C5FF8QS,"[0.9552257, -0.7680965, 0.37860662, -0.7980015..."
...,...,...,...
571191,,B08DKX14VY,"[0.08170852, 0.00392688, 0.102035984, -0.38783..."
571192,,B08BJXKDD2,"[0.27779222, 0.31448266, 0.14933139, -0.330696..."
571193,,B08FMKVLJ3,"[0.06593654, 0.0071030296, 0.09782801, -0.4232..."
571194,,B08DJXR1LR,"[0.19332173, 0.19885169, 0.09635306, -0.443597..."


Chroma has a batch size limit due to the underlying sqlite database. Therefore, we need to add the embeddings in batches.

In [6]:
batch_size = client.get_max_batch_size()
n_rows = len(embeddings_df)
n_batches = (n_rows-1)//batch_size + 1
for i in tqdm(range(n_batches)):
    start = i * batch_size
    end = min((i+1) * batch_size, n_rows)
    embeddings = np.stack(embeddings_df.loc[start:end-1, 'embedding'])
    metadatas = []
    for i in range(start, end):
        image_id = embeddings_df.loc[i, 'image_id']
        item_id = embeddings_df.loc[i, 'item_id']
        metadatas.append({'image_id': image_id, 'item_id': item_id})
    collection.add(embeddings=embeddings, metadatas=metadatas, ids=[str(i) for i in range(start, end)])

100%|██████████| 105/105 [10:20<00:00,  5.91s/it]


In [5]:
collection = client.get_collection(name="blip_2_"+blip_2_model)

In [6]:
embedding_test = embeddings_df.loc[571195, 'embedding']
collection.query(query_embeddings=[embedding_test], include=["metadatas", "distances"])

{'ids': [['571195',
   '571094',
   '571178',
   '571175',
   '571151',
   '571062',
   '571102',
   '571194',
   '571099',
   '571092']],
 'embeddings': None,
 'documents': None,
 'uris': None,
 'data': None,
 'metadatas': [[{'item_id': 'B08B1GPF45'},
   {'item_id': 'B089ZRYCT9'},
   {'item_id': 'B08B1832RW'},
   {'item_id': 'B08B13CW15'},
   {'item_id': 'B08B18RYH7'},
   {'item_id': 'B08B12JRW4'},
   {'item_id': 'B08DKRGCG1'},
   {'item_id': 'B08DJXR1LR'},
   {'item_id': 'B08FQYHVNJ'},
   {'item_id': 'B08BJXC3DP'}]],
 'distances': [[0.0,
   1.6682718992233276,
   2.760481762637771,
   4.759489253638559,
   4.982717653424928,
   5.186148166656494,
   5.457234609832796,
   6.416917054983525,
   6.617886066436768,
   6.948668956756592]],
 'included': [<IncludeEnum.distances: 'distances'>,
  <IncludeEnum.metadatas: 'metadatas'>]}