# PubMed Knowledge Graph

This notebook is part of a series that walks through the process of generating a knowledge graph of PubMed articles.

This notebook will
* Extract entities from the Chunk nodes in a Neo4j graph according to the defined schema
* Load the entities and connect them with their respective Chunk nodes
* Connect the entities as defined by the Domain Graph Data Model
* Connect extracted entities with existing patient journey graph

In [1]:
# filter some Numpy warnings that pop up during ingestion
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

In [2]:
import asyncio
import hashlib
from typing import Any, Optional, List

In [3]:
# allows for async operations in notebooks
import nest_asyncio
nest_asyncio.apply()

In [4]:
import pandas as pd
from pydantic import BaseModel, Field, computed_field, field_validator

## Domain Graph Schema Definition

We now need to define our knowledge graph schema. This information will be passed to the entity extraction LLM to control which entities and relationships are pulled out of the text.

This is necessary to prevent our schema from growing too large with an unbounded extraction process.

We are using Pydantic to define the schema here since it can be used to validate any returned results as well. This ensures that all data we are ingesting into Neo4j adheres to this structure.

Here is what our domain graph data model looks like.

<img src="./assets/images/domain-data-model-v1.png" alt="domain-data-model" width="600px">


In [5]:
# -------------
# Nodes
# -------------

class Medication(BaseModel):
    """
    A substance used for medical treatment - a medicine or drug. 
    This is a general representation of a medication. 
    A Medication node may have relationships to TreatmentArm nodes that are specific to a particular study.
    """
    
    name: str = Field(..., description="Name of the medication. Should also be uniquely identifiable. Do not include dosage, administration, frequency, or other details.")
    medication_class: str = Field(..., description="Drug class (e.g., GLP-1 RA, SGLT2i)")
    mechanism: Optional[str] = Field(None, description="Mechanism of action")
    generic_name: Optional[str] = Field(None, description="Generic name of the medication")
    brand_names: Optional[List[str]] = Field(None, description="Commercial brand names")
    approval_status: Optional[str] = Field(None, description="FDA approval status")
    
    class Config:
        json_schema_extra = {
            "examples": [
                {
                    "name": "semaglutide", 
                    "medication_class": "GLP-1 receptor agonist",
                    "mechanism": "GLP-1 receptor activation",
                    "generic_name": "semaglutide",
                    "brand_names": ["ozempic", "wegovy", "rybelsus"],
                    "approval_status": "FDA approved"
                }
            ]
        }

    @field_validator("name", "medication_class")
    def validate_lower_case(cls, v: str) -> str:
        """
        Validate that the field value is all lower case.
        """
        return v.lower()
    
    @field_validator("generic_name")
    def validate_generic_name(cls, v: str | None) -> str | None:
        """
        Validate that the generic name is all lower case.
        """
        if v is not None:
            return v.lower()
        return v
    
    @field_validator("brand_names")
    def validate_brand_names(cls, v: list[str] | None) -> list[str] | None:
        """
        Validate that the brand names are all lower case.
        """
        if v is not None:
            return [name.lower() for name in v]
        return v


class TreatmentArm(BaseModel):
    """
    A treatment arm is an explicit instance of a participant group in a study that receive the same treatment.
    A treatment arm should have relationships to Medication and ClinicalOutcome nodes.
    """
    study_name: str = Field(..., description="Name of the study. This is used to uniquely identify the TreatmentArm node.")
    name: str = Field(..., description="Name of the treatment arm")

    class Config:
        json_schema_extra = {
            "examples": [
                {
                    "study_name": "Study 1",
                    "name": "Treatment arm 1",
                }
            ]
        }

    @computed_field(return_type=str)
    def treatment_arm_id(self) -> str:
        """
        The unique id of the treatment arm.
        This is a sha256 hash of the study name and treatment arm name.
        """
        return hashlib.sha256(f"{self.study_name}_{self.name}".encode()).hexdigest()
    

class ClinicalOutcome(BaseModel):
    """
    Measured clinical outcomes and biomarkers.
    This node represents a clinical outcome present in a study.
    ClinicalOutcome nodes should have relationships with other entity nodes from a study.
    ClinicalOutcome nodes should not have relationships with entities that exist outside the study.
    """
    
    study_name: str = Field(..., description="Name of the study this outcome is associated with. This is used to uniquely identify the ClinicalOutcome node.")
    name: str = Field(..., description="A concise detailed name for the outcome.")

    @computed_field(return_type=str)
    def clinical_outcome_id(self) -> str:
        """
        The unique id of the clinical outcome.
        This is a sha256 hash of the study name and the name of the outcome.
        """
        return hashlib.sha256(f"{self.study_name}_{self.name}".encode()).hexdigest()
    
    class Config:
        json_schema_extra = {
            "examples": [
                # don't include the clinical_outcome_id in the example since this is computed from extracted fields
                {
                    "study_name": "Study 1",
                    "name": "A1C controlled",
                }
            ]
        }


