# PubMed Knowledge Graph

This notebook walks through the process of generating a knowledge graph of PubMed articles.

This notebook will
* Download a selection of articles from PubMed
* Define a knowledge graph schema
* Extract entities from the articles according to the defined schema
* Populate a Neo4j instance with articles and extracted entities
* Connect extracted entities with existing patient journey data

This notebook requires a local repo of articles. You may download a sample of 20 PubMed articles by running the following command.

```bash
python3 ./scripts/fetch_pubmed_articles.py
```

In [1]:
# filter some Numpy warnings that pop up during ingestion
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

In [2]:
import asyncio
import hashlib
from typing import Any, Optional, List

In [3]:
# allows for async operations in notebooks
import nest_asyncio
nest_asyncio.apply()

## Lexical Graph Construction

We will use Unstructured.IO to partition and chunk our articles. 

This process breaks the articles into sensible chunks that may be used as context in our application. 

These chunks will also have relationships to the extracted entities, but we will add these relationships later.

The lexical graph will adhere to the structure defined in the ['Lexical Graph with Extracted Entities'](https://graphrag.com/reference/knowledge-graph/lexical-graph-extracted-entities/) section of [graphrag.com](graphrag.com).

Here is the data model we will be using.


<img src="./assets/images/lexical-data-model.png" alt="lexical-data-model" width="400px">

In [46]:
import pandas as pd
from pydantic import BaseModel, Field, computed_field, field_validator
from unstructured.partition.pdf import partition_pdf

from unstructured.documents.elements import CompositeElement

from uuid import uuid4

In [5]:
# ------------
# Nodes
# ------------

class Document(BaseModel):
    """
    A Document.
    This is the top level node in our knowledge graph.
    Documents are made of many Chunks.
    """
    id: str = Field(..., description="The id of the document")
    name: str = Field(..., description="The name of the document")
    source: str = Field(..., description="The source of the document")

class Chunk(BaseModel):
    """
    A Chunk.
    This is a collection of `UnstructuredElements`.
    Unstructured.IO represents Chunks as `CompositeElement` objects.
    """
    id: str = Field(..., description="The id of the chunk")
    type: str = Field(..., description="The type of the chunk")
    text: str = Field(..., description="The text of the chunk")

class ChunkWithEmbedding(Chunk):
    """
    A Chunk with an embedding.
    This is used to represent chunks that have been embedded.
    """
    embedding: list[float] = Field(..., description="The embedding of the chunk text field")

class UnstructuredElement(BaseModel):
    """
    A base class for all unstructured elements. 
    These are the smallest units in our chunking process. 
    One or more of these elements are combined to form a Chunk.
    """
    id: str = Field(..., description="The id of the element")
    text: str = Field(..., description="The text of the element")
    type: str = Field(..., description="The type of the element")
    page_number: int = Field(..., description="The page number of the element")

class TextElement(UnstructuredElement):
    """
    A TextElement. Structurally identical to the UnstructuredElement class.
    This is used to represent text elements that contain no tables or images.
    """

class ImageElement(UnstructuredElement):
    """
    An ImageElement.
    """
    image_base64: str = Field(..., description="The base64 encoded image")
    image_mime_type: str = Field(..., description="The mime type of the image")

class TableElement(UnstructuredElement):
    """
    A TableElement. 
    This may also have image features and so it inherits from ImageElement.
    """
    image_base64: str | None = Field(None, description="The base64 encoded table")
    image_mime_type: str | None = Field(None, description="The mime type of the table")
    text_as_html: str | None = Field(None, description="The text of the table as HTML")
    
# -------------
# Relationships
# -------------

class ChunkPartOfDocument(BaseModel):
    """
    (:Chunk {id: $chunk_id})-[:PART_OF_DOCUMENT]->(:Document {id: $document_id})
    """
    chunk_id: str = Field(..., description="The id of the chunk")
    document_id: str = Field(..., description="The id of the document")

class UnstructuredElementPartOfChunk(BaseModel):
    """
    (:UnstructuredElement {id: $unstructured_element_id})-[:PART_OF_CHUNK]->(:Chunk {id: $chunk_id})

    This covers TextElement, ImageElement and TableElement nodes since they all share the UnstructuredElement label.
    """ 
    unstructured_element_id: str = Field(..., description="The id of the unstructured element")
    chunk_id: str = Field(..., description="The id of the chunk")

In [6]:
def parse_article_file_name(file_name: str) -> tuple[str, str]:
    """
    Parse the article file name and return the PubMed ID and title.

    Parameters
    ----------
    file_name : str
        The name of the file to parse.

    Returns
    -------
    tuple[str, str]
        The PubMed ID and title of the article.
    """
    doc_id, title = file_name.split("-", 1)
    title = title.replace("_", " ")

    return doc_id, title

In [7]:
def create_chunk_has_next_chunk_relationship_dataframe(chunk_dataframe: pd.DataFrame) -> pd.DataFrame:
    """
    Create the DataFrame for loading (:Chunk)-[:HAS_NEXT_CHUNK]->(:Chunk) relationships.
    """
    df = chunk_dataframe.copy()
    df['next_id'] = df['id'].shift(-1)
    df.dropna(inplace=True)
    res = df[['id', 'next_id']].rename({"id": "source_id", "next_id": "target_id"}, axis=1)
    return res

def extract_document_title(text_elements_dataframe: pd.DataFrame) -> str:
    """
    Extract the title of the document from the text elements.
    Here we assume that the first 'Title' element is the title of the document.

    Returns
    -------
    str
        The title of the document.
    """
    try:
        return text_elements_dataframe[text_elements_dataframe['type'] == 'Title'].iloc[0]['text']
    except Exception as e:
        print(f"Unable to extract document title: {e}")
        return 'unknown title'

def parse_node_and_relationship_from_composite_element(composite_element: CompositeElement, parent_document_id: str) -> dict[str, dict[str, Any]]:
    """
    Parse the nodes and relationships for a given chunk (CompositeElement). 
    This will find the following nodes:
    * Chunk
    * TextElement
    * ImageElement
    * TableElement
    * UnstructuredElement (Shared label for TextElement, ImageElement and TableElement)

    And the following relationships:
    * (:Chunk)-[:PART_OF_DOCUMENT]->(:Document)
    * (:UnstructuredElement)-[:PART_OF_CHUNK]->(:Chunk)

    Returns
    -------
    dict[str, dict[str, Any]]
        A dictionary containing a list of records for each node and relationship type.
    """
    chunk = Chunk(id=composite_element.id, text=composite_element.text, type=composite_element.category)
    chunk_part_of_document = ChunkPartOfDocument(chunk_id=chunk.id, document_id=parent_document_id)

    text_elements: list[TextElement] = list()
    image_elements: list[ImageElement] = list()
    table_elements: list[TableElement] = list()
    unstructured_element_part_of_chunk: list[UnstructuredElementPartOfChunk] = list()

    # Chunks (CompositeElements) are made of many smaller text chunks (UnstructuredElements)
    # We can parse what type of elements these subchunks are and load them as well
    # This will give us access to images and tables from the document
    for element in composite_element.metadata.orig_elements:
        match element.category:
            case "NarrativeText":
                text_elements.append(TextElement(id=element.id, 
                                                 text=element.text, 
                                                 type=element.category, 
                                                 page_number=element.metadata.page_number))
            case "Image":
                image_elements.append(ImageElement(id=element.id, 
                                                   text=element.text,
                                                   type=element.category, 
                                                   page_number=element.metadata.page_number,
                                                   image_base64=element.metadata.image_base64, 
                                                   image_mime_type=element.metadata.image_mime_type))
            case "Table":
                table_elements.append(TableElement(id=element.id, 
                                                   text=element.text,
                                                   type=element.category, 
                                                   page_number=element.metadata.page_number,
                                                   image_base_64=element.metadata.image_base64,
                                                   image_mime_type=element.metadata.image_mime_type,
                                                   text_as_html=element.metadata.text_as_html))
            # Assume some kind of text element if we can't match the category
            # Could be headers, figure captions, etc
            case _:
                try:
                    text_elements.append(TextElement(id=element.id, 
                                                 text=element.text, 
                                                 type=element.category, 
                                                 page_number=element.metadata.page_number))
                except Exception as e:
                    print(f"Error parsing text element: {e}")

        unstructured_element_part_of_chunk.append(UnstructuredElementPartOfChunk(unstructured_element_id=element.id, chunk_id=chunk.id))

    # we return a list of records for each entity and relationship instead of the Pydantic classes
    return {
        "nodes": {
            "chunk": [chunk.model_dump()],
            "text_element": [el.model_dump() for el in text_elements],
            "image_element": [el.model_dump() for el in image_elements],
            "table_element": [el.model_dump() for el in table_elements],
        },
        "relationships": {
            "chunk_part_of_document": [chunk_part_of_document.model_dump()],
            "unstructured_element_part_of_chunk": [rel.model_dump() for rel in unstructured_element_part_of_chunk],
        }
    }

def parse_nodes_and_relationships_from_composite_elements(composite_elements: list[CompositeElement], parent_doc_id: str) -> dict[str, dict[str, pd.DataFrame]]:
    """Parse entity nodes and document relationships for a set of chunks (CompositeElements) and their parent document"""
    
    data = {
        "nodes": {
            "document": list(),
            "chunk": list(),
            "text_element": list(),
            "image_element": list(),
            "table_element": list(),
        },
        "relationships": {
            "chunk_part_of_document": list(),
            "unstructured_element_part_of_chunk": list(),
            "chunk_has_next_chunk": list()
        }
    }

    for composite_element in composite_elements:
        new_data = parse_node_and_relationship_from_composite_element(composite_element, parent_doc_id)

        # update the records with new nodes and relationships
        data["nodes"]["chunk"].extend(new_data["nodes"]["chunk"])
        data["relationships"]["chunk_part_of_document"].extend(new_data["relationships"]["chunk_part_of_document"])
        data["nodes"]["text_element"].extend(new_data["nodes"]["text_element"])
        data["nodes"]["image_element"].extend(new_data["nodes"]["image_element"])
        data["nodes"]["table_element"].extend(new_data["nodes"]["table_element"])
        data["relationships"]["unstructured_element_part_of_chunk"].extend(new_data["relationships"]["unstructured_element_part_of_chunk"])

    # convert to pandas dataframe for ingestion
    # node DataFrames
    data["nodes"]["chunk"] = pd.DataFrame(data["nodes"]["chunk"])
    data["nodes"]["text_element"] = pd.DataFrame(data["nodes"]["text_element"])
    data["nodes"]["image_element"] = pd.DataFrame(data["nodes"]["image_element"])
    data["nodes"]["table_element"] = pd.DataFrame(data["nodes"]["table_element"])

    document_title = extract_document_title(data["nodes"]["text_element"])
    data["nodes"]["document"] = pd.DataFrame([Document(id=parent_doc_id, name=document_title, source="pubmed").model_dump()])

    # relationship DataFrames
    data["relationships"]["chunk_part_of_document"] = pd.DataFrame(data["relationships"]["chunk_part_of_document"])
    data["relationships"]["unstructured_element_part_of_chunk"] = pd.DataFrame(data["relationships"]["unstructured_element_part_of_chunk"])
    data["relationships"]["chunk_has_next_chunk"] = create_chunk_has_next_chunk_relationship_dataframe(data["nodes"]["chunk"])

    return data

# def process_xml_article(file_name: str) -> dict[str, dict[str, pd.DataFrame]]:
#     """
#     Process an article and return the nodes and relationships for ingestion into the knowledge graph.
#     Assumes that 
#     * the article name follows the format "{pmid}-{title}.xml"
#     * the article is stored in the "articles/" directory

#     Parameters
#     ----------
#     file_name : str
#         The name of the file to process.

#     Returns
#     -------
#     dict[str, dict[str, pd.DataFrame]]
#         A dictionary containing the node and relationship Pandas DataFrames for ingestion into the knowledge graph.
#     """
#     doc_id, title = parse_article_file_name(file_name)
#     parent_document = Document(id=str(uuid4()), pm_id=doc_id, name=title, source="pubmed")
#     partitioned_doc = partition_xml("articles/" + file_name, 
#                                     xml_keep_tags=False, 
#                                     chunking_strategy="by_title", 
#                                     combine_text_under_n_chars=200, 
#                                     max_characters=500, 
#                                     multipage_sections=True)
#     # return partitioned_doc
#     return parse_nodes_and_relationships_from_chunk_elements(partitioned_doc, parent_document)

def process_pdf_article(file_name: str) -> dict[str, dict[str, pd.DataFrame]]:
    """
    Process an article and return the nodes and relationships for ingestion into the knowledge graph.
    Assumes that 
    * the article name follows the format "{pmid}-{title}.pdf"
    * the article is stored in the "articles/" directory

    Parameters
    ----------
    file_name : str
        The name of the file to process.

    Returns
    -------
    dict[str, dict[str, pd.DataFrame]]
        A dictionary containing the node and relationship Pandas DataFrames for ingestion into the knowledge graph.
    """
    # doc_id, title = parse_article_file_name(file_name)
    doc_id = hashlib.sha256(file_name.encode()).hexdigest()
    # parent_document = Document(id=str(uuid4()), pm_id=file_name, name=file_name, source="pubmed")
    partitioned_doc = partition_pdf("articles/pdf/" + file_name, 
                                    strategy="hi_res",                                     
                                    extract_images_in_pdf=True,
                                    extract_image_block_types=["Image", "Table"], 
                                    extract_image_block_to_payload=True,               
                                    # extract_image_block_output_dir=f"figures/{file_name[:-4]}",
                                    chunking_strategy="by_title", 
                                    combine_text_under_n_chars=200, 
                                    max_characters=1000, 
                                    multipage_sections=True)
    # return partitioned_doc
    return parse_nodes_and_relationships_from_composite_elements(partitioned_doc, doc_id)

In [8]:
def process_pdf_articles(article_file_names: list[str]) -> dict[str, dict[str, pd.DataFrame]]:
    """
    Process a list of articles and return the nodes and relationships for ingestion into the knowledge graph.
    Assumes that the articles are stored in the "articles/pdf/" directory

    Parameters
    ----------
    article_file_names : list[str]
        A list of the names of the files to process.

    Returns
    -------
    dict[str, dict[str, pd.DataFrame]]
        A dictionary containing the node and relationship Pandas DataFrames for ingestion into the knowledge graph.
    """

    # initialize the DataFrames
    data = {
        "nodes": {
            "document": pd.DataFrame(),
            "chunk": pd.DataFrame(),
            "text_element": pd.DataFrame(),
            "image_element": pd.DataFrame(),
            "table_element": pd.DataFrame(),
        },
        "relationships": {
            "chunk_part_of_document": pd.DataFrame(),
            "unstructured_element_part_of_chunk": pd.DataFrame(),
            "chunk_has_next_chunk": pd.DataFrame()
        }
    }

    # process each article individually
    # this will
    # * partition the article into chunks using Unstructured.IO
    # * Identify TextElements, ImageElements, and TableElements in each chunk
    # * Create DataFrames for all lexical nodes and relationships found in the article
    # * Update the global DataFrames with the new article data
    for file_name in article_file_names:
        print(f"Processing article: {file_name}")
        # process each article 
        article_data = process_pdf_article(file_name)

        # update the DataFrames with the new article data
        data["nodes"]["document"] = pd.concat([data["nodes"]["document"], article_data["nodes"]["document"]], ignore_index=True)
        data["nodes"]["chunk"] = pd.concat([data["nodes"]["chunk"], article_data["nodes"]["chunk"]], ignore_index=True)
        data["nodes"]["text_element"] = pd.concat([data["nodes"]["text_element"], article_data["nodes"]["text_element"]], ignore_index=True)
        data["nodes"]["image_element"] = pd.concat([data["nodes"]["image_element"], article_data["nodes"]["image_element"]], ignore_index=True)
        data["nodes"]["table_element"] = pd.concat([data["nodes"]["table_element"], article_data["nodes"]["table_element"]], ignore_index=True)
        data["relationships"]["chunk_part_of_document"] = pd.concat([data["relationships"]["chunk_part_of_document"], article_data["relationships"]["chunk_part_of_document"]], ignore_index=True)
        data["relationships"]["unstructured_element_part_of_chunk"] = pd.concat([data["relationships"]["unstructured_element_part_of_chunk"], article_data["relationships"]["unstructured_element_part_of_chunk"]], ignore_index=True)
        data["relationships"]["chunk_has_next_chunk"] = pd.concat([data["relationships"]["chunk_has_next_chunk"], article_data["relationships"]["chunk_has_next_chunk"]], ignore_index=True)
    
    return data

## Lexical Graph Embedding

Here we will embed the text fields of our lexical graph for vector similarity search. 

In [9]:
#TODO

## Domain Graph Schema Definition

We now need to define our knowledge graph schema. This information will be passed to the entity extraction LLM to control which entities and relationships are pulled out of the text.

This is necessary to prevent our schema from growing too large with an unbounded extraction process.

We are using Pydantic to define the schema here since it can be used to validate any returned results as well. This ensures that all data we are ingesting into Neo4j adheres to this structure.

Here is what our domain graph data model looks like.

<img src="./assets/images/domain-data-model-v1.png" alt="domain-data-model" width="600px">


In [73]:
class Medication(BaseModel):
    """
    A substance used for medical treatment - a medicine or drug. 
    This is a general representation of a medication. 
    A Medication node may have relationships to StudyMedication nodes that are specific to a particular study.
    """
    
    name: str = Field(..., description="Name of the medication. Should also be uniquely identifiable.")
    medication_class: str = Field(..., description="Drug class (e.g., GLP-1 RA, SGLT2i)")
    mechanism: Optional[str] = Field(None, description="Mechanism of action")
    generic_name: Optional[str] = Field(None, description="Generic name if different from name")
    brand_names: Optional[List[str]] = Field(None, description="Commercial brand names")
    approval_status: Optional[str] = Field(None, description="FDA approval status")
    
    class Config:
        json_schema_extra = {
            "examples": [
                {
                    "name": "Semaglutide", 
                    "medication_class": "GLP-1 receptor agonist",
                    "mechanism": "GLP-1 receptor activation",
                    "generic_name": "semaglutide",
                    "brand_names": ["Ozempic", "Wegovy", "Rybelsus"],
                    "approval_status": "FDA approved"
                }
            ]
        }


class StudyMedication(BaseModel):
    """
    Study-specific medication usage - how a medication was used in a particular study
    This is an instance of a medication that is used in a particular study. 
    A StudyMedication node should also have a relationship with a Medication node.
    """
    study_name: str = Field(..., description="Name of the study. This is used to uniquely identify the StudyMedication node.")
    treatment_arm: str = Field(..., description="Treatment arm of the study medication. This uniquely identifies the StudyMedication node.")
    dosage: Optional[str] = Field(None, description="Dosage used in this study")
    route: Optional[str] = Field(None, description="Route of administration")
    frequency: Optional[str] = Field(None, description="Dosing frequency")
    treatment_duration: Optional[str] = Field(None, description="Duration of treatment")
    comparator: Optional[str] = Field(None, description="What this was compared against")
    adherence_rate: Optional[float] = Field(None, description="Treatment adherence rate")
    formulation: Optional[str] = Field(None, description="Specific formulation used")

    @computed_field(return_type=str)
    def study_medication_id(self) -> str:
        """
        The unique id of the study medication.
        This is a sha256 hash of the study name and treatment arm.
        """
        return hashlib.sha256(f"{self.study_name}_{self.treatment_arm}".encode()).hexdigest()
    
    class Config:
        json_schema_extra = {
            # don't include the study_medication_id in the example since this is computed from extracted fields
            "examples": [
                {
                    "study_name": "Study 1",
                    "treatment_arm": "Treatment arm 1",
                    "dosage": "1.0 mg",
                    "route": "subcutaneous",
                    "frequency": "weekly",
                    "treatment_duration": "12 weeks",
                    "comparator": "placebo",
                    "adherence_rate": 85.5,
                    "formulation": "pre-filled pen"
                }
            ]
        }


class ClinicalOutcome(BaseModel):
    """
    Measured clinical outcomes and biomarkers.
    This node represents a clinical outcome present in a study.
    ClinicalOutcome nodes should have relationships with other entity nodes from a study.
    ClinicalOutcome nodes should not have relationships with entities that exist outside the study.
    """
    
    study_name: str = Field(..., description="Name of the study this outcome is associated with. This is used to uniquely identify the ClinicalOutcome node.")
    name: str = Field(..., description="A concise detailed name for the outcome.")

    @computed_field(return_type=str)
    def clinical_outcome_id(self) -> str:
        """
        The unique id of the clinical outcome.
        This is a sha256 hash of the study name and the name of the outcome.
        """
        return hashlib.sha256(f"{self.study_name}_{self.name}".encode()).hexdigest()
    
    class Config:
        json_schema_extra = {
            "examples": [
                # don't include the clinical_outcome_id in the example since this is computed from extracted fields
                {
                    "study_name": "Study 1",
                    "name": "A1C controlled",
                }
            ]
        }


class MedicalCondition(BaseModel):
    """Medical conditions and comorbidities studied"""
    
    name: str = Field(..., description="Name of the medical condition")
    category: str = Field(..., description="Category of condition")
    severity: Optional[str] = Field(None, description="Severity or stage when specified")
    icd10_code: Optional[str] = Field(None, description="ICD-10 code when available")
    duration: Optional[str] = Field(None, description="Duration of condition if specified")
    
    @field_validator("icd10_code")
    def validate_icd10_code(cls, v: str) -> str:
        """
        Validate that the ICD-10 code is valid.
        """
        # ICD-10 codes are 3-7 characters long
        if len(v) < 3 or len(v) > 7:
            raise ValueError("ICD-10 code must be between 3 and 7 characters long.")
        # first character must be a letter
        elif not v[0].isalpha():
            raise ValueError("ICD-10 code must start with a letter.")
        # first character not case sensitive, can't be U, O, or I
        elif v[0].upper() in ["U", "O", "I"]:
            raise ValueError("ICD-10 code can not start with 'U', 'O', or 'I'.")
        # second character must be a digit
        elif not v[1].isdigit():
            raise ValueError("ICD-10 code second character must be a digit.")
        # '.' must separate the first 3 characters from the rest of the code
        # examples:
        # S52 Fracture of forearm
        # S52.5 Fracture of lower end of radius
        # S52.52 Torus fracture of lower end of radius
        # S52.521 Torus fracture of lower end of right radius
        # S52.521A Torus fracture of lower end of right radius, initial encounter, closed fracture
        elif len(v) > 3 and not v[3] == '.':
            raise ValueError("ICD-10 code must have a '.' after the first 3 characters.")
        return v
    
    class Config:
        json_schema_extra = {
            "examples": [
                {
                    "name": "Type 2 diabetes mellitus",
                    "category": "Primary condition", 
                    "severity": "moderate",
                    "icd10_code": "E11",
                    "duration": "5-10 years",
                }
            ]
        }


class StudyPopulation(BaseModel):
    """Patient populations and demographics in research studies"""
    
    study_name: str = Field(..., description="Name of the study. This is used to uniquely identify the StudyPopulation node.")
    description: str = Field(..., description="Description of the population")
    min_age: Optional[int] = Field(None, description="Minimum age in years")
    max_age: Optional[int] = Field(None, description="Maximum age in years")
    male_percentage: Optional[float] = Field(None, description="Percentage of male gender participants")
    female_percentage: Optional[float] = Field(None, description="Percentage of female gender participants")
    other_gender_percentage: Optional[float] = Field(None, description="Percentage of participants that identify as another gender")
    sample_size: Optional[int] = Field(None, description="Number of participants")
    study_type: str = Field(..., description="Type of study")
    location: Optional[str] = Field(None, description="Geographic location of study")
    inclusion_criteria: Optional[List[str]] = Field(None, description="Key inclusion criteria")
    exclusion_criteria: Optional[List[str]] = Field(None, description="Key exclusion criteria")
    study_duration: Optional[str] = Field(None, description="Duration of study")
    
    class Config:
        json_schema_extra = {
            "examples": [
                {
                    "study_name": "Study 1",
                    "description": "Adults with T2DM and schizophrenia",
                    "min_age": 30,
                    "max_age": 39,
                    "male_percentage": 46.0,
                    "female_percentage": 53.0,
                    "other_gender_percentage": 1.0,
                    "sample_size": 100,
                    "study_type": "Observational study",
                    "location": "Denmark",
                    "inclusion_criteria": ["Type 2 diabetes diagnosis", "Schizophrenia diagnosis", "Age ≥18"],
                    "study_duration": "12 months"
                }
            ]
        }

    @computed_field(return_type=str)
    def study_population_id(self) -> str:
        """
        The unique id of the study population.
        This is a sha256 hash of the study name.
        """
        return hashlib.sha256(f"{self.study_name}_{self.description}".encode()).hexdigest()


# Relationship classes
class StudyMedicationUsesMedication(BaseModel):
    """
    Links StudyMedication to Medication nodes.
    StudyMedication nodes should have a relationship with a Medication node.
    Pattern: (:StudyMedication)-[:USES_MEDICATION]->(:Medication)
    """
    medication_name: str
    study_medication_study_name: str = Field(..., description="Name of the study. This is used to uniquely identify the StudyMedication node.")
    study_medication_treatment_arm: str = Field(..., description="Treatment arm of the study medication. This uniquely identifies the StudyMedication node.")

    @computed_field(return_type=str)
    def study_medication_id(self) -> str:
        """
        The unique id of the study medication.
        This is a sha256 hash of the study name and treatment arm.
        """
        return hashlib.sha256(f"{self.study_medication_study_name}_{self.study_medication_treatment_arm}".encode()).hexdigest()


class StudyMedicationProducesClinicalOutcome(BaseModel):
    """
    Links StudyMedication to ClinicalOutcome nodes.
    StudyMedication nodes should have a relationship with a ClinicalOutcome node.
    Pattern: (:StudyMedication)-[:PRODUCES_CLINICAL_OUTCOME]->(:ClinicalOutcome)
    """
    study_medication_study_name: str = Field(..., description="Name of the study. This is used to uniquely identify the StudyMedication node.")
    study_medication_treatment_arm: str = Field(..., description="Treatment arm of the study medication. This uniquely identifies the StudyMedication node.")
    clinical_outcome_name: str = Field(..., description="Name of the clinical outcome")

    @computed_field(return_type=str)
    def clinical_outcome_id(self) -> str:
        """
        The unique id of the clinical outcome.
        This is a sha256 hash of the study name and the name of the outcome.
        """
        return hashlib.sha256(f"{self.study_medication_study_name}_{self.clinical_outcome_name}".encode()).hexdigest()
    
    @computed_field(return_type=str)
    def study_medication_id(self) -> str:
        """
        The unique id of the study medication.
        This is a sha256 hash of the study name and treatment arm.
        """
        return hashlib.sha256(f"{self.study_medication_study_name}_{self.study_medication_treatment_arm}".encode()).hexdigest()


class StudyPopulationHasMedicalCondition(BaseModel):
    """
    Links StudyPopulation to MedicalCondition nodes.
    StudyPopulation nodes should have a relationship with a MedicalCondition node.
    Pattern: (:StudyPopulation)-[:HAS_MEDICAL_CONDITION]->(:MedicalCondition)
    """
    study_name: str = Field(..., description="Name of the study. This is used to uniquely identify the StudyPopulation node.")
    study_population_description: str = Field(..., description="Description of the study population.")
    medical_condition_name: str

    @computed_field(return_type=str)
    def study_population_id(self) -> str:
        """
        The unique id of the study population.
        This is a sha256 hash of the study name and population description.
        """
        return hashlib.sha256(f"{self.study_name}_{self.study_population_description}".encode()).hexdigest()


class StudyPopulationReceivesStudyMedication(BaseModel):
    """
    Links StudyPopulation to StudyMedication nodes.
    StudyPopulation nodes should have a relationship with a StudyMedication node.
    Pattern: (:StudyPopulation)-[:RECEIVES_STUDY_MEDICATION]->(:StudyMedication)
    """
    study_name: str = Field(..., description="Name of the study. This is used to uniquely identify the StudyPopulation node.")
    study_population_description: str = Field(..., description="Description of the study population.")
    study_medication_treatment_arm: str = Field(..., description="Treatment arm of the study medication. This uniquely identifies the StudyMedication node.")

    @computed_field(return_type=str)
    def study_medication_id(self) -> str:
        """
        The unique id of the study medication.
        This is a sha256 hash of the study name and treatment arm.
        """
        return hashlib.sha256(f"{self.study_name}_{self.study_medication_treatment_arm}".encode()).hexdigest()

    @computed_field(return_type=str)
    def study_population_id(self) -> str:
        """
        The unique id of the study population.
        This is a sha256 hash of the study name and population description.
        """
        return hashlib.sha256(f"{self.study_name}_{self.study_population_description}".encode()).hexdigest()


class StudyPopulationHasClinicalOutcome(BaseModel):
    """
    Links StudyPopulation to ClinicalOutcome nodes.
    StudyPopulation nodes should have a relationship with a ClinicalOutcome node.
    Pattern: (:StudyPopulation)-[:HAS_CLINICAL_OUTCOME]->(:ClinicalOutcome)
    """
    study_name: str = Field(..., description="Name of the study. This is used to uniquely identify the StudyPopulation node.")
    study_population_description: str = Field(..., description="Description of the study population.")
    clinical_outcome_name: str = Field(..., description="Name of the clinical outcome to match on.")

    @computed_field(return_type=str)
    def study_population_id(self) -> str:
        """
        The unique id of the study population.
        This is a sha256 hash of the study name and population description.
        """
        return hashlib.sha256(f"{self.study_name}_{self.study_population_description}".encode()).hexdigest()

    @computed_field(return_type=str)
    def clinical_outcome_id(self) -> str:
        """
        The unique id of the clinical outcome.
        This is a sha256 hash of the study name and the name of the outcome.
        """
        return hashlib.sha256(f"{self.study_name}_{self.clinical_outcome_name}".encode()).hexdigest()

The lexical and domain knowledge graphs will be linked with `HAS_ENTITY` relationships between Chunk nodes and domain graph nodes.

This is the combined lexical and domain graph data model.

IMAGE OF DATA MODEL

## Entity Extraction via LLM

We will be using [OpenAI](https://platform.openai.com/docs/overview) and the [Instructor](https://python.useinstructor.com/) library to perform our entity extraction.

In [11]:
from openai import AsyncOpenAI
import instructor
from instructor.exceptions import IncompleteOutputException, InstructorRetryException, ValidationError

Instructor handles requesting structured outputs from the LLM. 

If the LLM fails to return output that adheres to the response models, Instructor will also handle the retry logic and pass any errors to inform corrections.

In [12]:
client = instructor.from_openai(AsyncOpenAI())

In [74]:
# the system prompt defines the overall behavior of the LLM
system_prompt = """
You are a healthcare research expert that is responsible for extracting detailed entities from PubMed articles. 
You will be provided a graph data model schema and must extract entities and relationships to populate a knowledge graph.
"""

async def extract_entities_from_text_chunk(text_chunk: str) -> list:
    """
    Extract entities and relationships from a text chunk.

    Parameters
    ----------
    text_chunk : str
        The text chunk to extract entities from.

    Returns
    -------
    list[Medication | StudyMedication | ClinicalOutcome | StudyMedicationUsesMedication | StudyMedicationProducesClinicalOutcome],
        A list of entities and relationships extracted from the text chunk.
        If the response is truncated, an empty list is returned.
        If retries are exhausted, an empty list is returned.
        If the response is invalid, an empty list is returned.
    """
    try:
        response = await client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": text_chunk}
            ],
            response_model=list[
                            # first test batch  
                            Medication | 
                            StudyMedication | 
                            ClinicalOutcome | 
                            StudyMedicationUsesMedication | 
                            StudyMedicationProducesClinicalOutcome |
                            # then add these
                            StudyPopulation |
                            MedicalCondition |
                            StudyPopulationHasMedicalCondition |
                            StudyPopulationReceivesStudyMedication |
                            StudyPopulationHasClinicalOutcome
                            ],
            temperature=0.0
        )
        return response
    except IncompleteOutputException as e:
        # Handle truncated output
        print(f"Response output truncated. Skipping chunk.")
        return list()
    except InstructorRetryException as e:
        # Handle retry exhaustion
        print(f"Failed after {e.n_attempts} attempts. Skipping chunk.")
        return list()
    except ValidationError as e:
        # Handle validation errors
        print(f"Validation failed. Skipping chunk.\nError: {e}")
        return list()

