# PubMed Knowledge Graph

This notebook walks through the process of generating a knowledge graph of PubMed articles.

This notebook will
* Download a selection of articles from PubMed
* Define a knowledge graph schema
* Extract entities from the articles according to the defined schema
* Populate a Neo4j instance with articles and extracted entities
* Connect extracted entities with existing patient journey data

This notebook requires a local repo of articles. You may download a sample of 20 PubMed articles by running the following command.

```bash
python3 ./scripts/fetch_pubmed_articles.py
```

In [61]:
from typing import Any, Optional, List
import asyncio

In [60]:
# allows for async operations in notebooks
import nest_asyncio
nest_asyncio.apply()

## Lexical Graph Construction

We will use Unstructured.IO to partition and chunk our articles. 

This process breaks the articles into sensible chunks that may be used as context in our application. 

These chunks will also have relationships to the extracted entities.

IMAGE OF LEXICAL DATA MODEL

In [17]:
from pydantic import BaseModel, Field
from unstructured.partition.auto import partition
from unstructured.documents.elements import CompositeElement

from uuid import uuid4

In [18]:
class Document(BaseModel):
    id: str = Field(..., description="The id of the document")
    name: str = Field(..., description="The name of the document")
    source: str = Field(..., description="The source of the document")

class Chunk(BaseModel):
    id: str = Field(..., description="The id of the chunk")
    text: str = Field(..., description="The text of the chunk")

class ChunkWithEmbedding(Chunk):
    embedding: list[float] = Field(..., description="The embedding of the chunk text field")

class ChunkPartOfDocument(BaseModel):
    chunk_id: str = Field(..., description="The id of the chunk")
    document_id: str = Field(..., description="The id of the document")

class ChunkHasEntity(BaseModel):
    chunk_id: str = Field(..., description="The id of the chunk")
    entity_id: str = Field(..., description="The id of the entity")

In [19]:
def element_to_node_and_relationship(element: CompositeElement, parent_document_id: str) -> dict[str, Chunk | ChunkPartOfDocument]:
    """Parse the entity node and document relationship for a given element"""
    chunk = Chunk(id=element.id, text=element.text)
    chunk_part_of_document = ChunkPartOfDocument(chunk_id=chunk.id, document_id=parent_document_id)
    return {
        "nodes": [chunk],
        "relationships": [chunk_part_of_document],
    }

def elements_to_nodes_and_relationships(elements: list[CompositeElement], parent_document: Document) -> dict[str, list[Document | Chunk | ChunkPartOfDocument]]:
    """Parse entity nodes and document relationships for a set of elements and their parent document"""
    
    data = {
        "nodes": [parent_document],
        "relationships": list(),
    }

    for element in elements:
        new_data = element_to_node_and_relationship(element, parent_document.id)
        data["nodes"].extend(new_data["nodes"])
        data["relationships"].extend(new_data["relationships"])

    return data

def process_article(file_name: str) -> dict[str, list[Document | Chunk | ChunkPartOfDocument]]:
    parent_document = Document(id=str(uuid4()), name=file_name, source="pubmed")
    elements = partition(file_name, chunking_strategy="by_title")
    return elements_to_nodes_and_relationships(elements, parent_document)

## Lexical Graph Embedding

Here we will embed the text fields of our lexical graph for vector similarity search. 

In [None]:
#TODO

## Schema Definition

We now need to define our knowledge graph schema. This information will be passed to the entity extraction LLM to control which entities and relationships are pulled out of the text.

This is necessary to prevent our schema from growing too large with an unbounded extraction process.

We are using Pydantic to define the schema here since it can be used to validate any returned results as well. This ensures that all data we are ingesting into Neo4j adheres to this structure.

In [169]:
class Medication(BaseModel):
    """a substance used for medical treatment, especially a medicine or drug. This is a base medication, not a medication implemented in a study."""
    
    medication_id: str = Field(..., description="Unique identifier for the medication")
    name: str = Field(..., description="Name of the medication. Should also be uniquely identifiable.")
    medication_class: str = Field(..., description="Drug class (e.g., GLP-1 RA, SGLT2i)")
    mechanism: Optional[str] = Field(None, description="Mechanism of action")
    generic_name: Optional[str] = Field(None, description="Generic name if different from name")
    brand_names: Optional[List[str]] = Field(None, description="Commercial brand names")
    approval_status: Optional[str] = Field(None, description="FDA approval status")
    
    class Config:
        json_schema_extra = {
            "examples": [
                {
                    "medication_id": "MED001",
                    "name": "Semaglutide", 
                    "medication_class": "GLP-1 receptor agonist",
                    "mechanism": "GLP-1 receptor activation",
                    "generic_name": "semaglutide",
                    "brand_names": ["Ozempic", "Wegovy", "Rybelsus"],
                    "approval_status": "FDA approved"
                }
            ]
        }