class MedicalCondition(BaseModel):
    """Medical conditions and comorbidities studied"""
    
    name: str = Field(..., description="Name of the medical condition")
    category: str = Field(..., description="Category of condition")
    icd10_code: Optional[str] = Field(None, description="ICD-10 code when available")
    
    @field_validator("icd10_code")
    def validate_icd10_code(cls, v: str) -> str:
        """
        Validate that the ICD-10 code is valid.
        """
        # ICD-10 codes are 3-7 characters long
        if len(v) < 3 or len(v) > 7:
            raise ValueError("ICD-10 code must be between 3 and 7 characters long.")
        # first character must be a letter
        elif not v[0].isalpha():
            raise ValueError("ICD-10 code must start with a letter.")
        # first character not case sensitive, can't be U, O, or I
        elif v[0].upper() in ["U", "O", "I"]:
            raise ValueError("ICD-10 code can not start with 'U', 'O', or 'I'.")
        # second character must be a digit
        elif not v[1].isdigit():
            raise ValueError("ICD-10 code second character must be a digit.")
        # '.' must separate the first 3 characters from the rest of the code
        # examples:
        # S52 Fracture of forearm
        # S52.5 Fracture of lower end of radius
        # S52.52 Torus fracture of lower end of radius
        # S52.521 Torus fracture of lower end of right radius
        # S52.521A Torus fracture of lower end of right radius, initial encounter, closed fracture
        elif len(v) > 3 and not v[3] == '.':
            raise ValueError("ICD-10 code must have a '.' after the first 3 characters.")
        return v
    
    class Config:
        json_schema_extra = {
            "examples": [
                {
                    "name": "Type 2 diabetes mellitus",
                    "category": "diabetes",
                    "icd10_code": "E11",
                }
            ]
        }


class StudyPopulation(BaseModel):
    """Patient populations and demographics in research studies"""
    
    study_name: str = Field(..., description="Name of the study. This is used to uniquely identify the StudyPopulation node.")
    description: str = Field(..., description="Description of the population")
    min_age: Optional[int] = Field(None, description="Minimum age in years")
    max_age: Optional[int] = Field(None, description="Maximum age in years")
    male_percentage: Optional[float] = Field(None, description="Percentage of male gender participants")
    female_percentage: Optional[float] = Field(None, description="Percentage of female gender participants")
    other_gender_percentage: Optional[float] = Field(None, description="Percentage of participants that identify as another gender")
    sample_size: Optional[int] = Field(None, description="Number of participants")
    study_type: str = Field(..., description="Type of study")
    location: Optional[str] = Field(None, description="Geographic location of study")
    inclusion_criteria: Optional[List[str]] = Field(None, description="Key inclusion criteria")
    exclusion_criteria: Optional[List[str]] = Field(None, description="Key exclusion criteria")
    study_duration: Optional[str] = Field(None, description="Duration of study")
    
    class Config:
        json_schema_extra = {
            "examples": [
                {
                    "study_name": "Study 1",
                    "description": "Adults with T2DM and schizophrenia",
                    "min_age": 30,
                    "max_age": 39,
                    "male_percentage": 46.0,
                    "female_percentage": 53.0,
                    "other_gender_percentage": 1.0,
                    "sample_size": 100,
                    "study_type": "Observational study",
                    "location": "Denmark",
                    "inclusion_criteria": ["Type 2 diabetes diagnosis", "Schizophrenia diagnosis", "Age ≥18"],
                    "study_duration": "12 months"
                }
            ]
        }

    @computed_field(return_type=str)
    def study_population_id(self) -> str:
        """
        The unique id of the study population.
        This is a sha256 hash of the study name.
        """
        return hashlib.sha256(f"{self.study_name}_{self.description}".encode()).hexdigest()


# -------------
# Relationships
# -------------