In [75]:
async def extract_entities_from_chunk_nodes(chunk_nodes_dataframe: pd.DataFrame, batch_size: int = 100) -> list[tuple[str, list[Any]]]:
    """
    Process a Pandas DataFrame of Chunk nodes and return the entities found in each chunk.

    Parameters
    ----------
    chunk_nodes_dataframe : pd.DataFrame
        A Pandas DataFrame where each row represents a Chunk node.
        Has columns `id` and `text`.
    batch_size : int
        The number of text chunks to process in each batch.

    Returns
    -------
    list[tuple[str, list[dict[str, Any]]]]
        A list of tuples, where the first element is the chunk id and the second element is a list of entities found in the chunk.
    """

    results = list()

    for batch_idx, i in enumerate(range(0, len(chunk_nodes_dataframe), batch_size)):
        if i + batch_size >= len(chunk_nodes_dataframe):
            batch = chunk_nodes_dataframe.iloc[i:]
        else:
            batch = chunk_nodes_dataframe.iloc[i:i+batch_size]
        print(f"Processing batch {batch_idx+1} of {int(len(chunk_nodes_dataframe)/(batch_size))}  \n", end="\r")
        # Create tasks for all nodes in the batch
        # order is maintained
        tasks = [extract_entities_from_text_chunk(row["text"]) for _, row in batch.iterrows()]
        # Execute all tasks concurrently
        extraction_results = await asyncio.gather(*tasks)
        # Add extracted records to the results list
        results.extend(extraction_results)

    # Return chunk_id paired with its entities from the results list
    return [(chunk_id, entities) for chunk_id, entities in zip(chunk_nodes_dataframe["id"], results)]