class StudyMedication(BaseModel):
    """Study-specific medication usage - how a medication was used in a particular study"""
    
    study_medication_id: str = Field(..., description="Unique identifier for this study medication instance")
    dosage: Optional[str] = Field(None, description="Dosage used in this study")
    route: Optional[str] = Field(None, description="Route of administration")
    frequency: Optional[str] = Field(None, description="Dosing frequency")
    treatment_duration: Optional[str] = Field(None, description="Duration of treatment")
    treatment_arm: Optional[str] = Field(None, description="Treatment arm description")
    comparator: Optional[str] = Field(None, description="What this was compared against")
    adherence_rate: Optional[float] = Field(None, description="Treatment adherence rate")
    formulation: Optional[str] = Field(None, description="Specific formulation used")
    
    class Config:
        json_schema_extra = {
            "examples": [
                {
                    "study_medication_id": "STUDY_MED001",
                    "dosage": "1.0 mg",
                    "route": "subcutaneous",
                    "frequency": "weekly",
                    "treatment_duration": "12 weeks",
                    "treatment_arm": "Active treatment group",
                    "comparator": "placebo",
                    "adherence_rate": 85.5,
                    "formulation": "pre-filled pen"
                }
            ]
        }


class ClinicalOutcome(BaseModel):
    """Measured clinical outcomes and biomarkers"""
    
    clinical_outcome_id: str = Field(..., description="Unique identifier for the outcome")
    name: str = Field(..., description="Name of the clinical outcome")
    category: str = Field(..., description="Category of outcome")
    measurement_unit: Optional[str] = Field(None, description="Unit of measurement")
    normal_range: Optional[str] = Field(None, description="Normal or target range when applicable")
    baseline_value: Optional[float] = Field(None, description="Baseline measurement value")
    post_treatment_value: Optional[float] = Field(None, description="Post-treatment measurement value")
    change_from_baseline: Optional[float] = Field(None, description="Change from baseline")
    p_value: Optional[float] = Field(None, description="Statistical significance if reported")
    confidence_interval: Optional[str] = Field(None, description="95% confidence interval")
    effect_size: Optional[float] = Field(None, description="Standardized effect size")
    
    class Config:
        json_schema_extra = {
            "examples": [
                {
                    "clinical_outcome_id": "OUT001",
                    "name": "HbA1c",
                    "category": "Glycemic control",
                    "measurement_unit": "%",
                    "normal_range": "<7.0%",
                    "baseline_value": 8.5,
                    "post_treatment_value": 7.2,
                    "change_from_baseline": -1.3,
                    "p_value": 0.001,
                    "confidence_interval": "[-1.8, -0.8]",
                    "effect_size": -0.8
                }
            ]
        }


class MedicalCondition(BaseModel):
    """Medical conditions and comorbidities studied"""
    
    medical_condition_id: str = Field(..., description="Unique identifier for the condition")
    name: str = Field(..., description="Name of the medical condition")
    category: str = Field(..., description="Category of condition")
    severity: Optional[str] = Field(None, description="Severity or stage when specified")
    icd10_code: Optional[str] = Field(None, description="ICD-10 code when available")
    duration: Optional[str] = Field(None, description="Duration of condition if specified")
    prevalence: Optional[float] = Field(None, description="Prevalence in study population")
    
    class Config:
        json_schema_extra = {
            "examples": [
                {
                    "medical_condition_id": "COND001",
                    "name": "Type 2 diabetes mellitus",
                    "category": "Primary condition", 
                    "severity": "moderate",
                    "icd10_code": "E11",
                    "duration": "5-10 years",
                    "prevalence": 100.0
                }
            ]
        }


class StudyPopulation(BaseModel):
    """Patient populations and demographics in research studies"""
    
    study_population_id: str = Field(..., description="Unique identifier for the population")
    description: str = Field(..., description="Description of the population")
    age_range: Optional[str] = Field(None, description="Age range")
    mean_age: Optional[float] = Field(None, description="Mean age in years")
    male_percentage: Optional[float] = Field(None, description="Percentage of male gender participants")
    female_percentage: Optional[float] = Field(None, description="Percentage of female gender participants")
    other_gender_percentage: Optional[float] = Field(None, description="Percentage of participants that identify as another gender")
    sample_size: Optional[int] = Field(None, description="Number of participants")
    study_type: str = Field(..., description="Type of study")
    location: Optional[str] = Field(None, description="Geographic location of study")
    inclusion_criteria: Optional[List[str]] = Field(None, description="Key inclusion criteria")
    exclusion_criteria: Optional[List[str]] = Field(None, description="Key exclusion criteria")
    study_duration: Optional[str] = Field(None, description="Duration of study")
    
    class Config:
        json_schema_extra = {
            "examples": [
                {
                    "study_population_id": "POP001",
                    "description": "Adults with T2DM and schizophrenia",
                    "age_range": "18-65 years",
                    "mean_age": 43.8,
                    "female_percentage": 47.0,
                    "male_percentage": 53.0,
                    "sample_size": 354,
                    "study_type": "Observational study",
                    "location": "Denmark",
                    "inclusion_criteria": ["Type 2 diabetes diagnosis", "Schizophrenia diagnosis", "Age ≥18"],
                    "study_duration": "12 months"
                }
            ]
        }