class MedicationUsedInTreatmentArm(BaseModel):
    """
    Study-specific medication usage - how a Medication was used in a particular TreatmentArm.
    This describes an instance of a medication that is used in a particular treatment arm. 
    All treatment arms should have a relationship with at least one Medication node.
    """
    study_name: str = Field(..., description="Name of the study.")
    treatment_arm_name: str = Field(..., description="Name of the treatment arm.")
    medication_name: str = Field(..., description="Name of the medication.")
    dosage: Optional[str] = Field(None, description="Dosage used in this study")
    route: Optional[str] = Field(None, description="Route of administration")
    frequency: Optional[str] = Field(None, description="Dosing frequency")
    treatment_duration: Optional[str] = Field(None, description="Duration of treatment")
    comparator: Optional[str] = Field(None, description="What this was compared against")
    adherence_rate: Optional[float] = Field(None, description="Treatment adherence rate")
    formulation: Optional[str] = Field(None, description="Specific formulation used")

    @computed_field(return_type=str)
    def treatment_arm_id(self) -> str:
        """
        The unique id of the treatment arm.
        This is a sha256 hash of the study name and treatment arm.
        """
        return hashlib.sha256(f"{self.study_name}_{self.treatment_arm_name}".encode()).hexdigest()
    
    @field_validator("medication_name")
    def validate_medication_name(cls, v: str) -> str:
        """
        Validate that the medication name is all lower case.
        """
        return v.lower()
    
    class Config:
        json_schema_extra = {
            # don't include the study_medication_id in the example since this is computed from extracted fields
            "examples": [
                {
                    "study_name": "Study 1",
                    "treatment_arm_name": "Treatment arm 1",
                    "medication_name": "Medication 1",
                    "dosage": "1.0 mg",
                    "route": "subcutaneous",
                    "frequency": "weekly",
                    "treatment_duration": "12 weeks",
                    "comparator": "placebo",
                    "adherence_rate": 85.5,
                    "formulation": "pre-filled pen"
                }
            ]
        }


class TreatmentArmHasClinicalOutcome(BaseModel):
    """
    Links TreatmentArm to ClinicalOutcome nodes.
    TreatmentArm nodes should have a relationship with a ClinicalOutcome node.
    Pattern: (:TreatmentArm)-[:HAS_CLINICAL_OUTCOME]->(:ClinicalOutcome)
    """
    study_name: str = Field(..., description="Name of the study. This is used to uniquely identify the TreatmentArm node.")
    treatment_arm_name: str = Field(..., description="Name of the treatment arm.")
    clinical_outcome_name: str = Field(..., description="Name of the clinical outcome")

    @computed_field(return_type=str)
    def clinical_outcome_id(self) -> str:
        """
        The unique id of the clinical outcome.
        This is a sha256 hash of the study name and the name of the outcome.
        """
        return hashlib.sha256(f"{self.study_name}_{self.clinical_outcome_name}".encode()).hexdigest()
    
    @computed_field(return_type=str)
    def treatment_arm_id(self) -> str:
        """
        The unique id of the treatment arm.
        This is a sha256 hash of the study name and treatment arm.
        """
        return hashlib.sha256(f"{self.study_name}_{self.treatment_arm_name}".encode()).hexdigest()


class StudyPopulationHasMedicalCondition(BaseModel):
    """
    Links StudyPopulation to MedicalCondition nodes.
    StudyPopulation nodes should have a relationship with a MedicalCondition node.
    Pattern: (:StudyPopulation)-[:HAS_MEDICAL_CONDITION]->(:MedicalCondition)
    """
    study_name: str = Field(..., description="Name of the study. This is used to uniquely identify the StudyPopulation node.")
    study_population_description: str = Field(..., description="Description of the study population.")
    medical_condition_name: str

    @computed_field(return_type=str)
    def study_population_id(self) -> str:
        """
        The unique id of the study population.
        This is a sha256 hash of the study name and population description.
        """
        return hashlib.sha256(f"{self.study_name}_{self.study_population_description}".encode()).hexdigest()


class StudyPopulationInTreatmentArm(BaseModel):
    """
    Links StudyPopulation to TreatmentArm nodes.
    StudyPopulation nodes should have a relationship with a TreatmentArm node.
    Pattern: (:StudyPopulation)-[:IN_TREATMENT_ARM]->(:TreatmentArm)
    """
    study_name: str = Field(..., description="Name of the study. This is used to uniquely identify the StudyPopulation node.")
    study_population_description: str = Field(..., description="Description of the study population.")
    treatment_arm_name: str = Field(..., description="Name of the treatment arm.")

    @computed_field(return_type=str)
    def treatment_arm_id(self) -> str:
        """
        The unique id of the treatment arm.
        This is a sha256 hash of the study name and treatment arm.
        """
        return hashlib.sha256(f"{self.study_name}_{self.treatment_arm_name}".encode()).hexdigest()

    @computed_field(return_type=str)
    def study_population_id(self) -> str:
        """
        The unique id of the study population.
        This is a sha256 hash of the study name and population description.
        """
        return hashlib.sha256(f"{self.study_name}_{self.study_population_description}".encode()).hexdigest()

The lexical and domain knowledge graphs will be linked with `HAS_ENTITY` relationships between Chunk nodes and domain graph nodes.

This is the combined lexical and domain graph data model.

IMAGE OF DATA MODEL

## Entity Extraction via LLM