## Load Articles

In [15]:
import os

We need to collect the article file names to pass to Unstructured for parsing.

In [16]:
article_names = os.listdir("articles/pdf/")

## Data Ingestion

We have now defined 
* Lexical and domain data models
* Partitioning and chunking logic for articles
* Entity extraction logic for chunks

It is now time to define our ingestion logic. We will run ingest in three stages 

1. Load lexical graph
2. Embed lexical graph Chunk nodes
3. Extract domain / entity graph from lexical graph

Decoupling these stages allows us easily make changes as we iterate our ingestion process.

In [17]:
import os

from pyneoinstance import Neo4jInstance, load_yaml_file

Our database credentials and all of our queries are stored in the `pyneoinstance_config.yaml` file. 

This makes it easy to manage our queries and keeps the notebook code clean. 

In [100]:
config = load_yaml_file("pyneoinstance_config.yaml")

db_info = config['db_info']

constraints = config['initializing_queries']['constraints']
indexes = config['initializing_queries']['indexes']

node_load_queries = config['loading_queries']['nodes']
relationship_load_queries = config['loading_queries']['relationships']

processing_queries = config['processing_queries']

This graph object will handle database connections and read / write transactions for us.

In [77]:
graph = Neo4jInstance(db_info.get('uri', os.getenv("NEO4J_URI", "neo4j://localhost:7687")), # use config value -> use env value -> use default value
                      db_info.get('user', os.getenv("NEO4J_USER", "neo4j")), 
                      db_info.get('password', os.getenv("NEO4J_PASSWORD", "password")))

