# Embeddings

> By default, ReadNext uses Hugging Face models that it downloads locally to generate the embeddings. Optionally, it can use external embedding services. At the moment, it is only integrated with the Cohere Embedding model.

In [None]:
#| default_exp embedding

## Imports

In [None]:
#| export

import chromadb
import cohere
import os
import torch
from chromadb.errors import IDAlreadyExistsError
from functools import cache 
from pypdf import PdfReader
from readnext.arxiv_categories import exists
from readnext.arxiv_sync import get_docs_path
from rich import print
from rich.progress import Progress
from transformers import AutoTokenizer, AutoModel

## Download Embedding Model

To be able to use local embedding model, the first step is to download them from Hugging Face using their Transformers library and save them locally on the file system.

In [None]:
#| export

def download_embedding_model(model_path: str, model_name: str):
    """Download a Hugging Face model and tokenizer to the specified directory"""
    # Check if the directory already exists
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    else:
        return

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)

    # Save the model and tokenizer to the specified directory
    model.save_pretrained(model_path)
    tokenizer.save_pretrained(model_path)

### Tests

In [None]:
from shutil import rmtree

In [None]:
download_embedding_model('test-download/', 'prajjwal1/bert-tiny')

assert os.path.exists('test-download/config.json')
assert os.path.exists('test-download/pytorch_model.bin')
assert os.path.exists('test-download/special_tokens_map.json')
assert os.path.exists('test-download/tokenizer_config.json')
assert os.path.exists('test-download/vocab.txt')

# tears down 
rmtree('test-download/')

## Load Embedding Model

Once the models are available locally, the next step is to load them in memory to be able to use them to create the embeddings for the PDF files. Because `load_embedding_model` can be called numerous time, we do memoize the result to speed up the process. There is no need to use a LRU cache here since only a single item should be cached anyway, so let's simplify the code.

In [None]:
#| export

@cache
def load_embedding_model(model_path: str):
    """Load a Hugging Face model and tokenizer from the specified directory"""
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModel.from_pretrained(model_path)
    return model, tokenizer

### Tests

In [None]:
from shutil import rmtree

In [None]:
download_embedding_model('test-download/', 'prajjwal1/bert-tiny')

model, tokenizer = load_embedding_model('test-download/')

assert model is not None
assert tokenizer is not None

# tears down 
rmtree('test-download/')

## Embed (Local Model)


In [None]:
#| export
def embed_text(text: str, model, tokenizer):
    """Embed a text using a Hugging Face model and tokenizer"""
    encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)
        # Perform pooling. In this case, cls pooling.
        sentence_embeddings = model_output[0][:, 0]

    embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)

    return embeddings

### Tests

In [None]:
from shutil import rmtree

In [None]:
download_embedding_model('test-download/', 'BAAI/bge-base-en')

model, tokenizer = load_embedding_model('test-download/')

tensor = embed_text('Hello world!', model, tokenizer)

assert len(tensor.tolist()[0]) == 128

# tears down 
rmtree('test-download/')

## Get Embedding System

We need to be able to easily identify the embedding system currently configured by the user. This is a utility function to simply the comprehension of the code elsewhere in the codebase.

In [None]:
#| export

def embedding_system() -> str:
    """Return a unique identifier for the embedding system currently in use"""

    if os.environ.get('EMBEDDING_SYSTEM') == 'BAAI/bge-base-en':
        return 'baai-bge-base-en'
    elif os.environ.get('EMBEDDING_SYSTEM') == 'cohere':
        return 'cohere'
        embeddings = co.embed([text]).embeddings
    else:
        return ''

## Get Embeddings (From any supporter system)

In [None]:
#| export

def get_embeddings(text: str) -> list:
    """Get embeddings for a text using any supported embedding system."""

    match embedding_system():
        case 'baai-bge-base-en':
            model, tokenizer = load_embedding_model(os.environ.get('MODELS_PATH'))
            return embed_text(text, model, tokenizer).tolist()
        case 'cohere':
            co = cohere.Client(os.environ.get('COHERE_API_KEY'))
            return co.embed([text]).embeddings
        case other:
            return []

## PDF to Text

The library PdfReader is used to extract the text from the PDF files.

In [None]:
#| export