# Relationship classes
class StudyMedicationUsesMedication(BaseModel):
    """Links study medication to base medication"""
    study_medication_id: str
    medication_name: str


class StudyMedicationProducesClinicalOutcome(BaseModel):
    """Links study medication usage to clinical outcomes"""
    study_medication_id: str
    clinical_outcome_name: str


class StudyPopulationHasMedicalCondition(BaseModel):
    """Relationship between study population and medical conditions"""
    study_population_id: str
    medical_condition_name: str


class StudyPopulationReceivesStudyMedication(BaseModel):
    """Relationship between study population and study medication"""
    study_population_id: str
    study_medication_id: str


class StudyPopulationHasOutcome(BaseModel):
    """Direct relationship between population and outcomes (for population-level measurements)"""
    study_population_id: str
    clinical_outcome_name: str

Our knowledge graph data model looks like this 

IMAGE OF DATA MODEL

## Entity Extraction LLM

We will be using [OpenAI](https://platform.openai.com/docs/overview) and the [Instructor](https://python.useinstructor.com/) library to perform our entity extraction.

In [149]:
from openai import AsyncOpenAI
import instructor

In [150]:
client = instructor.from_openai(AsyncOpenAI())

In [151]:
async def extract_entities_from_text_chunk(text_chunk: str) -> list:
    response = await client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": "You are a healthcare research expert that is responsible for extracting detailed entities from PubMed articles. You are provided a graph data model schema and must extract entities and relationships to populate a knowledge graph."},
            {"role": "user", "content": text_chunk}
        ],
        response_model=list[Medication | StudyMedication | StudyMedicationUsesMedication],
    )

    return response

In [153]:
async def extract_entities_from_chunk_nodes(chunk_nodes: list[Chunk]) -> list[tuple[str, list[Any]]]:
    """Process a list of Chunk nodes and return the entities found in each chunk."""

    # Create tasks for all nodes
    # order is maintained
    tasks = [extract_entities_from_text_chunk(chunk.text) for chunk in chunk_nodes]

    # Execute all tasks concurrently
    extraction_results = await asyncio.gather(*tasks)

    # Return chunk_id paired with its entities
    return [(chunk.id, entities) for chunk, entities in zip(chunk_nodes, extraction_results)]

### Test Extraction

In [133]:
with open("pubmed_abstracts.txt", "r") as f:
    text = f.read()[1500:2500]

print(text)

al Science, Faculty of Medicine, Universitas Airlangga, 
Surabaya, Indonesia. fahrul.nurkolis.mail@gmail.com.
(11)Medical Research Center of Indonesia, Surabaya, East Java, Indonesia. 
fahrul.nurkolis.mail@gmail.com.

BACKGROUND: The global rise in obesity and type 2 diabetes highlights the need 
for safe and effective therapeutic interventions. Enhalus acoroides is a 
tropical seagrass rich in carotenoids and other bioactives. Its potential for 
metabolic regulation has been suggested in vitro, but in vivo efficacy and 
molecular mechanisms remain unexplored. This study aimed to evaluate the 
anti-obesity and anti-diabetic effects of Enhalus acoroides extract (SEAE) in a 
zebrafish model of diet- and glucose-induced metabolic dysfunction.
METHODS: Adult zebrafish were subjected to overfeeding and glucose immersion, 
after overfeeding and 14 days of glucose immersion to induce diabetes, adult 
zebrafish were randomized into three groups: untreated diabetic, SEAE-treated 
(5 mg/L), and 

In [134]:
ents = await extract_entities_from_text_chunk(text)

In [135]:
ents

[Medication(medication_id='MED002', name='Enhalus acoroides extract', medication_class='Natural extract', mechanism=None, generic_name='SEAE', brand_names=None, approval_status=None),
 StudyMedication(study_medication_id='STUDY_MED002', medication_id='MED002', dosage='5 mg/L', route=None, frequency=None, treatment_duration=None, treatment_arm='SEAE-treated group', comparator=None, adherence_rate=None, formulation=None)]

## Data Ingestion

We have now defined 
* Lexical and domain data models
* Partitioning and chunking logic for articles
* Entity extraction logic for chunks

It is now time to define our ingestion logic. We will run ingest in three stages 

1. Load lexical graph
2. Embed lexical graph Chunk nodes
3. Extract domain / entity graph from lexical graph

Decoupling these stages allows us easily make changes as we iterate our ingestion process.

In [74]:
import os

from neo4j import GraphDatabase, Driver, RoutingControl

In [28]:
driver = GraphDatabase.driver(os.getenv("NEO4J_URI", "bolt://localhost:7687"), 
                              auth=(os.getenv("NEO4J_USER", "neo4j"), os.getenv("NEO4J_PASSWORD", "password")))