This is a helper function for ingesting data using the PyNeoInstance library.

In [20]:
def get_partition(data: pd.DataFrame, batch_size: int = 500) -> int:
    """
    Determine the data partition based on the desired batch size.

    Parameters
    ----------
    data : pd.DataFrame
        The Pandas DataFrame to partition.
    batch_size : int
        The desired batch size.

    Returns
    -------
    int
        The partition size.
    """
    
    partition = int(len(data) / batch_size)
    print("partition: "+str(partition if partition > 1 else 1))
    return partition if partition > 1 else 1

### Constraints

Here we write all the constraints and indexes we need for both the lexical and domain graphs

In [21]:
def create_constraints_and_indexes() -> None:
    """
    Create constraints and indexes for the lexical and domain graphs.
    """
    try:
        if constraints and len(constraints) > 0:
            graph.execute_write_queries(database=db_info['database'], queries=list(constraints.values()))
    except Exception as e:
        print(e)

    try:
        if indexes and len(indexes) > 0:
            graph.execute_write_queries(database=db_info['database'], queries=list(indexes.values()))
    except Exception as e:
        print(e)

In [22]:
create_constraints_and_indexes()




### Ingest Lexical Graph

#### Processing | Preparation

Process the articles

In [23]:
article_names

['nihms-1852972.pdf',
 'fendo-11-00178.pdf',
 'Diabetic Medicine - 2023 - Brønden - Effects of DPP‐4 inhibitors  GLP‐1 receptor agonists  SGLT‐2 inhibitors and.pdf',
 'jama_rosenstock_2019_oi_190026_1655321720.77793.pdf',
 'jciinsight-3-93936.pdf']

