In [1]:
import argparse
import chromadb
import pysbd
import uuid
from embedding_functions import get_embedding_fn
from utils import load_dataset

In [2]:
model = "BAAI/bge-small-en"
device = "mps"
embedding_fn = get_embedding_fn(model, device, normalize=False)
print(embedding_fn)

<embedding_functions.ChromaEmbedder object at 0x32aa2ded0>


In [3]:
data = load_dataset('data/processed_for_chroma/reviews/Astro_Reviews.json')
print(type(data))
print(len(data))

<class 'list'>
996


In [4]:
# record -> chunking -> use embdding_fn to create insertion records -> insert into db

In [42]:
SEG = pysbd.Segmenter(language="en", clean=False)

def augmented_sentence_embeddings(record, embedding_fn, augmenting_fn):
    """
    Takes a record's body, segments it into sentences, augments each sentence 
    for embedding (such as with title or abstract), and returns a list of tuples 
    of the form (sentence, embedding).
    """
    sentences = SEG.segment(record['body'])
    texts = augmenting_fn(record, sentences)
    vectors = embedding_fn(texts)
    return sentences, vectors

def add_abstract(record, sentences):
    abstract = record['abstract']
    return [abstract + '\n' + sentence for sentence in sentences]

def add_title(record, sentences):
    title = record['title']
    return [title + '\n' + sentence for sentence in sentences]

def add_title_and_abstract(record, sentences):
    title = record['title']
    abstract = record['abstract']
    return [title + '\n' + abstract + '\n' + sentence for sentence in sentences]

def no_augmentation(record, sentences):
    return sentences

# Insert to DB

For each (augmentation strategy, embedding function, metric) we need to create a collection. The same embedding function can be used for all metrics however

In [25]:
client = chromadb.PersistentClient(path='./vector_stores/test/')

In [35]:
# Delete previous test collections

collections = client.list_collections()
for collection in collections:
    if collection.name.startswith('test'):
        print(f"Deleting collection {collection.name}")
        client.delete_collection(name=collection.name)


Deleting collection test-bge-small-en__cosine__no_augmentation
Deleting collection test-bge-small-en__l2__no_augmentation
Deleting collection test-bge-small-en__cosine__add_title
Deleting collection test-bge-small-en__l2__add_title


In [53]:
# Collection name will be the name of the embedding model-metric-name of augmentation function
from tqdm import tqdm

def create_collection(client, embedding_fn, metric, augmenting_fn):
    # Set up collection name
    model_name = embedding_fn.model_name
    aug_fn = augmenting_fn.__name__
    collection_name = f"test-{model_name}__{metric}__{aug_fn}"

    # Create collection
    print(f"Creating collection: {collection_name}...", end="")
    collection = client.create_collection(
        name=collection_name, 
        embedding_function = embedding_fn,
        metadata={"hnsw:space": metric})
    print("created.")
    return collection

def get_expected_parameters_from_collection_name(collection_name):
    parts = collection_name.split("__")
    # TODO: fix this for the real runs when 'test-' won't prepend each collection name
    model_name = parts[0][5:]
    metric = parts[1]
    augmenting_fn_name = parts[2]
    return model_name, metric, augmenting_fn_name

def insert_records(collection, records, embedding_fn, augmenting_fn):
    # Ensure the passed in embedding function and augmenting function match the collection's expected functions
    model_name, metric, augmenting_fn_name = get_expected_parameters_from_collection_name(collection.name)
    assert model_name == embedding_fn.model_name, f"Expected embedding model '{model_name}' for collection {collection.name} but got '{embedding_fn.model_name}'"
    assert augmenting_fn_name == augmenting_fn.__name__, f"Expected augmentation function '{augmenting_fn_name}' for {collection.name} but got '{augmenting_fn.__name__}'"

    count = collection.count()

    # Insert records
    for record in tqdm(records):
        sentences, vectors = augmented_sentence_embeddings(record, embedding_fn, augmenting_fn)
        ids = [str(uuid.uuid4()) for _ in sentences]
        doi = record['doi'][0]
        collection.add(
            documents=sentences,
            metadatas=[{'doi': doi}] * len(sentences),
            embeddings=vectors,
            ids=ids
        )

    print(f"Added {collection.count() - count} records to collection {collection.name}")
    

In [37]:
import itertools
augmenting_functions = [no_augmentation, add_title]
metrics = ['cosine', 'l2']
collections = [create_collection(client, embedding_fn, metric, augmenting_fn) for augmenting_fn, metric in itertools.product(augmenting_functions, metrics)]


Creating collection: test-bge-small-en__cosine__no_augmentation...created.
Creating collection: test-bge-small-en__l2__no_augmentation...created.
Creating collection: test-bge-small-en__cosine__add_title...created.
Creating collection: test-bge-small-en__l2__add_title...created.


In [54]:
collection = collections[-1]
print(collection)
insert_records(collection, data[:2], embedding_fn, add_title)

Collection(name=test-bge-small-en__l2__add_title)


100%|██████████| 2/2 [00:11<00:00,  5.53s/it]

Added 1941 records to collection test-bge-small-en__l2__add_title





2640

In [37]:
vectors = [vector for _, vector in sentence_only_embeddings[:3]]
ids = [str(uuid.uuid4()) for _ in range(len(vectors))]

In [None]:
type(vectors[0])

In [None]:
# with vectors provided
collection.add(
    documents=dummy_docs,
    ids=ids,
    embeddings=vectors
)