### Ingest Lexical Graph

In [49]:
lexical_constraints = [
    "CREATE CONSTRAINT document_id IF NOT EXISTS FOR (n:Document) REQUIRE n.id IS NODE KEY",
    "CREATE CONSTRAINT document_name IF NOT EXISTS FOR (n:Document) REQUIRE n.name IS NODE KEY",
    "CREATE CONSTRAINT chunk_id IF NOT EXISTS FOR (n:Chunk) REQUIRE n.id IS NODE KEY"
]


def create_lexical_constraints(driver: Driver) -> None:
    for constraint in lexical_constraints:
        driver.execute_query(constraint, database_=os.getenv("NEO4J_DATABASE", "neo4j"), routing_=RoutingControl.WRITE)

In [50]:
def load_lexical_document_nodes(driver: Driver, document_nodes: list[Document]) -> None:
    """Load document nodes into the lexical graph"""
    records = [record.model_dump() for record in document_nodes]
    query = """
        UNWIND $records AS record
        MERGE (n:Document {id: record.id})
        ON CREATE SET n.name = record.name, n.source = record.source
    """
    driver.execute_query(query, 
                         records=records, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

def load_lexical_chunk_nodes(driver: Driver, chunk_nodes: list[Chunk]) -> None:
    """Load chunk nodes into the lexical graph"""
    records = [record.model_dump() for record in chunk_nodes]
    query = """
        UNWIND $records AS record
        MERGE (n:Chunk {id: record.id})
        ON CREATE SET n.text = record.text
    """
    driver.execute_query(query, 
                         records=records, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

def load_lexical_chunk_part_of_document_relationships(driver: Driver, chunk_part_of_document_relationships: list[ChunkPartOfDocument]) -> None:
    """Load Chunk - PART_OF -> Document relationships into the lexical graph"""
    records = [record.model_dump() for record in chunk_part_of_document_relationships]
    query = """
        UNWIND $records AS record
        MATCH (c:Chunk {id: record.chunk_id}), (d:Document {id: record.document_id})
        MERGE (c)-[:PART_OF]->(d)
    """
    driver.execute_query(query, 
                         records=records, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

def load_lexical_graph(driver: Driver, lexical_ingest_records: dict[str, list[Document | Chunk | ChunkPartOfDocument]]) -> None:
    """Load the lexical graph"""
    print("Creating Constraints")
    create_lexical_constraints(driver)

    print(f"Loading Document nodes")
    load_lexical_document_nodes(driver, [lexical_ingest_records.get("nodes")[0]])

    print(f"Loading Chunk nodes")
    load_lexical_chunk_nodes(driver, lexical_ingest_records.get("nodes")[1:])
    
    print(f"Loading Chunk - PART_OF -> Document relationships")
    load_lexical_chunk_part_of_document_relationships(driver, lexical_ingest_records.get("relationships"))

Process the articles

In [37]:
lexical_ingest_records = process_article("pubmed_abstracts.txt")

libmagic is unavailable but assists in filetype detection. Please consider installing libmagic for better results.


Check the first few records 

In [136]:
lexical_ingest_records.get("nodes")[:3]

[Document(id='a5ca5246-9454-42c6-88b4-93fadd9c439d', name='pubmed_abstracts.txt', source='pubmed'),
 Chunk(id='c7575a02b7776f6183bcb0c3aa2b3a58', text='1. Diabetol Metab Syndr. 2025 Jun 21;17(1):235. doi: 10.1186/s13098-025-01823-4.\n\nSeagrass Enhalus acoroides extract mitigates obesity and diabetes via GLP-1, PPARγ, SREBP-1c modulation and gut microbiome restoration in diabetic zebrafish.\n\nKadharusman MM(1), Syahputra RA(2), Kurniawan R(3), Hadinata E(4), Tjandrawinata RR(5), Taslim NA(6), Romano R(7), Santini A(8), Nurkolis F(9)(10)(11).'),
 Chunk(id='aaf28f4d0b350bd4c312418709797fd5', text='Author information: (1)Faculty of Medicine, Universitas Indonesia, Jakarta, Indonesia. (2)Department of Pharmacology, Faculty of Pharmacy, Universitas Sumatera Utara, Medan, 20155, Indonesia. (3)Graduate School of Medicine, Faculty of Medicine, Hasanuddin University, Makassar, Indonesia. (4)Faculty of Medicine, Ciputra University of Surabaya, Surabaya, 60219, Indonesia. (5)Center for Pharmaceu

Load the records into the graph

In [51]:
load_lexical_graph(driver, lexical_ingest_records)

Creating Constraints
Loading Document nodes
Loading Chunk nodes
Loading Chunk - PART_OF -> Document relationships


### Embed Lexical Graph

Here we will read Chunk nodes from the graph that don't have embedding properties yet. 

We will then embed the Chunk text property and add the embedding as a property.

In [52]:
vector_index = ...

def create_vector_index(driver: Driver) -> None:
    ...

In [54]:
def create_embeddings(driver: Driver) -> None:
    ...

def embed_lexical_graph(driver: Driver) -> None:
    ...

### Extract Entities from Lexical Graph

We will now perform entity extraction on the Chunk nodes to augment and connect to our domain graph containing patient journey information.

In [137]:
entity_constraints = [
    "CREATE CONSTRAINT medication_id IF NOT EXISTS FOR (n:Medication) REQUIRE n.id IS UNIQUE",
    "CREATE CONSTRAINT medication_name IF NOT EXISTS FOR (n:Medication) REQUIRE n.name IS NODE KEY",
    "CREATE CONSTRAINT study_medication_id IF NOT EXISTS FOR (n:StudyMedication) REQUIRE n.id IS UNIQUE",
    "CREATE CONSTRAINT medical_condition_id IF NOT EXISTS FOR (n:MedicalCondition) REQUIRE n.id IS UNIQUE",
    "CREATE CONSTRAINT medical_condition_name IF NOT EXISTS FOR (n:MedicalCondition) REQUIRE n.name IS NODE KEY",
    "CREATE CONSTRAINT study_population_id IF NOT EXISTS FOR (n:StudyPopulation) REQUIRE n.id IS UNIQUE",
    "CREATE CONSTRAINT study_population_name IF NOT EXISTS FOR (n:StudyPopulation) REQUIRE n.name IS NODE KEY",
    "CREATE CONSTRAINT study_outcome_id IF NOT EXISTS FOR (n:StudyOutcome) REQUIRE n.id IS UNIQUE",
    "CREATE CONSTRAINT study_outcome_name IF NOT EXISTS FOR (n:StudyOutcome) REQUIRE n.name IS NODE KEY",
]

def create_entity_constraints(driver: Driver) -> None:
    """Create constraints for the entity graph"""

    driver.execute_query(entity_constraints, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

In [138]:
def get_chunk_nodes_to_process(driver: Driver, article_name: str) -> list[Chunk]:
    """
    Retrieve Chunk node id and text from the database that have a relationship to the Document with the article name provided.
    """

    query = """
    MATCH (d:Document {name: $article_name})<-[:PART_OF]-(c:Chunk)
    RETURN c.id as id, c.text as text
    """

    result = driver.execute_query(query, 
                                  {"article_name": article_name}, 
                                  database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                                  routing_=RoutingControl.READ, 
                                  result_transformer_=lambda x: [Chunk(**args) for args in x.data()])
    return result

In [167]:
ENTITY_LABELS = {
    "Medication", 
    "StudyMedication",
    "MedicalCondition",
    "StudyPopulation",
    "ClinicalOutcome",
}

ENTITY_RELS = {
    "StudyMedicationUsesMedication",
    "StudyMedicationProducesOutcome",
    "StudyPopulationHasMedicalCondition",
    "StudyPopulationReceivesStudyMedication",
    "StudyPopulationHasOutcome",
}

def prepare_entities_for_ingestion(entities: list[tuple[str, list[Any]]]) -> dict[str, list[dict[str, Any]]]:
    """
    Prepare entities for ingestion into the graph.
    This function takes the results of the `get_chunk_nodes_to_process` function and returns a dictionary of entity label to list of entities.
    """

    records_node_dict = {lbl: list() for lbl in ENTITY_LABELS}
    records_rel_dict = {lbl: list() for lbl in ENTITY_RELS}

    for chunk_id, entities in entities:
        for entity in entities:
            to_add = entity.model_dump()
            to_add.update({"chunk_id": chunk_id})
            # nodes
            if entity.__class__.__name__ in ENTITY_LABELS:
                records_node_dict[entity.__class__.__name__].append(to_add)
            # rels
            elif entity.__class__.__name__ in ENTITY_RELS:
                records_rel_dict[entity.__class__.__name__].append(to_add)
            else:
                print(f"Unknown entity type: {entity.__class__.__name__}")

    return {"nodes": records_node_dict, "relationships": records_rel_dict}

In [139]:
chunks_to_process = get_chunk_nodes_to_process(driver, "pubmed_abstracts.txt")
print(f"Found {len(chunks_to_process)} chunks to process")
print(f"First chunk: {chunks_to_process[0]}")

Found 171 chunks to process
First chunk: id='c7575a02b7776f6183bcb0c3aa2b3a58' text='1. Diabetol Metab Syndr. 2025 Jun 21;17(1):235. doi: 10.1186/s13098-025-01823-4.\n\nSeagrass Enhalus acoroides extract mitigates obesity and diabetes via GLP-1, PPARγ, SREBP-1c modulation and gut microbiome restoration in diabetic zebrafish.\n\nKadharusman MM(1), Syahputra RA(2), Kurniawan R(3), Hadinata E(4), Tjandrawinata RR(5), Taslim NA(6), Romano R(7), Santini A(8), Nurkolis F(9)(10)(11).'


In [164]:
entity_ingest_records = await extract_entities_from_chunk_nodes(chunks_to_process[:8])

In [165]:
entity_ingest_records

[('c7575a02b7776f6183bcb0c3aa2b3a58',
  [Medication(medication_id='MED002', name='Seagrass Enhalus acoroides extract', medication_class='Natural extract', mechanism='GLP-1, PPARγ, SREBP-1c modulation and gut microbiome restoration', generic_name=None, brand_names=None, approval_status=None)]),
 ('aaf28f4d0b350bd4c312418709797fd5', []),
 ('def7de5c464d6254bdaa2d73b86034eb',
  [Medication(medication_id='MED001', name='Semaglutide', medication_class='GLP-1 receptor agonist', mechanism='GLP-1 receptor activation', generic_name='semaglutide', brand_names=['Ozempic', 'Wegovy', 'Rybelsus'], approval_status='FDA approved'),
   StudyMedication(study_medication_id='STUDY_MED001', dosage='1.0 mg', route='subcutaneous', frequency='weekly', treatment_duration='12 weeks', treatment_arm='Active treatment group', comparator='placebo', adherence_rate=85.5, formulation='pre-filled pen'),
   StudyMedicationUsesMedication(study_medication_id='STUDY_MED001', medication_name='Semaglutide')]),
 ('e0313766b61

### Ingest Entities Into Knowledge Graph

These functions load the extracted entities and relationships

In [181]:
def ingest_medication_nodes(driver: Driver, medications: list[Medication]) -> None:
    """Ingest medication nodes into the graph"""

    query = """
    UNWIND $medications AS medication
    MERGE (n:Medication {id: medication.medication_id, name: medication.name})
    SET n.medication_class = COALESCE(n.medication_class, medication.medication_class), 
        n.mechanism = COALESCE(n.mechanism, medication.mechanism), 
        n.generic_name = COALESCE(n.generic_name, medication.generic_name), 
        n.brand_names = COALESCE(n.brand_names, medication.brand_names), 
        n.approval_status = COALESCE(n.approval_status, medication.approval_status)
    """

    driver.execute_query(query, 
                         {"medications": medications}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

def ingest_study_medication_nodes(driver: Driver, study_medications: list[StudyMedication]) -> None:
    """Ingest study medication nodes into the graph"""

    query = """
    UNWIND $study_medications AS study_medication
    MERGE (n:StudyMedication {id: study_medication.study_medication_id})
    SET n.dosage = COALESCE(n.dosage, study_medication.dosage), 
        n.route = COALESCE(n.route, study_medication.route), 
        n.frequency = COALESCE(n.frequency, study_medication.frequency), 
        n.treatment_duration = COALESCE(n.treatment_duration, study_medication.treatment_duration), 
        n.treatment_arm = COALESCE(n.treatment_arm, study_medication.treatment_arm), 
        n.comparator = COALESCE(n.comparator, study_medication.comparator), 
        n.adherence_rate = COALESCE(n.adherence_rate, study_medication.adherence_rate), 
        n.formulation = COALESCE(n.formulation, study_medication.formulation)
    """

    driver.execute_query(query, 
                         {"study_medications": study_medications}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

def ingest_clinical_outcome_nodes(driver: Driver, clinical_outcomes: list[ClinicalOutcome]) -> None:
    """Ingest clinical outcome nodes into the graph"""

    query = """
    UNWIND $clinical_outcomes AS clinical_outcome
    MERGE (n:ClinicalOutcome {id: clinical_outcome.clinical_outcome_id})
    SET n.category = COALESCE(n.category, clinical_outcome.category), 
        n.measurement_unit = COALESCE(n.measurement_unit, clinical_outcome.measurement_unit), 
        n.normal_range = COALESCE(n.normal_range, clinical_outcome.normal_range), 
        n.baseline_value = COALESCE(n.baseline_value, clinical_outcome.baseline_value), 
        n.post_treatment_value = COALESCE(n.post_treatment_value, clinical_outcome.post_treatment_value), 
        n.change_from_baseline = COALESCE(n.change_from_baseline, clinical_outcome.change_from_baseline), 
        n.p_value = COALESCE(n.p_value, clinical_outcome.p_value), 
        n.confidence_interval = COALESCE(n.confidence_interval, clinical_outcome.confidence_interval), 
        n.effect_size = COALESCE(n.effect_size, clinical_outcome.effect_size)
    """

    driver.execute_query(query, 
                         {"clinical_outcomes": clinical_outcomes}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

def ingest_medical_condition_nodes(driver: Driver, medical_conditions: list[MedicalCondition]) -> None:
    """Ingest medical condition nodes into the graph"""

    query = """
    UNWIND $medical_conditions AS medical_condition
    MERGE (n:MedicalCondition {id: medical_condition.medical_condition_id, name: medical_condition.name})
    SET n.category = COALESCE(n.category, medical_condition.category), 
        n.severity = COALESCE(n.severity, medical_condition.severity), 
        n.icd10_code = COALESCE(n.icd10_code, medical_condition.icd10_code), 
        n.duration = COALESCE(n.duration, medical_condition.duration), 
        n.prevalence = COALESCE(n.prevalence, medical_condition.prevalence)
    """

    driver.execute_query(query, 
                         {"medical_conditions": medical_conditions}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

def ingest_study_population_nodes(driver: Driver, study_populations: list[StudyPopulation]) -> None:
    """Ingest study population nodes into the graph"""

    query = """
    UNWIND $study_populations AS study_population
    MERGE (n:StudyPopulation {id: study_population.study_population_id})
    SET n.description = COALESCE(n.description, study_population.description), 
        n.age_range = COALESCE(n.age_range, study_population.age_range), 
        n.mean_age = COALESCE(n.mean_age, study_population.mean_age), 
        n.male_percentage = COALESCE(n.male_percentage, study_population.male_percentage), 
        n.female_percentage = COALESCE(n.female_percentage, study_population.female_percentage), 
        n.other_gender_percentage = COALESCE(n.other_gender_percentage, study_population.other_gender_percentage), 
        n.sample_size = COALESCE(n.sample_size, study_population.sample_size), 
        n.study_type = COALESCE(n.study_type, study_population.study_type), 
        n.location = COALESCE(n.location, study_population.location), 
        n.inclusion_criteria = COALESCE(n.inclusion_criteria, study_population.inclusion_criteria), 
        n.exclusion_criteria = COALESCE(n.exclusion_criteria, study_population.exclusion_criteria), 
        n.study_duration = COALESCE(n.study_duration, study_population.study_duration)
    """

    driver.execute_query(query, 
                         {"study_populations": study_populations}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)



def ingest_study_medication_uses_medication_relationships(driver: Driver, study_medication_uses_medications: list[StudyMedicationUsesMedication]) -> None:
    """Links study medication to base medication"""

    query = """
    UNWIND $study_medication_uses_medications AS study_medication_uses_medication
    MATCH (s:StudyMedication {id: study_medication_uses_medication.study_medication_id}), (m:Medication {name: study_medication_uses_medication.medication_name})
    MERGE (s)-[:USES]->(m)
    """

    driver.execute_query(query, 
                         {"study_medication_uses_medications": study_medication_uses_medications}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

def ingest_study_medication_produces_clinical_outcome_relationships(driver: Driver, study_medication_produces_clinical_outcomes: list[StudyMedicationProducesClinicalOutcome]) -> None:
    """Links study medication usage to clinical outcomes"""

    query = """
    UNWIND $study_medication_produces_clinical_outcomes AS study_medication_produces_clinical_outcome
    MATCH (s:StudyMedication {id: study_medication_produces_clinical_outcome.study_medication_id}), (o:ClinicalOutcome {name: study_medication_produces_clinical_outcome.clinical_outcome_name})
    MERGE (s)-[:PRODUCES]->(o)
    """

    driver.execute_query(query, 
                         {"study_medication_produces_clinical_outcomes": study_medication_produces_clinical_outcomes}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)



def ingest_study_population_has_medical_condition_relationships(driver: Driver, study_population_has_medical_conditions: list[StudyPopulationHasMedicalCondition]) -> None:
    """Links study population to medical condition"""
    
    query = """
    UNWIND $study_population_has_medical_conditions AS study_population_has_medical_condition
    MATCH (s:StudyPopulation {id: study_population_has_medical_condition.study_population_id}), (m:MedicalCondition {name: study_population_has_medical_condition.medical_condition_name})
    MERGE (s)-[:HAS]->(m)
    """ 

    driver.execute_query(query, 
                         {"study_population_has_medical_conditions": study_population_has_medical_conditions}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)


def ingest_study_population_receives_study_medication_relationships(driver: Driver, study_population_receives_study_medications: list[StudyPopulationReceivesStudyMedication]) -> None:
    """Links study population to study medication"""

    query = """
    UNWIND $study_population_receives_study_medications AS study_population_receives_study_medication
    MATCH (s:StudyPopulation {id: study_population_receives_study_medication.study_population_id}), (m:StudyMedication {id: study_population_receives_study_medication.study_medication_id})
    MERGE (s)-[:RECEIVES]->(m)
    """

    driver.execute_query(query, 
                         {"study_population_receives_study_medications": study_population_receives_study_medications}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

def ingest_study_population_has_clinical_outcome_relationships(driver: Driver, study_population_has_outcomes: list[StudyPopulationHasOutcome]) -> None:
    """Links study population to outcome"""

    query = """
    UNWIND $study_population_has_outcomes AS study_population_has_outcome
    MATCH (s:StudyPopulation {id: study_population_has_outcome.study_population_id}), (o:ClinicalOutcome {name: study_population_has_outcome.clinical_outcome_name})
    MERGE (s)-[:HAS_OUTCOME]->(o)
    """

    driver.execute_query(query, 
                         {"study_population_has_outcomes": study_population_has_outcomes}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)


These functions link the extracted entities with their text chunk nodes

In [182]:
def ingest_chunk_has_entity_medication_relationships(driver: Driver, records: list[dict[str, Any]]) -> None:
    """Ingest document has entity medication relationships into the graph"""

    query = """
    UNWIND $records AS record
    MATCH (c:Chunk {id: record.chunk_id}), (e:Medication {id: record.medication_id})
    MERGE (c)-[:HAS_ENTITY]->(e)
    """

    driver.execute_query(query, 
                         {"records": records}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

def ingest_chunk_has_entity_medical_condition_relationships(driver: Driver, records: list[dict[str, Any]]) -> None:
    """Ingest document has entity medical condition relationships into the graph"""

    query = """
    UNWIND $records AS record
    MATCH (c:Chunk {id: record.chunk_id}), (e:MedicalCondition {id: record.medical_condition_id})
    MERGE (c)-[:HAS_ENTITY]->(e)
    """

    driver.execute_query(query, 
                         {"records": records}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

def ingest_chunk_has_entity_study_medication_relationships(driver: Driver, records: list[dict[str, Any]]) -> None:
    """Ingest document has entity study medication relationships into the graph"""

    query = """
    UNWIND $records AS record
    MATCH (c:Chunk {id: record.chunk_id}), (e:StudyMedication {id: record.study_medication_id})
    MERGE (c)-[:HAS_ENTITY]->(e)
    """

    driver.execute_query(query,     
                         {"records": records}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

def ingest_chunk_has_entity_study_population_relationships(driver: Driver, records: list[dict[str, Any]]) -> None:
    """Ingest document has entity study population relationships into the graph"""

    query = """
    UNWIND $records AS record
    MATCH (c:Chunk {id: record.chunk_id}), (e:StudyPopulation {id: record.study_population_id})
    MERGE (c)-[:HAS_ENTITY]->(e)
    """

    driver.execute_query(query, 
                         {"records": records}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

def ingest_chunk_has_entity_clinical_outcome_relationships(driver: Driver, records: list[dict[str, Any]]) -> None:
    """Ingest document has entity clinical outcome relationships into the graph"""

    query = """
    UNWIND $records AS record
    MATCH (c:Chunk {id: record.chunk_id}), (e:ClinicalOutcome {id: record.clinical_outcome_id})
    MERGE (c)-[:HAS_ENTITY]->(e)
    """

    driver.execute_query(query, 
                         {"records": records}, 
                         database_=os.getenv("NEO4J_DATABASE", "neo4j"), 
                         routing_=RoutingControl.WRITE)

In [183]:
def ingest_entities(driver: Driver, records_dict: dict[str, list[dict[str, Any]]]) -> None:
    """Ingest entities into the graph"""

    nodes = records_dict.get("nodes")
    rels = records_dict.get("relationships")

    print("Ingesting Nodes")
    ingest_medication_nodes(driver, nodes.get("Medication"))
    ingest_study_medication_nodes(driver, nodes.get("StudyMedication"))
    ingest_medical_condition_nodes(driver, nodes.get("MedicalCondition"))
    ingest_study_population_nodes(driver, nodes.get("StudyPopulation"))
    ingest_clinical_outcome_nodes(driver, nodes.get("ClinicalOutcome"))

    print("Linking Nodes to Chunks")
    ingest_chunk_has_entity_medication_relationships(driver, nodes.get("Medication"))
    ingest_chunk_has_entity_medical_condition_relationships(driver, nodes.get("MedicalCondition"))
    ingest_chunk_has_entity_study_medication_relationships(driver, nodes.get("StudyMedication"))
    ingest_chunk_has_entity_study_population_relationships(driver, nodes.get("StudyPopulation"))
    ingest_chunk_has_entity_clinical_outcome_relationships(driver, nodes.get("ClinicalOutcome"))

    print("Ingesting Relationships")
    ingest_study_medication_uses_medication_relationships(driver, rels.get("StudyMedicationUsesMedication"))
    ingest_study_population_has_medical_condition_relationships(driver, rels.get("StudyPopulationHasMedicalCondition"))
    ingest_study_population_receives_study_medication_relationships(driver, rels.get("StudyPopulationReceivesStudyMedication"))
    ingest_study_population_has_clinical_outcome_relationships(driver, rels.get("StudyPopulationHasClinicalOutcome"))
    ingest_study_medication_produces_clinical_outcome_relationships(driver, rels.get("StudyMedicationProducesClinicalOutcome"))


In [184]:
ingest_records = prepare_entities_for_ingestion(entity_ingest_records)

In [185]:
ingest_entities(driver, ingest_records)

Ingesting Nodes
Linking Nodes to Chunks
Ingesting Relationships
