## Save embeddings into a Chroma DB

We save the image_id and item_id as metadata for easy access later.

Uses Python 3.12 environment locally.

In [1]:
import chromadb
import pandas as pd
import numpy as np
from tqdm import tqdm
import os

We have 4 versions of the BLIP-2 model, so we have separate embeddings for each of them.

In [2]:
blip_2_model = 'gs'

In [3]:
client = chromadb.PersistentClient(path='D:/chroma')

In [4]:
collection = client.create_collection(name="blip_2_"+blip_2_model+'_multimodal')

Chroma has a batch size limit due to the underlying sqlite database. Therefore, we need to add the embeddings in batches.

In [5]:
def embed_multimodal(collection, file, start_id):
    batch_size = 4000
    embeddings_df = pd.read_pickle(file)
    n_rows = len(embeddings_df)
    n_batches = (n_rows-1)//batch_size + 1
    for i in tqdm(range(n_batches)):
        start = i * batch_size
        end = min((i+1) * batch_size, n_rows)
        embeddings = np.stack(embeddings_df.iloc[start:end]['embedding'])
        embeddings = embeddings[:,:4,:]
        embeddings = list(embeddings.reshape((len(embeddings), -1)))
        metadatas = []
        for i in range(start, end):
            image_id = embeddings_df.loc[i, 'image_id']
            item_id = embeddings_df.loc[i, 'item_id']
            metadatas.append({'image_id': image_id, 'item_id': item_id})
        collection.add(embeddings=embeddings, metadatas=metadatas, ids=[str(i + start_id) for i in range(start, end)])
    return end + start_id

Load the multimodal embeddings created from the ABO dataset (code in Blip-2_embeddings)

In [None]:
start_id = 0
embeddings_dir = 'D:/embeddings/'
for file in os.listdir(embeddings_dir):
    if file.startswith('embeddings_'+blip_2_model+'_multimodal'):
        print(file)
        # start_id to keep track of the ids (ints) used when adding to the database
        start_id = embed_multimodal(collection, embeddings_dir + file, start_id)

embeddings_gs_multimodal_1.pkl


  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  6.29it/s]


embeddings_gs_multimodal_1000.pkl


100%|██████████| 16/16 [02:42<00:00, 10.15s/it]


embeddings_gs_multimodal_1999.pkl


100%|██████████| 16/16 [03:06<00:00, 11.65s/it]


embeddings_gs_multimodal_2998.pkl


100%|██████████| 16/16 [03:20<00:00, 12.55s/it]


embeddings_gs_multimodal_3997.pkl


100%|██████████| 16/16 [03:34<00:00, 13.42s/it]


embeddings_gs_multimodal_4996.pkl


100%|██████████| 16/16 [03:49<00:00, 14.34s/it]


embeddings_gs_multimodal_5995.pkl


100%|██████████| 16/16 [04:04<00:00, 15.31s/it]


embeddings_gs_multimodal_6994.pkl


100%|██████████| 16/16 [04:23<00:00, 16.46s/it]


embeddings_gs_multimodal_7993.pkl


100%|██████████| 16/16 [04:35<00:00, 17.21s/it]


embeddings_gs_multimodal_8923.pkl


100%|██████████| 15/15 [04:27<00:00, 17.80s/it]


embeddings_gs_text_1000.pkl
embeddings_gs_text_1500.pkl
embeddings_gs_text_1877.pkl
embeddings_gs_text_500.pkl
old


Load the text-only embeddings

In [7]:
collection = client.create_collection(name="blip_2_"+blip_2_model+'_text')

In [22]:
def embed_text(collection, file, start_id):
    batch_size = 2000
    embeddings_df = pd.read_pickle(file)
    n_rows = len(embeddings_df)
    n_batches = (n_rows-1)//batch_size + 1
    for i in tqdm(range(n_batches)):
        start = i * batch_size
        end = min((i+1) * batch_size, n_rows)
        embeddings = []
        metadatas = []
        for i in range(start, end):
            embeddings.append(embeddings_df.loc[i, 'embedding'][:2,:].reshape(-1))
            item_id = embeddings_df.loc[i, 'item_id']
            metadatas.append({'item_id': item_id})
        collection.add(embeddings=embeddings, metadatas=metadatas, ids=[str(i + start_id) for i in range(start, end)])
    return end + start_id

In [None]:
start_id = 0
embeddings_dir = 'D:/embeddings/'
for file in os.listdir(embeddings_dir):
    if file.startswith('embeddings_'+blip_2_model+'_text'):
        print(file)
        start_id = embed_text(collection, embeddings_dir + file, start_id)

embeddings_gs_text_1000.pkl


100%|██████████| 16/16 [00:55<00:00,  3.49s/it]


embeddings_gs_text_1500.pkl


 44%|████▍     | 7/16 [00:26<00:33,  3.77s/it]