In [1]:
import chromadb
from chromadb.utils import embedding_functions
from chromadb.api.types import Embeddable, EmbeddingFunction, Documents, Embeddings
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer

In [2]:
modelPath = "src/models/all-MiniLM-L6-v2"
model = SentenceTransformer("src/models/all-MiniLM-L6-v2")

In [3]:
class SentenceEmbedding(EmbeddingFunction[Documents]):
    def __init__(self, model: SentenceTransformer) -> None:
        self.model = model
        
    def __call__(self, input: Documents) -> Embeddings:
        # embed the documents somehow
        embeddings = []
        for doc in tqdm(input):
            embedding = model.encode(doc)
            print(type(embedding))
            embeddings.append(embedding.tolist())
        return embeddings

In [4]:
sentence_transformer_ef = SentenceEmbedding()

In [5]:
client = chromadb.PersistentClient(path='./data/chroma')
collection = client.get_or_create_collection(name='test3', embedding_function=sentence_transformer_ef)

In [6]:
collection.add(
    documents=[
        "Hello World",
        "Hello World 2",
        "Hello World 3"
    ],
    metadatas=[
        {"name": "name1"},
        {"name": "name2"},
        {"name": "name3"}
    ],
    ids=[
        "1",
        "2",
        "3"
    ]
)

  0%|          | 0/3 [00:00<?, ?it/s]

<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>


In [7]:
all_embeddings = collection.get(ids=['1', '2', '3'], include=['embeddings'])

In [8]:
# Create a dataframe mapping message_id to embedding, embedding is a list, but should be saved as one column
df = pd.DataFrame(columns=['ids', 'embeddings'])
df['ids'] = all_embeddings['ids']
df['embeddings'] = all_embeddings['embeddings']
df

Unnamed: 0,ids,embeddings
0,1,"[-0.034477315843105316, 0.031023172661662102, ..."
1,2,"[-0.023039061576128006, 0.009830151684582233, ..."
2,3,"[-0.0509931854903698, -0.012073464691638947, -..."


In [9]:
collection.peek()

{'ids': ['1', '2', '3'],
 'embeddings': [[-0.034477315843105316,
   0.031023172661662102,
   0.006734910886734724,
   0.02610892429947853,
   -0.03936195746064186,
   -0.1603025197982788,
   0.06692396104335785,
   -0.006441440898925066,
   -0.04745054617524147,
   0.014758836477994919,
   0.07087532430887222,
   0.055527545511722565,
   0.01919332519173622,
   -0.026251299306750298,
   -0.01010951679199934,
   -0.026940451934933662,
   0.022307397797703743,
   -0.022226639091968536,
   -0.1496926248073578,
   -0.01749303936958313,
   0.007676327601075172,
   0.054352276027202606,
   0.0032544792629778385,
   0.03172592446208,
   -0.08462144434452057,
   -0.029405953362584114,
   0.05159562826156616,
   0.048124104738235474,
   -0.003314818488433957,
   -0.05827919766306877,
   0.04196928068995476,
   0.02221069671213627,
   0.12818878889083862,
   -0.02233896404504776,
   -0.011656257323920727,
   0.06292840093374252,
   -0.03287629410624504,
   -0.09122602641582489,
   -0.03117538616