We will be using [OpenAI](https://platform.openai.com/docs/overview) and the [Instructor](https://python.useinstructor.com/) library to perform our entity extraction.

In [6]:
from openai import AsyncOpenAI
import instructor
from instructor.exceptions import IncompleteOutputException, InstructorRetryException, ValidationError

Instructor handles requesting structured outputs from the LLM. 

If the LLM fails to return output that adheres to the response models, Instructor will also handle the retry logic and pass any errors to inform corrections.

In [7]:
client = instructor.from_openai(AsyncOpenAI())

In [8]:
# the system prompt defines the overall behavior of the LLM
system_prompt = """You are a healthcare research expert that is responsible for extracting detailed entities from PubMed articles. 
You will be provided a graph data model schema and must extract entities and relationships to populate a knowledge graph."""

user_prompt = """Rules:
* Use the provided schema to extract entities and relationships from the provided text.
* Follow the schema desciptions strictly.
* If a field is not provided, do not include it in the response. It should be null.
* If no entities are found, return an empty list.

Text Chunk:
{text_chunk}"""

async def extract_entities_from_text_chunk(text_chunk: str, chunk_id: str) -> list:
    """
    Extract entities and relationships from a text chunk.

    Parameters
    ----------
    text_chunk : str
        The text chunk to extract entities from.
    chunk_id : str
        The id of the text chunk. Used for debugging.

    Returns
    -------
    list[
        Medication | 
        TreatmentArm | 
        ClinicalOutcome | 
        MedicationUsedInTreatmentArm | 
        TreatmentArmHasClinicalOutcome |
        StudyPopulation |
        MedicalCondition |
        StudyPopulationHasMedicalCondition |
        StudyPopulationInTreatmentArm
        ]
        A list of entities and relationships extracted from the text chunk.
        If the response is truncated, an empty list is returned.
        If retries are exhausted, an empty list is returned.
        If the response is invalid, an empty list is returned.
    """
    try:
        response = await client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt.format(text_chunk=text_chunk)}
            ],
            response_model=list[
                            Medication | 
                            TreatmentArm | 
                            ClinicalOutcome | 
                            MedicationUsedInTreatmentArm | 
                            TreatmentArmHasClinicalOutcome |
                            StudyPopulation |
                            MedicalCondition |
                            StudyPopulationHasMedicalCondition |
                            StudyPopulationInTreatmentArm
                            ],
            temperature=0.0
        )
        return response
    except IncompleteOutputException as e:
        # Handle truncated output
        print(f"Response output truncated. Skipping chunk {chunk_id}.")
        return list()
    except InstructorRetryException as e:
        # Handle retry exhaustion
        print(f"Failed after {e.n_attempts} attempts. Skipping chunk {chunk_id}.")
        return list()
    except ValidationError as e:
        # Handle validation errors
        print(f"Validation failed. Skipping chunk {chunk_id}.\nError: {e}")
        return list()

In [9]:
async def extract_entities_from_chunk_nodes(chunk_nodes_dataframe: pd.DataFrame, batch_size: int = 100) -> list[tuple[str, list[Any]]]:
    """
    Process a Pandas DataFrame of Chunk nodes and return the entities found in each chunk.

    Parameters
    ----------
    chunk_nodes_dataframe : pd.DataFrame
        A Pandas DataFrame where each row represents a Chunk node.
        Has columns `id` and `text`.
    batch_size : int
        The number of text chunks to process in each batch.

    Returns
    -------
    list[tuple[str, list[dict[str, Any]]]]
        A list of tuples, where the first element is the chunk id and the second element is a list of entities found in the chunk.
    """

    results = list()

    for batch_idx, i in enumerate(range(0, len(chunk_nodes_dataframe), batch_size)):
        if i + batch_size >= len(chunk_nodes_dataframe):
            batch = chunk_nodes_dataframe.iloc[i:]
        else:
            batch = chunk_nodes_dataframe.iloc[i:i+batch_size]
        print(f"Processing batch {batch_idx+1} of {int(len(chunk_nodes_dataframe)/(batch_size))}  \n", end="\r")
        # TODO: implement cache of failed chunks
        
        # Create tasks for all nodes in the batch
        # order is maintained
        tasks = [extract_entities_from_text_chunk(row["text"], row['id']) for _, row in batch.iterrows()]
        # Execute all tasks concurrently
        extraction_results = await asyncio.gather(*tasks)
        # Add extracted records to the results list
        results.extend(extraction_results)

    # Return chunk_id paired with its entities from the results list
    return [(chunk_id, entities) for chunk_id, entities in zip(chunk_nodes_dataframe["id"], results)]

## Data Ingestion

We have now defined 
* Domain data model
* Entity extraction logic for chunks

