## Save text-only information into Chroma DB
Note that this is just a hack to shore-up our problem of completely irrelevant results for text-only queries.

We could continue to use BLIP-2 embeddings, but it only has a context length of 32 tokens. Therefore, Chroma's default embedding model (all-MiniLM-L6-v2) is better on that front, having a context length of 256 tokens. We will just throw the documents into the database and let it do its thing.

In [1]:
import pandas as pd
import chromadb
from tqdm import tqdm
from icecream import ic

Load the metadata and keep only the relevant rows in a useful order (order of precedence).

In [2]:
metadata_file = '/mnt/d/abo-dataset/abo-listings-final-draft.pkl'
metadata = pd.read_pickle(metadata_file)

# Removing item keywords because there can be a rediculous number of them per item
metadata = metadata[['item_name', 'brand', 'model_name', 'model_year',
                                'product_description', 'product_type', 'color',
                                'fabric_type', 'style', 'material',
                                'pattern', 'finish_type', 'bullet_point']]

Convert a row of metadata to a string so it's in a useful form for creating the embedding.

In [3]:
def row_to_str(row):
    row_filtered = row.dropna()
    text = []
    for row_item in row_filtered:
        if isinstance(row_item, list):
            for list_item in row_item:
                text.append(str(list_item) + ';')
        else:
            text.append(str(row_item) + ';')
    
    return ' '.join(text).replace('\n', ' ').replace('^', ' ').replace(',', ', ')

In [3]:
client = chromadb.PersistentClient(path='/mnt/d/chroma')

In [5]:
collection = client.create_collection(name="text_only")

Save the rows in batches of 1000 so we don't run out of memory and don't exceed the limit of the underlying SQLite database.

In [6]:
batch_size = 1000
n_rows = len(metadata)
n_batches = (n_rows-1)//batch_size + 1
for i in tqdm(range(n_batches)):
    start = i * batch_size
    end = min((i+1) * batch_size, n_rows)
    rows_to_add = []
    max_len = 0
    for i in range(start, end):
        row_str = row_to_str(metadata.iloc[i])
        rows_to_add.append(row_str)
    collection.add(documents=rows_to_add, ids=list(metadata.index[start:end]))

100%|██████████| 121/121 [1:37:35<00:00, 48.39s/it]