def pdf_to_text(file_path: str) -> str:
    """Read a PDF file and output it as a text string."""
    with open(file_path, 'rb') as pdf_file_obj:
        pdf_reader = PdfReader(pdf_file_obj)
        text = ''

        for page in pdf_reader.pages:
            text += page.extract_text()

        return text

### Tests

In [None]:
assert pdf_to_text("../tests/assets/test.pdf") == "this is a test"
assert pdf_to_text("../tests/assets/test.pdf") != "this is a test foo"

## Get PDF files from a folder

In [None]:
#| export

def get_pdfs_from_folder(folder_path: str) -> list:
    """Given a folder path, return all the PDF files existing in that folder."""
    return [pdf for pdf in os.listdir(folder_path) if pdf.endswith(".pdf")]

### Tests

In [None]:
assert get_pdfs_from_folder("../tests/assets/") == ['test.pdf']
assert get_pdfs_from_folder("../tests/assets/") != ['test.pdf', 'foo.pdf']

## Get Chroma Collection Name

It is important that the number of dimensions of the embedding is the same in a Chroma collection and when it gets queried. For example, depending what the users want to use, he may at one time use the local embedding model and at another time use the Cohere embedding service. In both cases, the number of dimensions of the embedding will be different. To avoid this problem, we use the name of the collection to determine the number of dimensions of the embedding. This way, the number of dimensions will be the same for a given collection, no matter what embedding model is used.

In [None]:
def get_chroma_collection_name(name: str) -> str:
    """Get the name of the ChromaDB collection to use."""
    
    return os.environ.get('CHROMA_COLLECTION_NAME')

## Embed all papers of a arXiv category

The embedding database management system ReadNext uses is [Chroma](https://www.trychroma.com/).

The embedding DBMS is organized as follows:

 - Each category (sub or top categories) become a collection of embeddings
 - We have one global collection named `all` that contains all the embeddings of every known categories

When a new arXiv category is being processing, all the embeddings of the papers it contains will be added to the collection related to its category, and to the global collection.

For the category collection, we have to prefix each category with `_arxiv` to avoid the restriction that Chroma won't accept a collection name with less than three characters.

In [None]:
#| export

def embed_category_papers(category: str) -> bool:
    """Given a ArXiv category, create the embeddings for each of the PDF paper existing locally.
    Embeddings is currently using Cohere's embedding service.
    Returns True if successful, False otherwise."""
 
    chroma_client = chromadb.PersistentClient(path=os.environ.get('CHROMA_DB_PATH'))

    if exists(category):
        # We create two Chroma collection of embeddings:
        #   1. a general one with all and every embeddings called 'all'
        #   2. one for the specific ArXiv category
        papers_all_collection = chroma_client.get_or_create_collection(name="all_" + embedding_system())
        papers_category_collection = chroma_client.get_or_create_collection(name="arxiv_" + category + '_' + embedding_system())

        with Progress() as progress:
            folder_path = get_docs_path(category)
            pdfs = get_pdfs_from_folder(folder_path)

            task = progress.add_task("[cyan]Embedding papers...", total=len(pdfs))

            for pdf in pdfs:
                # check if the PDF file has already been embedded and indexed in Chromadb,
                # let's not do all this processing if that is the case.
                check_pdf = papers_all_collection.get(ids=[pdf])

                if not progress.finished:
                    progress.update(task, advance=1)

                if len(check_pdf['ids']) == 0:
                    doc = pdf_to_text(folder_path.rstrip('/') + '/' + pdf)

                    try:
                        papers_all_collection.add(
                            embeddings=get_embeddings(doc),
                            documents=[doc.encode("unicode_escape").decode()], # necessary escape to prevent possible encoding errors when adding to Chroma
                            metadatas=[{"source": pdf,
                                        "category": category}],
                            ids=[pdf]
                        )
                    except IDAlreadyExistsError:
                        print("[yellow]ID already existing in Chroma DB, skipping...[/yellow]")
                        continue
                        
                    try:
                        papers_category_collection.add(
                            embeddings=get_embeddings(doc),
                            documents=[doc.encode("unicode_escape").decode()], # necessary escape to prevent possible encoding errors when adding to Chroma
                            metadatas=[{"source": pdf}],
                            ids=[pdf]
                        )
                    except IDAlreadyExistsError:
                        print("[yellow]ID already existing in Chroma DB, skipping...[/yellow]")
                        continue
        return True
    else:
        print("[red]Can't persist embeddings in local vector db, ArXiv category not existing[/red]")
        return False