It is now time to define our ingestion logic. We will run ingest in two stages 

1. Extract Domain / Entity Graph from lexical graph Chunk nodes
2. Ingest entities into Domain Graph

Decoupling these stages allows us easily make changes as we iterate our ingestion process.

In [10]:
import os

from pyneoinstance import Neo4jInstance, load_yaml_file

Our database credentials and all of our queries are stored in the `pyneoinstance_config.yaml` file. 

This makes it easy to manage our queries and keeps the notebook code clean. 

In [42]:
config = load_yaml_file("pyneoinstance_config.yaml")

db_info = config['db_info']

constraints = config['initializing_queries']['constraints']
indexes = config['initializing_queries']['indexes']

node_load_queries = config['loading_queries']['nodes']
relationship_load_queries = config['loading_queries']['relationships']

processing_queries = config['processing_queries']

This graph object will handle database connections and read / write transactions for us.

In [12]:
graph = Neo4jInstance(db_info.get('uri', os.getenv("NEO4J_URI", "neo4j://localhost:7687")), # use config value -> use env value -> use default value
                      db_info.get('user', os.getenv("NEO4J_USER", "neo4j")), 
                      db_info.get('password', os.getenv("NEO4J_PASSWORD", "password")))

This is a helper function for ingesting data using the PyNeoInstance library.

In [13]:
def get_partition(data: pd.DataFrame, batch_size: int = 500) -> int:
    """
    Determine the data partition based on the desired batch size.

    Parameters
    ----------
    data : pd.DataFrame
        The Pandas DataFrame to partition.
    batch_size : int
        The desired batch size.

    Returns
    -------
    int
        The partition size.
    """
    
    partition = int(len(data) / batch_size)
    print("partition: "+str(partition if partition > 1 else 1))
    return partition if partition > 1 else 1

### Constraints

Here we write all the constraints and indexes we need for both the lexical and domain graphs

In [14]:
def create_constraints_and_indexes() -> None:
    """
    Create constraints and indexes for the lexical and domain graphs.
    """
    try:
        if constraints and len(constraints) > 0:
            graph.execute_write_queries(database=db_info['database'], queries=list(constraints.values()))
    except Exception as e:
        print(e)

    try:
        if indexes and len(indexes) > 0:
            graph.execute_write_queries(database=db_info['database'], queries=list(indexes.values()))
    except Exception as e:
        print(e)

In [15]:
create_constraints_and_indexes()

### Extract Entities from Lexical Graph

We will now perform entity extraction on the Chunk nodes to augment and connect to our patient journey graph.

In [16]:
def get_chunk_nodes_to_process(min_length: int = 100) -> pd.DataFrame:
    """
    Retrieve Chunk node id and text from the database that don't have an embedding.
    These chunks may then be used as input to the entity extraction process.

    Parameters
    ----------
    min_length : int
        The minimum length the text must be to be included in the DataFrame.

    Returns
    -------
    pd.DataFrame
        A Pandas DataFrame where each row represents a Chunk node that has text and is at least `min_length` characters long.
        Has columns `id` and `text`.
    """
    return graph.execute_read_query(database=db_info['database'], 
                            query=processing_queries['get_chunk_nodes_to_process'], 
                            parameters={"min_length": min_length},
                        )

In [17]:
chunks_to_process = get_chunk_nodes_to_process(min_length=20)



In [18]:
print(f"Found {len(chunks_to_process)} chunks to process\n")
print(f"First chunk:\n\n{chunks_to_process.loc[0,'text']}")

Found 494 chunks to process

First chunk:

statistical significance (metformin + Ex9-39 vs. placebo + Ex9-39, P = 0.053). The glucose iAUC after metformin + saline was significantly smaller than the iAUC for metformin + Ex9-39 (P = 0.004). Based on individual iAUC values, the relative contribution of GLP-1 to the acute glucose-lowering effect of metformin was 75% ± 35%, calculated as follows: 100% × ([iAUCplacebo + saline – iAUCmetformin + saline] – [iAUCplacebo + Ex9–39 – iAUCmetformin + Ex9–39])/(iAUCplacebo + saline – iAUCmetformin + saline) (P = 0.05). Using a 2-way ANOVA, both metformin and Ex9-39 were shown to significantly affect postprandial plasma glucose (iAUC) (P = 0.005 and P = 0.002, respectively), but no interaction between the 2 factors was evident. The time courses of the C-peptide/glucose ratios are illustrated in Figure 2B, and the AUCs for C-peptide/glucose, insulin/glucose, and insulin secretion


In [19]:
extracted_entities_with_chunk_ids = await extract_entities_from_chunk_nodes(chunks_to_process[:200], batch_size=20)