In [None]:
# lexical_ingest_records = process_pdf_article(article_names[0])
lexical_ingest_records = process_pdf_articles(article_names)

Processing article: nihms-1852972.pdf


Cannot set gray non-stroke color because /'R50' is an invalid float value


Processing article: fendo-11-00178.pdf


Cannot set gray non-stroke color because /'R50' is an invalid float value


Processing article: Diabetic Medicine - 2023 - Brønden - Effects of DPP‐4 inhibitors  GLP‐1 receptor agonists  SGLT‐2 inhibitors and.pdf
Processing article: jama_rosenstock_2019_oi_190026_1655321720.77793.pdf
Processing article: jciinsight-3-93936.pdf


Check the first few records 

In [25]:
lexical_ingest_records["nodes"]["document"]

Unnamed: 0,id,name,source
0,753d70915e2a2a747cee355745bd17ff08c45f90d938b1...,HHS Public Access,pubmed
1,e0048a13f033f7fe71581406cd2d3bac1ffc4db7ce88a8...,OPEN ACCESS,pubmed
2,7d560f487686b70e9b42d82c08ea1b9a43e804ebc75b74...,Abstract,pubmed
3,db75d65961558fccb83719eed3a308ce6f794c27511b36...,KeyPoints,pubmed
4,1e5cfcba99bcd1dfb5aeb0cbb4c95b7ca3f35c8210a04e...,Metformin-induced glucagon-like peptide-1 secr...,pubmed


