In [1]:
import argparse
import chromadb
import pysbd
import torch
import uuid
from tqdm import tqdm
from embedding_functions import get_embedding_fn
from utils import load_dataset
from semantic_text_splitter import TextSplitter
from sentence_validators import length_at_least_40

In [2]:
client = chromadb.PersistentClient(path='./vector_stores/foo/')
collection_arg = 'test-bge-small-en__cosine'
data_path = 'data/json/Astro_Reviews.json'
collection = client.get_collection(collection_arg)
print(collection)

Collection(name=test-bge-small-en__cosine)


In [3]:
from database import get_expected_parameters_from_collection_name
model_name, metric = get_expected_parameters_from_collection_name(
    collection.name)
print(f"{model_name}, {metric}")

bge-small-en, cosine


In [4]:
CHROMA_MODEL_NAME_TO_HF = {
    'bge-small-en': 'BAAI/bge-small-en',
    'bert-base-uncased': 'bert-base-uncased',
    'NV-Embed-v2': 'nvidia/NV-Embed-v2'
}

embedding_fn = get_embedding_fn(
                                model_name=CHROMA_MODEL_NAME_TO_HF[model_name], 
                                device='mps',
                                normalize=False
                                )
print(embedding_fn)



<embedding_functions.ChromaEmbedder object at 0x10c12ce50>


In [5]:
res = embedding_fn(['Hello', 'how are you?'])
print(type(res))

<class 'list'>


In [8]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('BAAI/bge-small-en', trust_remote_code=True, device='mps')
print(model)

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)


In [10]:
res = model.encode(['Hello', 'how are you?'], convert_to_numpy=True, normalize_embeddings=False)
print(type(res))
print(res.shape)

<class 'numpy.ndarray'>
(2, 384)


In [11]:
from embedding_functions import ChromaEmbedder

embedder = ChromaEmbedder(lambda x: model.encode(x, convert_to_numpy=True, normalize_embeddings=False), name='BAAI/bge-small-en')
print(embedder.model_name)

bge-small-en


In [12]:
chroma_res = embedder(['Hello', 'how are you?'])
print(type(chroma_res))

<class 'list'>
