# Embeddings

ReadNext currently uses Cohere's embedding web service to generate the embedding of each of the arXiv paper. We will eventually extend that to other services, including some local ones.

In [None]:
#| default_exp embedding

## Imports

In [None]:
#| export

import chromadb
import cohere
import os
from chromadb.errors import IDAlreadyExistsError
from pypdf import PdfReader
from readnext.arxiv_categories import exists
from readnext.arxiv_sync import get_docs_path
from rich import print
from rich.progress import Progress

# TODO Default embedding to use BGE via Hugging Face: https://blog.gopenai.com/bge-embeddings-langchain-and-chroma-for-retrieval-qa-9c684206d8f3

## PDF to Text

The library PdfReader is used to extract the text from the PDF files.

In [None]:
#| export

def pdf_to_text(file_path: str) -> str:
    """Read a PDF file and output it as a text string."""
    pdf_file_obj = open(file_path, 'rb')
    pdf_reader = PdfReader(pdf_file_obj)
    text = ''

    for page in pdf_reader.pages:
        text += page.extract_text()

    return text

### Tests

In [None]:
assert pdf_to_text("../tests/assets/test.pdf") == "this is a test"
assert pdf_to_text("../tests/assets/test.pdf") != "this is a test foo"

## Get PDF files from a folder

In [None]:
#| export

def get_pdfs_from_folder(folder_path: str) -> list:
    """Given a folder path, return all the PDF files existing in that folder."""
    pdfs = []

    for pdf in os.listdir(folder_path):
        if pdf.endswith(".pdf"):
            pdfs.append(pdf)

    return pdfs

### Tests

In [None]:
assert get_pdfs_from_folder("../tests/assets/") == ['test.pdf']
assert get_pdfs_from_folder("../tests/assets/") != ['test.pdf', 'foo.pdf']

## Embed all papers of a arXiv category

The embedding database management system ReadNext uses is [Chroma](https://www.trychroma.com/).

The embedding DBMS is organized as follows:

 - Each category (sub or top categories) become a collection of embeddings
 - We have one global collection named `all` that contains all the embeddings of every known categories

When a new arXiv category is being processing, all the embeddings of the papers it contains will be added to the collection related to its category, and to the global collection.

For the category collection, we have to prefix each category with `_arxiv` to avoid the restriction that Chroma won't accept a collection name with less than three characters.

In [None]:
#| export

def embed_category_papers(category: str) -> bool:
    """Given a ArXiv category, create the embeddings for each of the PDF paper existing locally.
    Embeddings is currently using Cohere's embedding service.
    Returns True if successful, False otherwise."""

    co = cohere.Client(os.environ.get('COHERE_API_KEY'))

    chroma_client = chromadb.PersistentClient(path=os.environ.get('CHROMA_DB_PATH'))

    if exists(category):
        # We create two Chroma collection of embeddings:
        #   1. a general one with all and every embeddings called 'all'
        #   2. one for the specific ArXiv category
        papers_all_collection = chroma_client.get_or_create_collection(name="all")
        papers_category_collection = chroma_client.get_or_create_collection(name="arxiv_" + category)

        with Progress() as progress:
            folder_path = get_docs_path(category)
            pdfs = get_pdfs_from_folder(folder_path)

            task = progress.add_task("[cyan]Embedding papers...", total=len(pdfs))

            for pdf in pdfs:
                # check if the PDF file has already been embedded and indexed in Chromadb,
                # let's not do all this processing if that is the case.
                check_pdf = papers_all_collection.get(ids=[pdf])

                if not progress.finished:
                    progress.update(task, advance=1)

                if len(check_pdf['ids']) == 0:
                    doc = pdf_to_text(folder_path.rstrip('/') + '/' + pdf)

                    # get the embedding of the paper from Cohere
                    embedding = co.embed([doc])

                    try:
                        papers_all_collection.add(
                            embeddings=embedding.embeddings,
                            documents=[doc.encode("unicode_escape").decode()], # necessary escape to prevent possible encoding errors when adding to Chroma
                            metadatas=[{"source": pdf,
                                        "category": category}],
                            ids=[pdf]
                        )
                    except IDAlreadyExistsError:
                        print("[yellow]ID already existing in Chroma DB, skipping...[/yellow]")
                        continue
                        
                    try:
                        papers_category_collection.add(
                            embeddings=embedding.embeddings,
                            documents=[doc.encode("unicode_escape").decode()], # necessary escape to prevent possible encoding errors when adding to Chroma
                            metadatas=[{"source": pdf}],
                            ids=[pdf]
                        )
                    except IDAlreadyExistsError:
                        print("[yellow]ID already existing in Chroma DB, skipping...[/yellow]")
                        continue
        return True
    else:
        print("[red]Can't persist embeddings in local vector db, ArXiv category not existing[/red]")
        return False