#### Ingestion

Load the Document and Chunk nodes into the graph

In [26]:
def load_lexical_nodes(document_dataframe: pd.DataFrame, 
                       chunk_dataframe: pd.DataFrame, 
                       text_element_dataframe: pd.DataFrame, 
                       image_element_dataframe: pd.DataFrame, 
                       table_element_dataframe: pd.DataFrame) -> None:
    """
    Load lexical nodes into the graph. These include Document and Chunk nodes.

    Parameters
    ----------
    document_dataframe : pd.DataFrame
        A Pandas DataFrame of Document nodes to load into the graph. 
    chunk_dataframe : pd.DataFrame
        A Pandas DataFrame of Chunk nodes to load into the graph.
    text_element_dataframe : pd.DataFrame
        A Pandas DataFrame of TextElement nodes to load into the graph. 
    image_element_dataframe : pd.DataFrame
        A Pandas DataFrame of ImageElement nodes to load into the graph. 
    table_element_dataframe : pd.DataFrame
        A Pandas DataFrame of TableElement nodes to load into the graph. 
    """
    
    lexical_nodes_ingest_iterator = list(zip([document_dataframe, 
                                              chunk_dataframe, 
                                              text_element_dataframe, 
                                              image_element_dataframe, 
                                              table_element_dataframe], 
                                              ['document', 
                                               'chunk', 
                                               'text_element', 
                                               'image_element', 
                                               'table_element']))

    for data, query in lexical_nodes_ingest_iterator:
        res = graph.execute_write_query_with_data(database=db_info['database'], 
                                                    data=data, 
                                                    query=node_load_queries[query], 
                                                    partitions=get_partition(data, batch_size=500), 
                                                    parallel=True,
                                                    workers=2)
        print(res)

In [27]:
load_lexical_nodes(lexical_ingest_records["nodes"]["document"], 
                   lexical_ingest_records["nodes"]["chunk"], 
                   lexical_ingest_records["nodes"]["text_element"], 
                   lexical_ingest_records["nodes"]["image_element"], 
                   lexical_ingest_records["nodes"]["table_element"])

partition: 1
{'labels_added': 5, 'nodes_created': 5, 'properties_set': 15}
partition: 1
{'labels_added': 500, 'nodes_created': 500, 'properties_set': 1500}
partition: 7
{'labels_added': 7076, 'nodes_created': 3538, 'properties_set': 14152}
partition: 1
{'labels_added': 74, 'nodes_created': 37, 'properties_set': 222}
partition: 1
{'labels_added': 36, 'nodes_created': 18, 'properties_set': 108}


Load the relationships into the graph

In [28]:
def load_lexical_relationships(chunk_part_of_document_dataframe: pd.DataFrame, 
                               unstructured_element_part_of_chunk_dataframe: pd.DataFrame, 
                               chunk_has_next_chunk_dataframe: pd.DataFrame) -> None:
    """
    Load lexical relationships into the graph.

    Parameters
    ----------
    chunk_part_of_document_dataframe : pd.DataFrame
        A Pandas DataFrame of Chunk - PART_OF -> Document relationships to load into the graph.
        Should have columns `chunk_id` and `document_id`.
    unstructured_element_part_of_chunk_dataframe : pd.DataFrame
        A Pandas DataFrame of UnstructuredElement - PART_OF -> Chunk relationships to load into the graph.
        Should have columns `unstructured_element_id` and `chunk_id`.
    chunk_has_next_chunk_dataframe : pd.DataFrame
        A Pandas DataFrame of Chunk - HAS_NEXT_CHUNK -> Chunk relationships to load into the graph.
        Should have columns `source_id` and `target_id`.
    """
    lexical_relationships_ingest_iterator = list(zip([chunk_part_of_document_dataframe, 
                                                      unstructured_element_part_of_chunk_dataframe, 
                                                      chunk_has_next_chunk_dataframe], 
                                                      ['chunk_part_of_document', 
                                                       'unstructured_element_part_of_chunk', 
                                                       'chunk_has_next_chunk']))

    for data, query in lexical_relationships_ingest_iterator:
        res = graph.execute_write_query_with_data(database=db_info['database'], 
                                                    data=data, 
                                                    query=relationship_load_queries[query], 
                                                    partitions=get_partition(data, batch_size=500))
        print(res)

In [29]:
load_lexical_relationships(lexical_ingest_records["relationships"]["chunk_part_of_document"], 
                          lexical_ingest_records["relationships"]["unstructured_element_part_of_chunk"],
                          lexical_ingest_records["relationships"]["chunk_has_next_chunk"])

partition: 1
{'relationships_created': 500}
partition: 7
{'relationships_created': 3672}
partition: 1
{'relationships_created': 495}


### Embed Lexical Graph

Here we will read Chunk nodes from the graph that don't have embedding properties yet. 

We will then embed the Chunk text property and add the embedding as a property.

In [30]:
vector_index = ...

def create_vector_index() -> None:
    ...

In [31]:
def create_embeddings(driver) -> None:
    ...

def embed_lexical_graph(driver) -> None:
    ...

### Extract Entities from Lexical Graph

We will now perform entity extraction on the Chunk nodes to augment and connect to our patient journey graph.

In [32]:
def get_chunk_nodes_to_process_by_article_name(article_name: str) -> pd.DataFrame:
    """
    Retrieve Chunk node id and text from the database that have a relationship to the Document with the article name provided.
    These chunks may then be used as input to the entity extraction process.

    Parameters
    ----------
    article_name : str
        The name of the article to retrieve chunks for.

    Returns
    -------
    pd.DataFrame
        A Pandas DataFrame where each row represents a Chunk node connected to the Document with the article name provided.
        Has columns `id` and `text`.
    """ 
    _, title = parse_article_file_name(article_name)
    return graph.execute_read_query(database=db_info['database'], 
                            parameters={"article_name": title}, 
                            query=processing_queries['get_chunk_nodes_to_process_by_article_name'], 
                        )

def get_chunk_nodes_to_process(min_length: int = 100) -> pd.DataFrame:
    """
    Retrieve Chunk node id and text from the database that don't have an embedding.
    These chunks may then be used as input to the entity extraction process.

    Parameters
    ----------
    min_length : int
        The minimum length the text must be to be included in the DataFrame.

    Returns
    -------
    pd.DataFrame
        A Pandas DataFrame where each row represents a Chunk node that has text and is at least `min_length` characters long.
        Has columns `id` and `text`.
    """
    return graph.execute_read_query(database=db_info['database'], 
                            query=processing_queries['get_chunk_nodes_to_process'], 
                            parameters={"min_length": min_length},
                        )

In [33]:
chunks_to_process = get_chunk_nodes_to_process(min_length=20)



In [34]:
print(f"Found {len(chunks_to_process)} chunks to process\n")
print(f"First chunk:\n\n{chunks_to_process.loc[0,'text']}")