Processing batch 1 of 10  
Failed after 3 attempts. Skipping chunk.
Processing batch 2 of 10  
Processing batch 3 of 10  
Processing batch 4 of 10  
Failed after 3 attempts. Skipping chunk.
Processing batch 5 of 10  
Processing batch 6 of 10  
Processing batch 7 of 10  
Processing batch 8 of 10  
Processing batch 9 of 10  
Processing batch 10 of 10  


### Ingest Entities Into Knowledge Graph

These functions load the extracted entities and relationships

These functions link the extracted entities with their text chunk nodes

In [20]:
ENTITY_LABELS = {
    "Medication", 
    "TreatmentArm",
    "MedicalCondition",
    "StudyPopulation",
    "ClinicalOutcome",
}

ENTITY_RELS = {
    "MedicationUsedInTreatmentArm",
    "TreatmentArmHasClinicalOutcome",
    "StudyPopulationInTreatmentArm",
    "StudyPopulationHasMedicalCondition",
}

def prepare_entities_for_ingestion(entities: list[tuple[str, list[Any]]]) -> dict[str, dict[str, pd.DataFrame]]:
    """
    Prepare entities for ingestion into the graph.
    This function takes the results of the `get_chunk_nodes_to_process_by_article_name` function and returns a dictionary of entity label keys and Pandas DataFrames of entities.

    Parameters
    ----------
    entities : list[tuple[str, list[Any]]]
        A list of tuples, where the first element is the chunk id and the second element is a list of entities found in the chunk.
        Entities are Pydantic models that adhere to the domain graph data model.

    Returns
    -------
    dict[str, dict[str, pd.DataFrame]]
        A dictionary of entity label to pandas dataframe of entities.

        {
            "nodes": {
                "Medication": pd.DataFrame(...),
                "StudyMedication": pd.DataFrame(...),
                ...
            },
            "relationships": {
                "StudyMedicationUsesMedication": pd.DataFrame(...),
                "StudyMedicationProducesClinicalOutcome": pd.DataFrame(...),
                ...
            }
        }
    """

    records_node_dict = {lbl: list() for lbl in ENTITY_LABELS}
    records_rel_dict = {lbl: list() for lbl in ENTITY_RELS}

    for chunk_id, entities in entities:
        for entity in entities:
            to_add = entity.model_dump()
            to_add.update({"chunk_id": chunk_id})
            # nodes
            if entity.__class__.__name__ in ENTITY_LABELS:
                records_node_dict[entity.__class__.__name__].append(to_add)
            # rels
            elif entity.__class__.__name__ in ENTITY_RELS:
                records_rel_dict[entity.__class__.__name__].append(to_add)
            else:
                print(f"Unknown entity type: {entity.__class__.__name__}")

    for key, value in records_node_dict.items():
        records_node_dict[key] = pd.DataFrame(value).replace({float('nan'): None})

    for key, value in records_rel_dict.items():
        records_rel_dict[key] = pd.DataFrame(value).replace({float('nan'): None})

    return {"nodes": records_node_dict, "relationships": records_rel_dict}

In [21]:
def load_entity_nodes(medication_dataframe: pd.DataFrame, 
                      medical_condition_dataframe: pd.DataFrame, 
                      treatment_arm_dataframe: pd.DataFrame, 
                      study_population_dataframe: pd.DataFrame, 
                      clinical_outcome_dataframe: pd.DataFrame) -> None:
    """
    Load entity nodes into the graph.
    """
    
    entity_nodes_ingest_iterator = list(zip([medication_dataframe, 
                                             medical_condition_dataframe, 
                                             treatment_arm_dataframe, 
                                             study_population_dataframe, 
                                             clinical_outcome_dataframe], 
                                             ['medication', 
                                              'medical_condition', 
                                              'treatment_arm', 
                                              'study_population', 
                                              'clinical_outcome']))

    for data, query in entity_nodes_ingest_iterator:
        if len(data) > 0:
            print(f"Loading {len(data)} {query} nodes")
            res = graph.execute_write_query_with_data(database=db_info['database'], 
                                                    data=data, 
                                                    query=node_load_queries[query], 
                                                    partitions=get_partition(data, batch_size=500),
                                                    parallel=False)
            print(res)
        else:
            print(f"No {query} nodes to load")

