In [1]:
import pandas as pd

Xy = pd.read_csv('./stash/Xy.csv')
X = Xy[Xy.columns.drop(['title', 'y'])]
t = Xy['title']
y = Xy['y']

Xy.shape, X.shape, t.shape, y.shape

((90, 1538), (90, 1536), (90,), (90,))

In [2]:
df = pd.DataFrame({
    'id': list(range(Xy.shape[0])),
    'document': t,
    'metadata': [{'subject': _y} for _y in y],
    'embedding': [[float(_x) for _x in _v] for _v in X.values]
})
df.shape

(90, 4)

In [3]:
df.head()

Unnamed: 0,id,document,metadata,embedding
0,0,Java: The Complete Reference,{'subject': 'java'},"[0.0022508783731609, -0.0042110048234462, 0.01..."
1,1,Core Java An Integrated Approach (Black Book),{'subject': 'java'},"[0.0083073740825057, -0.0205343309789896, 0.01..."
2,2,Head First Java,{'subject': 'java'},"[0.0109464712440967, 0.0101075824350118, -0.00..."
3,3,Effective Java,{'subject': 'java'},"[-0.0103182280436158, 0.0071102487854659, 0.00..."
4,4,Thinking in Java,{'subject': 'java'},"[-0.0012556132860481, 0.0025670316535979, 0.01..."


In [4]:
t2e = {r['document']: r['embedding'] for _, r in df.iterrows()}
len(t2e)

88

In [5]:
import pathlib
import shutil
from langchain_core.documents.base import Document
from langchain.vectorstores import Chroma
from openai import OpenAI

class MockEmbedding:
    def __init__(self, t2e, embedding_model='text-embedding-ada-002'):
        self.t2e = t2e
        self.embedding_model = embedding_model
        self.client = OpenAI()

    def __embed(self, t):
        if t in self.t2e:
            return self.t2e[t]
            
        docs = [t.replace('\n', ' ')]
        res = self.client.embeddings.create(input=docs, model=self.embedding_model)
        return res.data[0].embedding
        
    def embed_documents(self, texts):
        return [self.__embed(t) for t in texts]

    def embed_query(self, query):
        return self.__embed(query)
        
def get_documents():
    def r2d(r):
        document = Document(page_content=r['document'], metadata=r['metadata'])
        return document
        
    return df.apply(r2d, axis=1).tolist()

def get_db(db_path=pathlib.Path('./book_vdb')):
    documents = get_documents()
    embedding_function = MockEmbedding(t2e=t2e)
    
    if db_path.exists():
        db = Chroma(
            persist_directory=str(db_path), 
            embedding_function=embedding_function
        )
    else:
        db_path.mkdir(parents=True, exist_ok=True)
        db = Chroma.from_documents(
            documents=documents, 
            embedding=embedding_function, 
            persist_directory=str(db_path)
        )

    return db

def get_retriever(db_path=pathlib.Path('./book_vdb'), retriever_params=None):
    if retriever_params is None:
        retriever_params = {
            'search_type': 'mmr',
            'search_kwargs': {
                'k': 5,
                'fetch_k': 100,
                'lambda_mult': 0.5,
                'score_threshold': 0.2
            }
        }
        
    return get_db(db_path).as_retriever(**retriever_params)
    
retriever = get_retriever()
vectorstore = retriever.vectorstore

In [6]:
vectorstore.search('java', search_type='similarity')

[Document(page_content='The Java™ Programming Language', metadata={'subject': 'java'}),
 Document(page_content='Effective Java', metadata={'subject': 'java'}),
 Document(page_content='Thinking in Java', metadata={'subject': 'java'}),
 Document(page_content='Head First Java', metadata={'subject': 'java'})]