In [4]:
from chromadb.utils import embedding_functions

embedding_function =\
    embedding_functions.SentenceTransformerEmbeddingFunction(

    model_name="BAAI/bge-m3", trust_remote_code=True
)

  from tqdm.autonotebook import tqdm, trange


In [2]:
import chromadb
from uuid import uuid4
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter

def add_to_collection(
        collection: chromadb.Collection, path: str
):
    loader = PyPDFLoader(path)
    pages = loader.load()

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    chunks = text_splitter.split_documents(pages)

    tmp = {
    'page_content': [],
    'metadata': []
    }

    for chunk in chunks:
        tmp['page_content'].append(chunk.page_content)
        tmp['metadata'].append(chunk.metadata)

    collection.add(
        documents=tmp['page_content'],
        ids=[str(uuid4()) for _ in range(len(chunks))],
        metadatas=tmp['metadata'],
    )
    print("Documents loaded to DB")

In [3]:
client = chromadb.PersistentClient(path="test_db/")

In [6]:
import os
import chromadb


def get_db_collection(path: str) -> chromadb.Collection:
    try:
        collection = client.get_collection(
            name=os.path.basename(path),
            embedding_function=embedding_function,
        )
    except ValueError as e:
        print(e)
        collection = client.create_collection(
            name=os.path.basename(path),
            embedding_function=embedding_function,
            metadata={"hnsw:space": "cosine"},
        )

    return collection

In [7]:
collection = get_db_collection('1512.03385v1.pdf')

Collection 1512.03385v1.pdf does not exist.


In [8]:
add_to_collection(collection, '1512.03385v1.pdf')

Documents loaded to DB


In [9]:
def generate_context(query_result: dict):
    context = ""
    for doc in query_result.documents:
        for i in doc:
            context += i
            context += "\n"
    return context


In [10]:
contexts = collection.query(query_texts="resnet")

In [12]:
texts = generate_context(contexts)

In [13]:
texts

'validation set (except†reported on the test set).\nmethod top-5 err. ( test)\nVGG [41] (ILSVRC’14) 7.32\nGoogLeNet [44] (ILSVRC’14) 6.66\nVGG [41] (v5) 6.8\nPReLU-net [13] 4.94\nBN-inception [16] 4.82\nResNet (ILSVRC’15) 3.57\nTable 5. Error rates (%) of ensembles . The top-5 error is on the\ntest set of ImageNet and reported by the test server.\nResNet reduces the top-1 error by 3.5% (Table 2), resulting\nfrom the successfully reduced training error (Fig. 4 right vs.\nleft). This comparison veriﬁes the effectiveness of residual\nlearning on extremely deep systems.\nLast, we also note that the 18-layer plain/residual nets\nare comparably accurate (Table 2), but the 18-layer ResNet\nconverges faster (Fig. 4 right vs. left). When the net is “not\noverly deep” (18 layers here), the current SGD solver is still\nable to ﬁnd good solutions to the plain net. In this case, the\nResNet eases the optimization by providing faster conver-\ngence at the early stage.able to ﬁnd good solutions to th