## Embedding with NVIDIA's model

A `SentenceTransformer` model instantiated from `NVIDIA-Embed-v2` will return tensors or np arrays but can only take a string input or a list containing a single string. It cannot accept a list of strings.

In [1]:
from sentence_transformers import SentenceTransformer
import torch

model_name = "BAAI/bge-en-icl"
sentences = [
    "The weather is lovely today.",
    "It's so sunny outside!",
    "He drove to the stadium.",
]
device = 'mps' # mps was producing a dimension error on batch input to model.encode 
print(f"Using device: {device}")
model = SentenceTransformer(
    model_name, trust_remote_code=True, device=device)
print(model)

# Encoding a list of strings
try:
    embeddings = model.encode(
        sentences, convert_to_tensor=True, normalize_embeddings=False)
    print(f"Batch embeddings shape: {embeddings.shape}")
except Exception as e:
    print(f"Error: {e}")

Using device: mps


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

SentenceTransformer(
  (0): Transformer({'max_seq_length': 32768, 'do_lower_case': False}) with Transformer model: MistralModel 
  (1): Pooling({'word_embedding_dimension': 4096, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': True, 'include_prompt': True})
)
Batch embeddings shape: torch.Size([3, 4096])


In [2]:
import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings

client = chromadb.PersistentClient(path='./vector_stores/bge/')
print(client)

# Set up embedding model
class BgeChromaEmbedder(EmbeddingFunction):
    def __init__(self, embedding_fn):
        self._encode = embedding_fn

    def __call__(self, input: Documents) -> Embeddings:
        return self._encode(input)

embedding_lambda = lambda docs: model.encode(docs, convert_to_numpy=True, normalize_embeddings=False)

embedder = BgeChromaEmbedder(embedding_lambda)
result = embedder(['hi there', 'hello world'])
print(f"Return object type: {type(result)}")
print(f"Inner object type: {type(result[0])}")
print(result[0].shape)

<chromadb.api.client.Client object at 0x32bf8aad0>
Return object type: <class 'list'>
Inner object type: <class 'numpy.ndarray'>
(4096,)


In [3]:
import os
import json

METRICS = ['l2', 'cosine', 'ip']
PATH_TO_DATA = 'data/processed_for_chroma/research'
FILENAMES = os.listdir(PATH_TO_DATA)
data = dict()

for filename in FILENAMES:
    with open(f'{PATH_TO_DATA}/{filename}', 'r') as file:
        data[os.path.splitext(os.path.basename(filename))[0]] = json.load(file)

print(data.keys())
data['Astro_Research'].keys()

dict_keys(['Earth_Science_Research', 'Planetary_Research', 'Astro_Research'])


dict_keys(['documents', 'metadatas', 'ids'])

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)
doc = data['Astro_Research']['documents'][0]
print(f"Length of doc string: {len(doc)}")


Length of doc string: 218019


In [6]:
abstract = data['Astro_Research']['metadatas'][0]['abstract']
print(abstract)
embedder([abstract])

We present a new model for computing the spectral evolution of stellar populations at ages between 1 × 10<SUP>5</SUP> and 2 × 10<SUP>10</SUP> yr at a resolution of 3 Å across the whole wavelength range from 3200 to 9500 Å for a wide range of metallicities. These predictions are based on a newly available library of observed stellar spectra. We also compute the spectral evolution across a larger wavelength range, from 91 Å to 160 μm, at lower resolution. The model incorporates recent progress in stellar evolution theory and an observationally motivated prescription for thermally pulsing stars on the asymptotic giant branch. The latter is supported by observations of surface brightness fluctuations in nearby stellar populations. We show that this model reproduces well the observed optical and near-infrared colour-magnitude diagrams of Galactic star clusters of various ages and metallicities. Stochastic fluctuations in the numbers of stars in different evolutionary phases can account for 

[array([-0.878101  ,  2.3778255 , -1.4578811 , ..., -3.2767513 ,
        -0.22019193,  3.9709563 ], dtype=float32)]

In [5]:
res = embedder(doc[:50])
print(res)

[array([ 0.3989852 , -0.89258695, -1.0513442 , ..., -1.9740007 ,
       -0.8036554 ,  2.7621977 ], dtype=float32)]


In [6]:
# Reset collections
collections = client.list_collections()

# Delete each collection
for collection in collections:
    client.delete_collection(name=collection.name)

print("All collections have been deleted.")

All collections have been deleted.


In [7]:
collection = client.create_collection(
    name='testing-insert',
    embedding_function=embedder,
    metadata={"hnsw:space": 'l2'}
)



In [8]:
collection.add(
    documents=data['Astro_Research']['documents'][0:1],
    metadatas=data['Astro_Research']['metadatas'][0:1],
    ids=data['Astro_Research']['ids'][0:1]
)

In [None]:
ADD_INCREMENT = 10

for metric in METRICS:
    clean_model_name = model_name.replace('/', '-')
    collection_name = f"{clean_model_name}_{metric}_no-norm"
    print(f"Creating collection: {collection_name}...")

    collection = client.create_collection(
        name=collection_name,
        embedding_function=embedder,
        metadata={"hnsw:space": metric}
    )
    for journal, records in data.items():
        print(f"  Adding {journal} records...")

        # Add records 10 at a time
        for i in range(len(records['documents']) // ADD_INCREMENT):
            s = slice(i*ADD_INCREMENT, (i+1)*ADD_INCREMENT, 1)
            collection.add(
                documents=records['documents'][s],
                metadatas=records['metadatas'][s],
                ids=records['ids'][s]
            )
    print(f"Finished {collection}")

In [None]:
# Test the dimensionality of the embeddings
collection = client.get_collection(name=f"{model_name}_cosine_no-norm")
chroma_embedding = collection.query(
    query_texts=["the sun is a star"],
    n_results=1,
    include=["embeddings", "documents"]
)

print(type(chroma_embedding['embeddings']))
print(type(chroma_embedding['embeddings'][0]))