Found 494 chunks to process

First chunk:

statistical significance (metformin + Ex9-39 vs. placebo + Ex9-39, P = 0.053). The glucose iAUC after metformin + saline was significantly smaller than the iAUC for metformin + Ex9-39 (P = 0.004). Based on individual iAUC values, the relative contribution of GLP-1 to the acute glucose-lowering effect of metformin was 75% ± 35%, calculated as follows: 100% × ([iAUCplacebo + saline – iAUCmetformin + saline] – [iAUCplacebo + Ex9–39 – iAUCmetformin + Ex9–39])/(iAUCplacebo + saline – iAUCmetformin + saline) (P = 0.05). Using a 2-way ANOVA, both metformin and Ex9-39 were shown to significantly affect postprandial plasma glucose (iAUC) (P = 0.005 and P = 0.002, respectively), but no interaction between the 2 factors was evident. The time courses of the C-peptide/glucose ratios are illustrated in Figure 2B, and the AUCs for C-peptide/glucose, insulin/glucose, and insulin secretion


In [78]:
entity_ingest_records = await extract_entities_from_chunk_nodes(chunks_to_process[:200], batch_size=20)

Failed after 3 attempts10  


In [79]:
entity_ingest_records[:10]

[('0024e9c2d7afcf519d4d13871816a21d',
  [Medication(name='Metformin', medication_class='Biguanide', mechanism='Decreases hepatic glucose production and increases insulin sensitivity', generic_name='metformin', brand_names=['Glucophage', 'Fortamet', 'Glumetza'], approval_status='FDA approved'),
   Medication(name='Ex9-39', medication_class='GLP-1 receptor antagonist', mechanism='Blocks GLP-1 receptor', generic_name='Exendin 9-39', brand_names=None, approval_status=None),
   StudyMedication(study_name='Study on Metformin and Ex9-39', treatment_arm='Metformin + Ex9-39', dosage=None, route=None, frequency=None, treatment_duration=None, comparator='Placebo + Ex9-39', adherence_rate=None, formulation=None, study_medication_id='5028738cbc272b71e202b93d8188e3f3dd898fcb237b95dbe381f09f7a1c54c1'),
   StudyMedication(study_name='Study on Metformin and Ex9-39', treatment_arm='Metformin + Saline', dosage=None, route=None, frequency=None, treatment_duration=None, comparator='Metformin + Ex9-39', adh

### Ingest Entities Into Knowledge Graph

These functions load the extracted entities and relationships

These functions link the extracted entities with their text chunk nodes

In [89]:
ENTITY_LABELS = {
    "Medication", 
    "StudyMedication",
    "MedicalCondition",
    "StudyPopulation",
    "ClinicalOutcome",
}

ENTITY_RELS = {
    "StudyMedicationUsesMedication",
    "StudyMedicationProducesClinicalOutcome",
    "StudyPopulationHasMedicalCondition",
    "StudyPopulationReceivesStudyMedication",
    "StudyPopulationHasClinicalOutcome",
    
}

def prepare_entities_for_ingestion(entities: list[tuple[str, list[Any]]]) -> dict[str, dict[str, pd.DataFrame]]:
    """
    Prepare entities for ingestion into the graph.
    This function takes the results of the `get_chunk_nodes_to_process_by_article_name` function and returns a dictionary of entity label keys and pandas dataframes of entities.

    Parameters
    ----------
    entities : list[tuple[str, list[Any]]]
        A list of tuples, where the first element is the chunk id and the second element is a list of entities found in the chunk.
        Entities are Pydantic models that adhere to the domain graph data model.

    Returns
    -------
    dict[str, dict[str, pd.DataFrame]]
        A dictionary of entity label to pandas dataframe of entities.

        {
            "nodes": {
                "Medication": pd.DataFrame(...),
                "StudyMedication": pd.DataFrame(...),
                ...
            },
            "relationships": {
                "StudyMedicationUsesMedication": pd.DataFrame(...),
                "StudyMedicationProducesClinicalOutcome": pd.DataFrame(...),
                ...
            }
        }
    """

    records_node_dict = {lbl: list() for lbl in ENTITY_LABELS}
    records_rel_dict = {lbl: list() for lbl in ENTITY_RELS}

    for chunk_id, entities in entities:
        for entity in entities:
            to_add = entity.model_dump()
            to_add.update({"chunk_id": chunk_id})
            # nodes
            if entity.__class__.__name__ in ENTITY_LABELS:
                records_node_dict[entity.__class__.__name__].append(to_add)
            # rels
            elif entity.__class__.__name__ in ENTITY_RELS:
                records_rel_dict[entity.__class__.__name__].append(to_add)
            else:
                print(f"Unknown entity type: {entity.__class__.__name__}")

    for key, value in records_node_dict.items():
        records_node_dict[key] = pd.DataFrame(value).replace({float('nan'): None})

    for key, value in records_rel_dict.items():
        records_rel_dict[key] = pd.DataFrame(value).replace({float('nan'): None})

    return {"nodes": records_node_dict, "relationships": records_rel_dict}

In [90]:
def load_entity_nodes(medication_dataframe: pd.DataFrame, 
                      medical_condition_dataframe: pd.DataFrame, 
                      study_medication_dataframe: pd.DataFrame, 
                      study_population_dataframe: pd.DataFrame, 
                      clinical_outcome_dataframe: pd.DataFrame) -> None:
    """
    Load entity nodes into the graph.
    """
    
    entity_nodes_ingest_iterator = list(zip([medication_dataframe, 
                                             medical_condition_dataframe, 
                                             study_medication_dataframe, 
                                             study_population_dataframe, 
                                             clinical_outcome_dataframe], 
                                             ['medication', 
                                              'medical_condition', 
                                              'study_medication', 
                                              'study_population', 
                                              'clinical_outcome']))

    for data, query in entity_nodes_ingest_iterator:
        if len(data) > 0:
            print(f"Loading {len(data)} {query} nodes")
            res = graph.execute_write_query_with_data(database=db_info['database'], 
                                                    data=data, 
                                                    query=node_load_queries[query], 
                                                    partitions=get_partition(data, batch_size=500),
                                                    parallel=False)
            print(res)
        else:
            print(f"No {query} nodes to load")

In [91]:
def load_entity_relationships(study_medication_uses_medication_dataframe: pd.DataFrame,
                              study_medication_produces_clinical_outcome_dataframe: pd.DataFrame,
                              study_population_has_medical_condition_dataframe: pd.DataFrame,
                              study_population_receives_study_medication_dataframe: pd.DataFrame,
                              study_population_has_outcome_dataframe: pd.DataFrame,
                              ) -> None:
    """
    Load entity relationships into the graph.
    """
    entity_relationships_ingest_iterator = list(zip([study_medication_uses_medication_dataframe, 
                                                      study_medication_produces_clinical_outcome_dataframe, 
                                                      study_population_has_medical_condition_dataframe, 
                                                      study_population_receives_study_medication_dataframe, 
                                                      study_population_has_outcome_dataframe], 
                                                      ['study_medication_uses_medication', 
                                                       'study_medication_produces_clinical_outcome', 
                                                       'study_population_has_medical_condition', 
                                                       'study_population_receives_study_medication', 
                                                       'study_population_has_clinical_outcome']))
    
    for data, query in entity_relationships_ingest_iterator:
        if len(data) > 0:
            print(f"Loading {len(data)} {query} relationships")
            res = graph.execute_write_query_with_data(database=db_info['database'], 
                                                    data=data, 
                                                    query=relationship_load_queries[query], 
                                                    partitions=get_partition(data, batch_size=500),
                                                    parallel=False)
            print(res)
        else:
            print(f"No {query} relationships to load")

In [92]:
def link_entities_to_chunks(medication_link_dataframe: pd.DataFrame, 
                      medical_condition_link_dataframe: pd.DataFrame, 
                      study_medication_link_dataframe: pd.DataFrame, 
                      study_population_link_dataframe: pd.DataFrame, 
                      clinical_outcome_link_dataframe: pd.DataFrame) -> None:
    """
    Link entities to chunks.
    """
    entity_link_iterator = list(zip([medication_link_dataframe, 
                                     medical_condition_link_dataframe, 
                                     study_medication_link_dataframe, 
                                     study_population_link_dataframe, 
                                     clinical_outcome_link_dataframe], 
                                     ["chunk_has_entity_medication",
                                      "chunk_has_entity_medical_condition",
                                      "chunk_has_entity_study_medication",
                                      "chunk_has_entity_study_population",
                                      "chunk_has_entity_clinical_outcome"]))
    
    for data, query in entity_link_iterator:
        if len(data) > 0:
            print(f"Linking {len(data)} {query} entities to chunks")
            res = graph.execute_write_query_with_data(database=db_info['database'], 
                                                    data=data, 
                                                    query=relationship_load_queries[query], 
                                                    partitions=get_partition(data, batch_size=500),
                                                    parallel=False)
            print(res)
        else:
            print(f"No {query} relationships to load")

In [93]:
ingest_records = prepare_entities_for_ingestion(entity_ingest_records)

In [94]:
ingest_records['nodes']['StudyPopulation']

Unnamed: 0,study_name,description,min_age,max_age,male_percentage,female_percentage,other_gender_percentage,sample_size,study_type,location,inclusion_criteria,exclusion_criteria,study_duration,study_population_id,chunk_id
0,Study 1,Patients with established atherosclerotic card...,,,56.19,43.81,,507.0,Randomized Controlled Trial,Not specified,,,2 years,8e2f978b6e2de6b3a215313941adbd8fb5e42e5f92e9b9...,097f93c7cc27f257d4c04b3153765e6b
1,Study 2,Patients with type 2 diabetes and high cardiov...,,,56.19,43.81,,129.0,Randomized Controlled Trial,Not specified,,,5 weeks,ebf71b271dc74f0d2faf079da304d66ce3ad27a6ae5dd4...,097f93c7cc27f257d4c04b3153765e6b
2,Study 3,Patients with type 2 diabetes and high cardiov...,,,56.19,43.81,,199.0,Randomized Controlled Trial,Not specified,,,7 weeks,c8a58b3f7946892e1ec52bac322ad74763258b2b72ab69...,097f93c7cc27f257d4c04b3153765e6b
3,Study 4,Patients with type 2 diabetes and high cardiov...,,,56.19,43.81,,182.0,Randomized Controlled Trial,Not specified,,,2 weeks,29ac0b01370f069064a1e3964872ef22f46d65bead46ca...,097f93c7cc27f257d4c04b3153765e6b
4,Current Trial,Diverse population with type 2 diabetes,,,,,,,Comparative-effectiveness trial,,[Participants with type 2 diabetes],[Coexisting conditions],Long duration,eb4b9743a7d94b8a88a04ffababde3f97ac8e55d312d4f...,09d8e609d9e0e28eb1dcb67351e66f83
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
109,Study on Semaglutide and Sitagliptin,Patients in the study on semaglutide and sitag...,,,,,,1864.0,Clinical Trial,,,,,f3268fb59f03256893d27ed1126b44ca79236535fd5d71...,6176451015999ec938bb4c38f53b667e
110,Pioneer 7,Patients with T2DM,,,,,,504.0,Randomized clinical trial,81 sites in 10 countries,,,,681d016419a23edf1780963e4d677fd5393b4c9c9e0518...,621decedae713b6e90013ab0e4f33e42
111,Study 2,Patients with established type 2 diabetes (T2D...,18.0,,,,,507.0,Randomized Controlled Trial,Unknown,"[Type 2 diabetes diagnosis, BMI ≥ 30 kg/m², Ag...",,26 weeks,85ecc31e99d6f48c59d0e25332279838ecc36ba6e789b8...,62647d63971da0c08fc7aecca4780acb
112,Ahren et al. (48),Patients with T2DM on metformin monotherapy,,,,,,,Randomized controlled trial,,[T2DM on metformin monotherapy],,104 weeks,e94f139e4ac6c873d9ff2f00e598cd72e681f32b74a8c8...,62c8f7607d7c903b4179eaba82922516


In [95]:
load_entity_nodes(ingest_records["nodes"]["Medication"], 
                  ingest_records["nodes"]["MedicalCondition"], 
                  ingest_records["nodes"]["StudyMedication"], 
                  ingest_records["nodes"]["StudyPopulation"], 
                  ingest_records["nodes"]["ClinicalOutcome"])

Loading 229 medication nodes
partition: 1
{'labels_added': 74, 'nodes_created': 74, 'properties_set': 1219}
Loading 69 medical_condition nodes
partition: 1
{'labels_added': 28, 'nodes_created': 28, 'properties_set': 373}
Loading 105 study_medication nodes
partition: 1
{'labels_added': 94, 'nodes_created': 94, 'properties_set': 1039}
Loading 114 study_population nodes
partition: 1
{'labels_added': 96, 'nodes_created': 96, 'properties_set': 1464}
Loading 146 clinical_outcome nodes
partition: 1
{'labels_added': 141, 'nodes_created': 141, 'properties_set': 433}


In [96]:
load_entity_relationships(ingest_records["relationships"]["StudyMedicationUsesMedication"], 
                          ingest_records["relationships"]["StudyMedicationProducesClinicalOutcome"], 
                          ingest_records["relationships"]["StudyPopulationHasMedicalCondition"], 
                          ingest_records["relationships"]["StudyPopulationReceivesStudyMedication"], 
                          ingest_records["relationships"]["StudyPopulationHasClinicalOutcome"])

Loading 77 study_medication_uses_medication relationships
partition: 1
{'relationships_created': 65}
Loading 76 study_medication_produces_clinical_outcome relationships
partition: 1
{'relationships_created': 66}
Loading 45 study_population_has_medical_condition relationships
partition: 1
{'relationships_created': 37}
Loading 35 study_population_receives_study_medication relationships
partition: 1
{'relationships_created': 34}
Loading 39 study_population_has_clinical_outcome relationships
partition: 1
{'relationships_created': 37}


In [97]:
link_entities_to_chunks(ingest_records["nodes"]["Medication"], 
                        ingest_records["nodes"]["MedicalCondition"], 
                        ingest_records["nodes"]["StudyMedication"], 
                        ingest_records["nodes"]["StudyPopulation"], 
                        ingest_records["nodes"]["ClinicalOutcome"])

Linking 229 chunk_has_entity_medication entities to chunks
partition: 1
{'relationships_created': 229}
Linking 69 chunk_has_entity_medical_condition entities to chunks
partition: 1
{'relationships_created': 69}
Linking 105 chunk_has_entity_study_medication entities to chunks
partition: 1
{'relationships_created': 102}
Linking 114 chunk_has_entity_study_population entities to chunks
partition: 1
{'relationships_created': 110}
Linking 146 chunk_has_entity_clinical_outcome entities to chunks
partition: 1
{'relationships_created': 146}


In [103]:
def link_domain_and_patient_journey_graph() -> None:
    """
    Link the domain graph with the patient journey graph. 
    This process doesn't require any input DataFrames. 
    Instead it attempts to link nodes based on matching properties.
    """

    queries = ["demographic_in_study_population"]

    for q in queries:
        res = graph.execute_write_query(database=db_info['database'], 
                                        query=relationship_load_queries[q])
    print(res)

In [104]:
link_domain_and_patient_journey_graph()

{}