In [22]:
def load_entity_relationships(medication_used_in_treatment_arm_dataframe: pd.DataFrame,
                              treatment_arm_has_clinical_outcome_dataframe: pd.DataFrame,
                              study_population_in_treatment_arm_dataframe: pd.DataFrame,
                              study_population_has_medical_condition_dataframe: pd.DataFrame,
                              ) -> None:
    """
    Load entity relationships into the graph.
    """
    entity_relationships_ingest_iterator = list(zip([medication_used_in_treatment_arm_dataframe, 
                                                      treatment_arm_has_clinical_outcome_dataframe, 
                                                      study_population_in_treatment_arm_dataframe, 
                                                      study_population_has_medical_condition_dataframe, 
                                                      ], 
                                                      ['medication_used_in_treatment_arm', 
                                                       'treatment_arm_has_clinical_outcome', 
                                                       'study_population_in_treatment_arm', 
                                                       'study_population_has_medical_condition', 
                                                       ]))
    
    for data, query in entity_relationships_ingest_iterator:
        if len(data) > 0:
            print(f"Loading {len(data)} {query} relationships")
            res = graph.execute_write_query_with_data(database=db_info['database'], 
                                                    data=data, 
                                                    query=relationship_load_queries[query], 
                                                    partitions=get_partition(data, batch_size=500),
                                                    parallel=False)
            print(res)
        else:
            print(f"No {query} relationships to load")

In [23]:
def link_entities_to_chunks(medication_link_dataframe: pd.DataFrame, 
                      medical_condition_link_dataframe: pd.DataFrame, 
                      treatment_arm_link_dataframe: pd.DataFrame, 
                      study_population_link_dataframe: pd.DataFrame, 
                      clinical_outcome_link_dataframe: pd.DataFrame) -> None:
    """
    Link entities to chunks.
    """
    entity_link_iterator = list(zip([medication_link_dataframe, 
                                     medical_condition_link_dataframe, 
                                     treatment_arm_link_dataframe, 
                                     study_population_link_dataframe, 
                                     clinical_outcome_link_dataframe], 
                                     ["chunk_has_entity_medication",
                                      "chunk_has_entity_medical_condition",
                                      "chunk_has_entity_treatment_arm",
                                      "chunk_has_entity_study_population",
                                      "chunk_has_entity_clinical_outcome"]))
    
    for data, query in entity_link_iterator:
        if len(data) > 0:
            print(f"Linking {len(data)} {query} entities to chunks")
            res = graph.execute_write_query_with_data(database=db_info['database'], 
                                                    data=data, 
                                                    query=relationship_load_queries[query], 
                                                    partitions=get_partition(data, batch_size=500),
                                                    parallel=False)
            print(res)
        else:
            print(f"No {query} relationships to load")

In [24]:
ingest_records = prepare_entities_for_ingestion(extracted_entities_with_chunk_ids)

In [27]:
load_entity_nodes(ingest_records["nodes"]["Medication"], 
                  ingest_records["nodes"]["MedicalCondition"], 
                  ingest_records["nodes"]["TreatmentArm"], 
                  ingest_records["nodes"]["StudyPopulation"], 
                  ingest_records["nodes"]["ClinicalOutcome"])

Loading 271 medication nodes
partition: 1
{'properties_set': 1355}
Loading 66 medical_condition nodes
partition: 1
{'labels_added': 32, 'nodes_created': 32, 'properties_set': 296}
Loading 232 treatment_arm nodes
partition: 1
{'labels_added': 211, 'nodes_created': 211, 'properties_set': 675}
Loading 84 study_population nodes
partition: 1
{'labels_added': 75, 'nodes_created': 75, 'properties_set': 1167}
No clinical_outcome nodes to load


In [48]:
ingest_records["relationships"]["TreatmentArmHasClinicalOutcome"]

