# PubMed Knowledge Graph

This notebook is part of a series that walks through the process of generating a knowledge graph of PubMed articles.

This notebook will
* Define a lexical graph schema
* Populate a Neo4j instance with articles chunks

In [1]:
# filter some Numpy warnings that pop up during ingestion
import warnings
warnings.filterwarnings('ignore', category=FutureWarning) 

In [78]:
import hashlib
from math import ceil
from typing import Any

In [3]:
# allows for async operations in notebooks
import nest_asyncio
nest_asyncio.apply()

## Lexical Graph Construction

We will use Unstructured.IO to partition and chunk our articles. 

This process breaks the articles into sensible chunks that may be used as context in our application. 

These chunks will also have relationships to the extracted entities, but we will add these later.

The lexical graph is based on the data model defined in the [Neo4j Connector Documentation](https://graphrag.com/reference/knowledge-graph/lexical-graph-extracted-entities/) section of [Unstructured](unstructured.io).

The main difference here is that we are capturing Text, Image and Table Elements in distinct nodes, instead of grouping them all into a single `UnstructuredElement` node.

Here is the data model we will be using.


<img src="./assets/images/lexical-data-model.png" alt="lexical-data-model" width="800px">

In [4]:
import pandas as pd
from pydantic import BaseModel, Field
from unstructured.partition.pdf import partition_pdf

from unstructured.documents.elements import CompositeElement

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# ------------
# Nodes
# ------------

class Document(BaseModel):
    """
    A Document.
    This is the top level node in our knowledge graph.
    Documents are made of many Chunks.
    """
    id: str = Field(..., description="The id of the document")
    name: str = Field(..., description="The name of the document")
    source: str = Field(..., description="The source of the document")

class Chunk(BaseModel):
    """
    A Chunk.
    This is a collection of `UnstructuredElements`.
    Unstructured.IO represents Chunks as `CompositeElement` objects.
    """
    id: str = Field(..., description="The id of the chunk")
    type: str = Field(..., description="The type of the chunk")
    text: str = Field(..., description="The text of the chunk")

class ChunkWithEmbedding(Chunk):
    """
    A Chunk with an embedding.
    This is used to represent chunks that have been embedded.
    """
    embedding: list[float] = Field(..., description="The embedding of the chunk text field")

class UnstructuredElement(BaseModel):
    """
    A base class for all unstructured elements. 
    These are the smallest units in our chunking process. 
    One or more of these elements are combined to form a Chunk.
    """
    id: str = Field(..., description="The id of the element")
    text: str = Field(..., description="The text of the element")
    type: str = Field(..., description="The type of the element")
    page_number: int = Field(..., description="The page number of the element")

class TextElement(UnstructuredElement):
    """
    A TextElement. Structurally identical to the UnstructuredElement class.
    This is used to represent text elements that contain no tables or images.
    """

class ImageElement(UnstructuredElement):
    """
    An ImageElement.
    """
    image_base64: str = Field(..., description="The base64 encoded image")
    image_mime_type: str = Field(..., description="The mime type of the image")

class TableElement(UnstructuredElement):
    """
    A TableElement. 
    This may also optionally have image features.
    """
    image_base64: str | None = Field(None, description="The base64 encoded table")
    image_mime_type: str | None = Field(None, description="The mime type of the table")
    text_as_html: str | None = Field(None, description="The text of the table as HTML")
    
# -------------
# Relationships
# -------------

class ChunkPartOfDocument(BaseModel):
    """
    (:Chunk {id: $chunk_id})-[:PART_OF_DOCUMENT]->(:Document {id: $document_id})
    """
    chunk_id: str = Field(..., description="The id of the chunk")
    document_id: str = Field(..., description="The id of the document")

class UnstructuredElementPartOfChunk(BaseModel):
    """
    (:UnstructuredElement {id: $unstructured_element_id})-[:PART_OF_CHUNK]->(:Chunk {id: $chunk_id})

    This covers TextElement, ImageElement and TableElement nodes since they all share the UnstructuredElement label.
    """ 
    unstructured_element_id: str = Field(..., description="The id of the unstructured element")
    chunk_id: str = Field(..., description="The id of the chunk")

In [6]:
def create_chunk_has_next_chunk_relationship_dataframe(chunk_dataframe: pd.DataFrame) -> pd.DataFrame:
    """
    Create the DataFrame for loading (:Chunk)-[:HAS_NEXT_CHUNK]->(:Chunk) relationships.

    Parameters
    ----------
    chunk_dataframe : pd.DataFrame
        A Pandas DataFrame containing the Chunk node records.

    Returns
    -------
    pd.DataFrame
        A Pandas DataFrame containing the columns `source_id` and `target_id`
    """
    df = chunk_dataframe.copy()
    df['next_id'] = df['id'].shift(-1)
    df.dropna(inplace=True)
    res = df[['id', 'next_id']].rename({"id": "source_id", "next_id": "target_id"}, axis=1)
    return res

def extract_document_title(text_elements_dataframe: pd.DataFrame) -> str:
    """
    Extract the title of the document from the text elements.
    Here we assume that the first 'Title' element is the title of the document.

    Returns
    -------
    str
        The title of the document.
    """
    try:
        return text_elements_dataframe[text_elements_dataframe['type'] == 'Title'].iloc[0]['text']
    except Exception as e:
        print(f"Unable to extract document title: {e}")
        return 'unknown title'

def parse_node_and_relationship_from_composite_element(composite_element: CompositeElement, parent_document_id: str) -> dict[str, dict[str, Any]]:
    """
    Parse the nodes and relationships for a given chunk (CompositeElement). 
    This will find the following nodes:
    * Chunk
    * TextElement
    * ImageElement
    * TableElement
    * UnstructuredElement (Shared label for TextElement, ImageElement and TableElement)

    And the following relationships:
    * (:Chunk)-[:PART_OF_DOCUMENT]->(:Document)
    * (:UnstructuredElement)-[:PART_OF_CHUNK]->(:Chunk)

    Returns
    -------
    dict[str, dict[str, Any]]
        A dictionary containing a list of records for each node and relationship type.
    """
    chunk = Chunk(id=composite_element.id, text=composite_element.text, type=composite_element.category)
    chunk_part_of_document = ChunkPartOfDocument(chunk_id=chunk.id, document_id=parent_document_id)

    text_elements: list[TextElement] = list()
    image_elements: list[ImageElement] = list()
    table_elements: list[TableElement] = list()
    unstructured_element_part_of_chunk: list[UnstructuredElementPartOfChunk] = list()

    # Chunks (CompositeElements) are made of many smaller text chunks (UnstructuredElements)
    # We can parse what type of elements these subchunks are and load them as well
    # This will give us access to images and tables from the document
    for element in composite_element.metadata.orig_elements:
        match element.category:
            case "NarrativeText":
                text_elements.append(TextElement(id=element.id, 
                                                 text=element.text, 
                                                 type=element.category, 
                                                 page_number=element.metadata.page_number))
            case "Image":
                image_elements.append(ImageElement(id=element.id, 
                                                   text=element.text,
                                                   type=element.category, 
                                                   page_number=element.metadata.page_number,
                                                   image_base64=element.metadata.image_base64, 
                                                   image_mime_type=element.metadata.image_mime_type))
            case "Table":
                table_elements.append(TableElement(id=element.id, 
                                                   text=element.text,
                                                   type=element.category, 
                                                   page_number=element.metadata.page_number,
                                                   image_base_64=element.metadata.image_base64,
                                                   image_mime_type=element.metadata.image_mime_type,
                                                   text_as_html=element.metadata.text_as_html))
            # Assume some kind of text element if we can't match the category
            # Could be headers, figure captions, etc
            case _:
                try:
                    text_elements.append(TextElement(id=element.id, 
                                                 text=element.text, 
                                                 type=element.category, 
                                                 page_number=element.metadata.page_number))
                except Exception as e:
                    print(f"Error parsing text element: {e}")

        unstructured_element_part_of_chunk.append(UnstructuredElementPartOfChunk(unstructured_element_id=element.id, chunk_id=chunk.id))

    # we return a list of records for each entity and relationship instead of the Pydantic classes
    return {
        "nodes": {
            "chunk": [chunk.model_dump()],
            "text_element": [el.model_dump() for el in text_elements],
            "image_element": [el.model_dump() for el in image_elements],
            "table_element": [el.model_dump() for el in table_elements],
        },
        "relationships": {
            "chunk_part_of_document": [chunk_part_of_document.model_dump()],
            "unstructured_element_part_of_chunk": [rel.model_dump() for rel in unstructured_element_part_of_chunk],
        }
    }

def parse_nodes_and_relationships_from_composite_elements(composite_elements: list[CompositeElement], parent_doc_id: str) -> dict[str, dict[str, pd.DataFrame]]:
    """
    Parse entity nodes and document relationships for a set of chunks (CompositeElements) and their parent document
    
    Parameters
    ----------
    composite_elements : list[CompositeElement]
        A list of CompositeElements to parse.
    parent_doc_id : str
        The id of the parent document.

    Returns
    -------
    dict[str, dict[str, pd.DataFrame]]
        A dictionary containing the node and relationship Pandas DataFrames for ingestion into the knowledge graph.
    """
    
    data = {
        "nodes": {
            "document": list(),
            "chunk": list(),
            "text_element": list(),
            "image_element": list(),
            "table_element": list(),
        },
        "relationships": {
            "chunk_part_of_document": list(),
            "unstructured_element_part_of_chunk": list(),
            "chunk_has_next_chunk": list()
        }
    }

    for composite_element in composite_elements:
        new_data = parse_node_and_relationship_from_composite_element(composite_element, parent_doc_id)

        # update the records with new nodes and relationships
        data["nodes"]["chunk"].extend(new_data["nodes"]["chunk"])
        data["relationships"]["chunk_part_of_document"].extend(new_data["relationships"]["chunk_part_of_document"])
        data["nodes"]["text_element"].extend(new_data["nodes"]["text_element"])
        data["nodes"]["image_element"].extend(new_data["nodes"]["image_element"])
        data["nodes"]["table_element"].extend(new_data["nodes"]["table_element"])
        data["relationships"]["unstructured_element_part_of_chunk"].extend(new_data["relationships"]["unstructured_element_part_of_chunk"])

    # convert to pandas dataframe for ingestion
    # node DataFrames
    data["nodes"]["chunk"] = pd.DataFrame(data["nodes"]["chunk"])
    data["nodes"]["text_element"] = pd.DataFrame(data["nodes"]["text_element"])
    data["nodes"]["image_element"] = pd.DataFrame(data["nodes"]["image_element"])
    data["nodes"]["table_element"] = pd.DataFrame(data["nodes"]["table_element"])

    document_title = extract_document_title(data["nodes"]["text_element"])
    data["nodes"]["document"] = pd.DataFrame([Document(id=parent_doc_id, name=document_title, source="pubmed").model_dump()])

    # relationship DataFrames
    data["relationships"]["chunk_part_of_document"] = pd.DataFrame(data["relationships"]["chunk_part_of_document"])
    data["relationships"]["unstructured_element_part_of_chunk"] = pd.DataFrame(data["relationships"]["unstructured_element_part_of_chunk"])
    data["relationships"]["chunk_has_next_chunk"] = create_chunk_has_next_chunk_relationship_dataframe(data["nodes"]["chunk"])

    return data

def process_pdf_article(file_name: str) -> dict[str, dict[str, pd.DataFrame]]:
    """
    Process an article and return the nodes and relationships for ingestion into the knowledge graph.
    Assumes that the article is stored in the "articles/pdf" directory.

    Parameters
    ----------
    file_name : str
        The name of the file to process.

    Returns
    -------
    dict[str, dict[str, pd.DataFrame]]
        A dictionary containing the node and relationship Pandas DataFrames for ingestion into the knowledge graph.
    """

    doc_id = hashlib.sha256(file_name.encode()).hexdigest()

    partitioned_doc = partition_pdf("articles/pdf/" + file_name,                   # path to the article file
                                    strategy="hi_res",                             # required to extract images        
                                    extract_images_in_pdf=True,                    # required to extract images
                                    extract_image_block_types=["Image", "Table"],  # extract images and tables as base64
                                    extract_image_block_to_payload=True,           # required to extract images as base64
                                    chunking_strategy="by_title",                  # chunk by title - this breaks by indentified sections
                                    combine_text_under_n_chars=200,                # combine text under 200 characters
                                    max_characters=1000,                           # 1000 <= characters per chunk
                                    multipage_sections=True)                       # combine multi-page sections
    # return partitioned_doc
    return parse_nodes_and_relationships_from_composite_elements(partitioned_doc, doc_id)

In [7]:
def process_pdf_articles(article_file_names: list[str]) -> dict[str, dict[str, pd.DataFrame]]:
    """
    Process a list of articles and return the nodes and relationships for ingestion into the knowledge graph.
    Assumes that the articles are stored in the "articles/pdf/" directory

    Parameters
    ----------
    article_file_names : list[str]
        A list of the names of the files to process.

    Returns
    -------
    dict[str, dict[str, pd.DataFrame]]
        A dictionary containing the node and relationship Pandas DataFrames for ingestion into the knowledge graph.
    """

    # initialize the DataFrames
    data = {
        "nodes": {
            "document": pd.DataFrame(),
            "chunk": pd.DataFrame(),
            "text_element": pd.DataFrame(),
            "image_element": pd.DataFrame(),
            "table_element": pd.DataFrame(),
        },
        "relationships": {
            "chunk_part_of_document": pd.DataFrame(),
            "unstructured_element_part_of_chunk": pd.DataFrame(),
            "chunk_has_next_chunk": pd.DataFrame()
        }
    }

    # process each article individually
    # this will
    # * partition the article into chunks using Unstructured
    # * Identify TextElements, ImageElements, and TableElements in each chunk
    # * Create DataFrames for all lexical nodes and relationships found in the article
    # * Update the global DataFrames with the new article data
    for file_name in article_file_names:
        print(f"Processing article: {file_name}")
        # process a single article 
        article_data = process_pdf_article(file_name)

        # update the DataFrames with the new article data
        data["nodes"]["document"] = pd.concat([data["nodes"]["document"], article_data["nodes"]["document"]], ignore_index=True)
        data["nodes"]["chunk"] = pd.concat([data["nodes"]["chunk"], article_data["nodes"]["chunk"]], ignore_index=True)
        data["nodes"]["text_element"] = pd.concat([data["nodes"]["text_element"], article_data["nodes"]["text_element"]], ignore_index=True)
        data["nodes"]["image_element"] = pd.concat([data["nodes"]["image_element"], article_data["nodes"]["image_element"]], ignore_index=True)
        data["nodes"]["table_element"] = pd.concat([data["nodes"]["table_element"], article_data["nodes"]["table_element"]], ignore_index=True)
        data["relationships"]["chunk_part_of_document"] = pd.concat([data["relationships"]["chunk_part_of_document"], article_data["relationships"]["chunk_part_of_document"]], ignore_index=True)
        data["relationships"]["unstructured_element_part_of_chunk"] = pd.concat([data["relationships"]["unstructured_element_part_of_chunk"], article_data["relationships"]["unstructured_element_part_of_chunk"]], ignore_index=True)
        data["relationships"]["chunk_has_next_chunk"] = pd.concat([data["relationships"]["chunk_has_next_chunk"], article_data["relationships"]["chunk_has_next_chunk"]], ignore_index=True)
    
    return data

## Load Articles

In [8]:
import os

We need to collect the article file names to pass to Unstructured for parsing.

In [9]:
article_names = os.listdir("articles/pdf/")[:2]

## Data Ingestion

We have now defined 
* Lexical data model
* Partitioning and chunking logic for articles

It is now time to define our ingestion logic. We will run ingest in three stages 

1. Load lexical graph
2. Embed lexical graph Chunk nodes

Decoupling these stages allows us easily make changes as we iterate our ingestion process.

We will be using PyNeoInstance to ingest our data into Neo4j. 

This allows for easy and manageable database and query configuration.

In [10]:
import os

from pyneoinstance import Neo4jInstance, load_yaml_file

Our database credentials and all of our queries are stored in the `pyneoinstance_config.yaml` file. 

This makes it easy to manage our queries and keeps the notebook code clean. 

In [73]:
config = load_yaml_file("pyneoinstance_config.yaml")

db_info = config['db_info']

constraints = config['initializing_queries']['constraints']
indexes = config['initializing_queries']['indexes']

node_load_queries = config['loading_queries']['nodes']
relationship_load_queries = config['loading_queries']['relationships']

processing_queries = config['processing_queries']

This graph object will handle database connections and read / write transactions for us.

In [60]:
graph = Neo4jInstance(db_info.get('uri', os.getenv("NEO4J_URI", "neo4j://localhost:7687")), # use config value -> use env value -> use default value
                      db_info.get('user', os.getenv("NEO4J_USER", "neo4j")), 
                      db_info.get('password', os.getenv("NEO4J_PASSWORD", "password")))

This is a helper function for ingesting data using the PyNeoInstance library.

In [13]:
def get_partition(data: pd.DataFrame, batch_size: int = 500) -> int:
    """
    Determine the data partition based on the desired batch size.

    Parameters
    ----------
    data : pd.DataFrame
        The Pandas DataFrame to partition.
    batch_size : int
        The desired batch size.

    Returns
    -------
    int
        The partition size.
    """
    
    partition = int(len(data) / batch_size)
    print("partition: "+str(partition if partition > 1 else 1))
    return partition if partition > 1 else 1

### Constraints

Here we write all the constraints and indexes we need for the entire graph.

We will be utilizing the vector index to perform similarity search over our `Chunk` nodes.

The query to set this index may be found in the `pyneoinstance_config.yaml` file and looks like this:

```cypher
CREATE VECTOR INDEX chunk_vector_index IF NOT EXISTS
    FOR (c:Chunk)
    ON c.embedding
    OPTIONS { indexConfig: {
        `vector.dimensions`: 768,
        `vector.similarity_function`: 'cosine'
    }}
```

Since we set the dimensions to 768, we must ensure that we use 768 dimensions when generating our embeddings as well.

In [61]:
def create_constraints_and_indexes() -> None:
    """
    Create constraints and indexes for the lexical, entity and patient journey graphs.
    """
    try:
        if constraints and len(constraints) > 0:
            graph.execute_write_queries(database=db_info['database'], queries=list(constraints.values()))
    except Exception as e:
        print(e)

    try:
        if indexes and len(indexes) > 0:
            graph.execute_write_queries(database=db_info['database'], queries=list(indexes.values()))
    except Exception as e:
        print(e)

In [15]:
create_constraints_and_indexes()

### Ingest Lexical Graph

#### Processing | Preparation

Process the articles

In [16]:
print(f"Loaded {len(article_names)} articles\n{'*'*5}")
[print(f"* {article_name}") for article_name in article_names]
print()

Loaded 2 articles
*****
* nihms-1852972.pdf
* fendo-11-00178.pdf



In [17]:
lexical_ingest_records = process_pdf_articles(article_names)

Processing article: nihms-1852972.pdf


Cannot set gray non-stroke color because /'R50' is an invalid float value


Processing article: fendo-11-00178.pdf


Cannot set gray non-stroke color because /'R50' is an invalid float value


Check the first few records 

In [18]:
lexical_ingest_records["nodes"]["document"]

Unnamed: 0,id,name,source
0,753d70915e2a2a747cee355745bd17ff08c45f90d938b1...,HHS Public Access,pubmed
1,e0048a13f033f7fe71581406cd2d3bac1ffc4db7ce88a8...,OPEN ACCESS,pubmed


#### Ingestion

Load the Document and Chunk nodes into the graph

In [19]:
def load_lexical_nodes(document_dataframe: pd.DataFrame, 
                       chunk_dataframe: pd.DataFrame, 
                       text_element_dataframe: pd.DataFrame, 
                       image_element_dataframe: pd.DataFrame, 
                       table_element_dataframe: pd.DataFrame) -> None:
    """
    Load lexical nodes into the graph. These include Document and Chunk nodes.

    Parameters
    ----------
    document_dataframe : pd.DataFrame
        A Pandas DataFrame of Document nodes to load into the graph. 
    chunk_dataframe : pd.DataFrame
        A Pandas DataFrame of Chunk nodes to load into the graph.
    text_element_dataframe : pd.DataFrame
        A Pandas DataFrame of TextElement nodes to load into the graph. 
    image_element_dataframe : pd.DataFrame
        A Pandas DataFrame of ImageElement nodes to load into the graph. 
    table_element_dataframe : pd.DataFrame
        A Pandas DataFrame of TableElement nodes to load into the graph. 
    """
    
    lexical_nodes_ingest_iterator = list(zip([document_dataframe, 
                                              chunk_dataframe, 
                                              text_element_dataframe, 
                                              image_element_dataframe, 
                                              table_element_dataframe], 
                                              ['document', 
                                               'chunk', 
                                               'text_element', 
                                               'image_element', 
                                               'table_element']))

    for data, query in lexical_nodes_ingest_iterator:
        res = graph.execute_write_query_with_data(database=db_info['database'], 
                                                    data=data, 
                                                    query=node_load_queries[query], 
                                                    partitions=get_partition(data, batch_size=500), 
                                                    parallel=True,
                                                    workers=2)
        print(res)

In [20]:
load_lexical_nodes(lexical_ingest_records["nodes"]["document"], 
                   lexical_ingest_records["nodes"]["chunk"], 
                   lexical_ingest_records["nodes"]["text_element"], 
                   lexical_ingest_records["nodes"]["image_element"], 
                   lexical_ingest_records["nodes"]["table_element"])

partition: 1
{'labels_added': 2, 'nodes_created': 2, 'properties_set': 6}
partition: 1
{'labels_added': 209, 'nodes_created': 209, 'properties_set': 627}
partition: 2
{'labels_added': 2768, 'nodes_created': 1384, 'properties_set': 5536}
partition: 1
{'labels_added': 12, 'nodes_created': 6, 'properties_set': 36}
partition: 1
{'labels_added': 10, 'nodes_created': 5, 'properties_set': 30}


Load the lexical relationships into the graph

In [21]:
def load_lexical_relationships(chunk_part_of_document_dataframe: pd.DataFrame, 
                               unstructured_element_part_of_chunk_dataframe: pd.DataFrame, 
                               chunk_has_next_chunk_dataframe: pd.DataFrame) -> None:
    """
    Load lexical relationships into the graph.

    Parameters
    ----------
    chunk_part_of_document_dataframe : pd.DataFrame
        A Pandas DataFrame of Chunk - PART_OF -> Document relationships to load into the graph.
        Should have columns `chunk_id` and `document_id`.
    unstructured_element_part_of_chunk_dataframe : pd.DataFrame
        A Pandas DataFrame of UnstructuredElement - PART_OF -> Chunk relationships to load into the graph.
        Should have columns `unstructured_element_id` and `chunk_id`.
    chunk_has_next_chunk_dataframe : pd.DataFrame
        A Pandas DataFrame of Chunk - HAS_NEXT_CHUNK -> Chunk relationships to load into the graph.
        Should have columns `source_id` and `target_id`.
    """
    lexical_relationships_ingest_iterator = list(zip([chunk_part_of_document_dataframe, 
                                                      unstructured_element_part_of_chunk_dataframe, 
                                                      chunk_has_next_chunk_dataframe], 
                                                      ['chunk_part_of_document', 
                                                       'unstructured_element_part_of_chunk', 
                                                       'chunk_has_next_chunk']))

    for data, query in lexical_relationships_ingest_iterator:
        res = graph.execute_write_query_with_data(database=db_info['database'], 
                                                    data=data, 
                                                    query=relationship_load_queries[query], 
                                                    partitions=get_partition(data, batch_size=500))
        print(res)

In [22]:
load_lexical_relationships(lexical_ingest_records["relationships"]["chunk_part_of_document"], 
                          lexical_ingest_records["relationships"]["unstructured_element_part_of_chunk"],
                          lexical_ingest_records["relationships"]["chunk_has_next_chunk"])

partition: 1
{'relationships_created': 209}
partition: 2
{'relationships_created': 1428}
partition: 1
{'relationships_created': 207}


### Embed Lexical Graph

Here we will read Chunk nodes from the graph that don't have embedding properties yet. 

We will then embed the Chunk text property and add the embedding as a property.

In [23]:
import asyncio

from openai import AsyncOpenAI

In [24]:
embedding_client = AsyncOpenAI()

In [79]:
def get_chunks_to_embed(min_length: int = 20) -> pd.DataFrame:
    """
    Get the chunks to embed.
    """
    chunks_to_embed = graph.execute_read_query(database=db_info['database'], 
                                               query=processing_queries['get_chunk_nodes_to_embed'],
                                               parameters={'min_length': min_length})
    return chunks_to_embed

To learn more about OpenAI embedding models follow this [link](https://platform.openai.com/docs/guides/embeddings/embedding-models#embedding-models).

In [80]:
async def create_single_chunk_embedding(chunk_text: str, chunk_id: str, failed_cache: list[tuple[str, str]]) -> list[float]:
    """
    Create embedding for a single Chunk node's text.

    Parameters  
    ----------
    chunk_text : str
        The text of the chunk to embed.
    chunk_id : str
        The id of the chunk to embed.
    failed_cache : list[tuple[str, str]]
        A list of tuples, where the first element is the chunk id and the second element is the text chunk.
        This is used to log failed embeddings across batches.

    Returns
    -------
    list[float]
        The embedding for the chunk text.
    """

    try:
        response = await embedding_client.embeddings.create(
            model="text-embedding-3-small",
            input=chunk_text,
            encoding_format="float",
            dimensions=768, # must be the same dimensions as the vector index
        )
        return response.data[0].embedding
    except Exception as e:
        print(e)
        failed_cache.append((chunk_id, chunk_text))
        return None

In [81]:
async def create_chunk_embeddings(chunk_nodes_dataframe: pd.DataFrame, batch_size: int = 100) -> list[tuple[str, list[Any]]]:
    """
    Create embeddings for a Pandas DataFrame of Chunk nodes.

    Parameters
    ----------
    chunk_nodes_dataframe : pd.DataFrame
        A Pandas DataFrame where each row represents a Chunk node.
        Has columns `id` and `text`.
    batch_size : int
        The number of text chunks to process in each batch.

    Returns
    -------
    list[tuple[str, list[float]]]
        A list of tuples, where the first element is the chunk id and the second element is the embedding for the chunk text.
    """

    
    async def _create_embeddings_for_batch(batch: pd.DataFrame, failed_cache: list[tuple[str, str]]) -> list[tuple[str, list[dict[str, Any]]]]:
        """
        Create embeddings for a batch of text chunks.
        Failed extractions are maintained in the `failed_cache` list that is passed to the embedding creation function.

        Parameters
        ----------
        batch : pd.DataFrame
            A Pandas DataFrame where each row represents a text chunk.
            Has columns `id` and `text`.
        failed_cache : list[tuple[str, str]]
            A list of tuples, where the first element is the chunk id and the second element is the text chunk.
            This is used to log failed embeddings across batches.

        Returns
        -------
        list[tuple[str, list[dict[str, Any]]]]
            A list of tuples, where the first element is the chunk id and the second element is a list of entities found in the chunk.
        """
        
        # Create tasks for all nodes in the batch
        # order is maintained
        tasks = [create_single_chunk_embedding(row["text"], row['id'], failed_cache) for _, row in batch.iterrows()]
        # Execute all tasks concurrently
        embedding_results = await asyncio.gather(*tasks)

        # filter results to only include non-None values
        embedding_results = [(id, embedding) for id, embedding in zip(batch["id"], embedding_results) if embedding is not None]

        return embedding_results

    
    async def _create_embeddings_in_batches(chunk_nodes_dataframe: pd.DataFrame, batch_size: int) -> tuple[list[tuple[str, list[dict[str, Any]]]], list[tuple[str, str]]]:
        """
        Create embeddings for a Pandas DataFrame of text chunks in batches.

        Parameters
        ----------
        chunk_nodes_dataframe : pd.DataFrame
            A Pandas DataFrame where each row represents a text chunk.
            Has columns `id` and `text`.
        batch_size : int
            The number of text chunks to process in each batch.

        Returns
        -------
        tuple[list[tuple[str, list[float]]], list[tuple[str, str]]]
            A tuple of two lists. The first list contains tuples of chunk id and list of embeddings for the chunk text.
            The second list contains tuples of chunk id and text chunk that failed to be processed.
        """

        results = list()
        failed_cache: list[tuple[str, str]] = list() # [(chunk_id, text_chunk), ...]
        for batch_idx, i in enumerate(range(0, len(chunk_nodes_dataframe), batch_size)):
            print(f"Processing batch {batch_idx+1} of {ceil(len(chunk_nodes_dataframe)/(batch_size))}  \n", end="\r") 
            if i + batch_size >= len(chunk_nodes_dataframe):
                batch = chunk_nodes_dataframe.iloc[i:]
            else:
                batch = chunk_nodes_dataframe.iloc[i:i+batch_size]
            batch_results = await _create_embeddings_for_batch(batch, failed_cache)

            # Add extracted records to the results list
            results.extend(batch_results)

        return results, failed_cache

    # first pass through chunks
    results, failed = await _create_embeddings_in_batches(chunk_nodes_dataframe, batch_size)
    print(f"Successful chunks : {len(results)}")
    print(f"Failed chunks     : {len(failed)}")
    print("--------------------------------")
    print("Retrying failed chunks...")

    # retry failed chunks once
    retry_df = pd.DataFrame(failed, columns=["id", "text"])
    retry_results, failed = await _create_embeddings_in_batches(retry_df, batch_size)
    print(f"Successful retries : {len(retry_results)}")
    print(f"Failed retries     : {len(failed)}")

    print("--------------------------------")
    print(f"Overall Success Rate : {round(len(results + retry_results) / len(chunk_nodes_dataframe) * 100, 2)}%")

    return results + retry_results

In [82]:
chunks = get_chunks_to_embed()

In [83]:
chunks.head(3)

Unnamed: 0,id,text
0,14c257d5146e625ccbe4b0435317b173,The adverse events associated with GLP-1RA are...
1,18032216f80a07578785896148f43e26,The rates of secondary-outcome events (Table 2...
2,19f0d07d88fc2175b63d25ea52201783,GLP-1 is secreted by L-cells found in the ileu...


In [84]:
chunks_with_embeddings = await create_chunk_embeddings(chunks, batch_size=100)

Processing batch 1 of 2  
Processing batch 2 of 2  
Successful chunks : 196
Failed chunks     : 0
--------------------------------
Retrying failed chunks...
Successful retries : 0
Failed retries     : 0
--------------------------------
Overall Success Rate : 100.0%


Neo4j has special functions to write embedding properties.

These functions store the embedding in a more space-efficient manner compared to using `SET`.

We can use 
* `db.create.setNodeVectorProperty` for node properties (requires Neo4j 5.13)
* `db.create.setRelationshipVectorProperty` for relationship properties (requires Neo4j 5.18)

The query looks something like this:

```cypher
UNWIND $rows as row
MATCH (c:Chunk {id: row.id})
CALL db.create.setNodeVectorProperty(c, 'embedding', row.embedding)
```


In [85]:
def write_embeddings_to_chunks(df: pd.DataFrame) -> None:
    """
    Write embeddings to chunks.

    Parameters
    ----------
    df : pd.DataFrame
        A Pandas DataFrame where each row represents a Chunk node.
        Has columns `id` and `embedding`.
    """
    graph.execute_write_query_with_data(database=db_info['database'], 
                              query=processing_queries['write_embeddings_by_chunk_id'],
                              data=df)

In [86]:
embeddings_df = pd.DataFrame(chunks_with_embeddings, columns=['id', 'embedding'])

In [87]:
write_embeddings_to_chunks(embeddings_df)