Unnamed: 0,study_name,treatment_arm_name,clinical_outcome_name,clinical_outcome_id,treatment_arm_id,chunk_id
0,Exenatide Once Weekly vs. Sitagliptin Study,Exenatide Once Weekly Arm,A1C Reduction,d30cc00f1297fdb298bb9cc23cf88667959d6e4c603da5...,69685cf6a8d172ad3a798c63a0b7012d96108bc1e09a8a...,072f6fc5f143b89bf521a5b75867f080
1,Exenatide Once Weekly vs. Sitagliptin Study,Exenatide Once Weekly Arm,Weight Loss,e58e46dddf01a2ccfa430c114b5334b0142c067024af4d...,69685cf6a8d172ad3a798c63a0b7012d96108bc1e09a8a...,072f6fc5f143b89bf521a5b75867f080
2,Exenatide Once Weekly vs. Sitagliptin Study,Sitagliptin Arm,A1C Reduction,d30cc00f1297fdb298bb9cc23cf88667959d6e4c603da5...,fac4659b2549c2320f07dfd4cb459332a11db9e7b82cbe...,072f6fc5f143b89bf521a5b75867f080
3,Exenatide Once Weekly vs. Sitagliptin Study,Sitagliptin Arm,Weight Loss,e58e46dddf01a2ccfa430c114b5334b0142c067024af4d...,fac4659b2549c2320f07dfd4cb459332a11db9e7b82cbe...,072f6fc5f143b89bf521a5b75867f080
4,Network Meta-Analysis of Exenatide,Exenatide Arm,Glycemic control improvement,11b3ba7c3a0f9f07966bf715a1a4636df87c4e0eae85fa...,a055949ab18ff5de56a6221bd9cbf4d29ae8ec6fdc6d71...,0ae41eb129a57859103e320c0cafe3d7
...,...,...,...,...,...,...
114,Ahren et al. (48),Placebo,Change in HbA1c for Placebo,64f0e6ca962a82b233e2752c307114b749abb4623daf54...,ffc331a0ef346b0180835f8330294ab826904e0827b7b1...,62c8f7607d7c903b4179eaba82922516
115,Ahren et al. (48),Albiglutide 30mg,Weight change for Albiglutide,e2e6a80ffb14582bba41184fa8baece63cee7cfe058772...,31fd566f895b4ef5ae947546b1d90e77aa2db0db3675f9...,62c8f7607d7c903b4179eaba82922516
116,Ahren et al. (48),Sitagliptin 100mg,Weight change for Sitagliptin,07fb2b384d0e8e90c31aec165967ece04ebd445e9fdf03...,aeca001e5e7406c604998f575cb11ae64d5db4ba1693e3...,62c8f7607d7c903b4179eaba82922516
117,Ahren et al. (48),Glimepiride 2mg,Weight change for Glimepiride,f1f7103127a3d2573c6a07763fd1a832b7cfb68fa4a5b5...,9bdcba3126d14f3bcdd63988e4430a7bab92e010039e8e...,62c8f7607d7c903b4179eaba82922516


In [38]:
load_entity_relationships(ingest_records["relationships"]["MedicationUsedInTreatmentArm"], 
                          ingest_records["relationships"]["TreatmentArmHasClinicalOutcome"], 
                          ingest_records["relationships"]["StudyPopulationInTreatmentArm"], 
                          ingest_records["relationships"]["StudyPopulationHasMedicalCondition"])

Loading 71 medication_used_in_treatment_arm relationships
partition: 1
{'relationships_created': 55, 'properties_set': 550}
Loading 119 treatment_arm_has_clinical_outcome relationships
partition: 1
{}
Loading 24 study_population_in_treatment_arm relationships
partition: 1
{}
Loading 31 study_population_has_medical_condition relationships
partition: 1
{}


### Link Entities to Lexical Graph

Now we link the loaded entities to their respective Chunk nodes

In [29]:
link_entities_to_chunks(ingest_records["nodes"]["Medication"], 
                        ingest_records["nodes"]["MedicalCondition"], 
                        ingest_records["nodes"]["TreatmentArm"], 
                        ingest_records["nodes"]["StudyPopulation"], 
                        ingest_records["nodes"]["ClinicalOutcome"])

Linking 271 chunk_has_entity_medication entities to chunks
partition: 1
{'relationships_created': 271}
Linking 66 chunk_has_entity_medical_condition entities to chunks
partition: 1
{'relationships_created': 66}
Linking 232 chunk_has_entity_treatment_arm entities to chunks
partition: 1
{'relationships_created': 232}
Linking 84 chunk_has_entity_study_population entities to chunks
partition: 1
{'relationships_created': 80}
No chunk_has_entity_clinical_outcome relationships to load


### Link Entity Graph to the Rest of Domain Graph

Now we execute custom Cypher to link the extracted entities with the existing patient journey graph.

This will create the following relationships

* (:Demographic)-[:IN_STUDY_POPULATION]->(:StudyPopulation)

Other links already exist since we are extracting Medication and MedicalCondtion nodes from the text and these entities already exist in the patient journey graph.

In [30]:
def link_domain_and_patient_journey_graph() -> None:
    """
    Link the domain graph with the patient journey graph. 
    This process doesn't require any input DataFrames. 
    Instead it attempts to link nodes based on matching properties.
    """

    queries = ["demographic_in_study_population"]

    for q in queries:
        res = graph.execute_write_query(database=db_info['database'], 
                                        query=relationship_load_queries[q])
    print(res)

In [31]:
link_domain_and_patient_journey_graph()

{'relationships_created': 150, 'properties_set': 150}


### Resolve Entities

We now can perform some entity resolution. 

The entity extraction process may find entities that are slight variations of existing entities.

We can merge these entities together with some Cypher.

The following entitites will be resolved:
* Medication

In [46]:
def resolve_entities() -> None:
    """
    Resolve extracted entities.
    This process doesn't require any input DataFrames.
    """

    queries = ["resolve_medications"]

    for q in queries:
        res = graph.execute_write_query(database=db_info['database'], 
                                        query=processing_queries[q])
    print(res)

In [47]:
resolve_entities()

{}
