# PubMed Knowledge Graph

This notebook is part of a series that walks through the process of generating a knowledge graph of PubMed articles.

It will use the Neo4j GraphRAG Python Package to create a pipeline that parses documents, extracts entities and loads data into Neo4j. 

This notebook is based on the custom pipeline example provided by the [Neo4j GraphRAG Python Package documentation](https://github.com/neo4j/neo4j-graphrag-python/blob/main/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py).

**Please see the [Neo4j GraphRAG Section README](./README.md) for details on differences from the main notebooks and known bugs.**

## Imports

In [1]:
# allows for async operations in notebooks
import nest_asyncio
nest_asyncio.apply()

In [2]:
import asyncio
import os
from typing import Any

from neo4j_graphrag.experimental.components.entity_relation_extractor import (
    LLMEntityRelationExtractor,
    OnError,
)
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
from neo4j_graphrag.experimental.components.resolver import (
    FuzzyMatchResolver,
)
from neo4j_graphrag.experimental.components.schema import (
    SchemaBuilder,
    NodeType,
    RelationshipType,
    PropertyType,
)
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
    FixedSizeSplitter,
)
from neo4j_graphrag.llm import LLMInterface, OpenAILLM

from neo4j_graphrag.experimental.pipeline import Pipeline

from neo4j import Driver, GraphDatabase

## Entity Graph Schema Definition

Here we define the entities and relationships to extract from the documents. 

These are the same as the entities and relationships that are defined in `3_generate_entity_graph.ipynb`, however here we are using the Neo4j GraphRAG classes `NodeType`, `RelationshipType` and `PropertyType` instead of defining our own Pydantic classes.

Here is what our entity graph data model looks like.

<img src="../assets/images/entity-data-model.png" alt="entity-data-model" width="600px">

### Node Definitions

In [3]:
medication = NodeType(
    label="Medication",
    description="A substance used for medical treatment - a medicine or drug. This is a general representation of a medication. A Medication node may have relationships to TreatmentArm nodes that are specific to a particular study.",
    properties=[
        PropertyType(name="name", type="STRING", required=True, description="Name of the medication. Should also be uniquely identifiable. Do not include dosage, administration, frequency, or other details."),
        PropertyType(name="medicationClass", type="STRING", required=True, description="Drug class (e.g., GLP-1 RA, SGLT2i)"),
        PropertyType(name="mechanism", type="STRING", description="Mechanism of action"),
        PropertyType(name="genericName", type="STRING", description="Generic name of the medication"),
        PropertyType(name="brandNames", type="LIST", description="Commercial brand names"),
        PropertyType(name="approvalStatus", type="STRING", description="FDA approval status"),
    ],
)

treatment_arm = NodeType(
    label="TreatmentArm",
    description="A treatment arm is a group of participants who receive the same treatment. It may be a control arm or an experimental arm.",
    properties=[
        PropertyType(name="id", type="STRING", required=True, description="The unique id of the treatment arm. This is a combination of the study name and treatment arm name. Follows the pattern: <study_name>_<treatment_arm_name>"),
        PropertyType(name="studyName", type="STRING", required=True, description="Name of the study. This is used to uniquely identify the TreatmentArm node."),
        PropertyType(name="name", type="STRING", required=True, description="Name of the treatment arm"),
    ],
)

clinical_outcome = NodeType(
    label="ClinicalOutcome",
    description="A clinical outcome of a treatment arm. This describes the resulting effect a treatment has on a treatment arm population.",
    properties=[
        PropertyType(name="id", type="STRING", required=True, description="The unique id of the clinical outcome. This is a combination of the study name and the name of the outcome. Follows the pattern: <study_name>_<outcome_name>"),
        PropertyType(name="studyName", type="STRING", required=True, description="Name of the study this outcome is associated with. This is used to uniquely identify the ClinicalOutcome node."),
        PropertyType(name="name", type="STRING", required=True, description="Name of the clinical outcome."),
    ],
)

medical_condition = NodeType(
    label="MedicalCondition",
    description="Medical conditions and comorbidities studied",
    properties=[
        PropertyType(name="name", type="STRING", required=True, description="Name of the medical condition"),
        PropertyType(name="category", type="STRING", required=True, description="Category of condition"),
        PropertyType(name="icd10Code", type="STRING", description="ICD-10 code when available"),
    ],
)

study_population = NodeType(
    label="StudyPopulation",
    description="Patient populations and demographics in research studies",
    properties=[
        PropertyType(name="id", type="STRING", required=True, description="The unique id of the study population. This is a combination of the study name and the population description. Follows the pattern: <study_name>_<population_description>"),
        PropertyType(name="studyName", type="STRING", required=True, description="Name of the study. This is used to uniquely identify the StudyPopulation node."),
        PropertyType(name="description", type="STRING", required=True, description="Description of the population"),
        PropertyType(name="minAge", type="INTEGER", description="Minimum age in years"),
        PropertyType(name="maxAge", type="INTEGER", description="Maximum age in years"),
        PropertyType(name="malePercentage", type="FLOAT", description="Percentage of male gender participants"),
        PropertyType(name="femalePercentage", type="FLOAT", description="Percentage of female gender participants"),
        PropertyType(name="otherGenderPercentage", type="FLOAT", description="Percentage of participants that identify as another gender"),
        PropertyType(name="sampleSize", type="INTEGER", description="Number of participants"),
        PropertyType(name="studyType", type="STRING", required=True, description="Type of study"),
        PropertyType(name="inclusionCriteria", type="LIST", description="Key inclusion criteria"),
        PropertyType(name="exclusionCriteria", type="LIST", description="Key exclusion criteria"),
        PropertyType(name="studyDuration", type="STRING", description="Duration of study"),
    ],
)

nodes = [medication, treatment_arm, clinical_outcome, medical_condition, study_population]

### Relationship Definitions

Here we define the Relationship types that will be extracted from the documents.

In [4]:
medication_used_in_treatment_arm = RelationshipType(
    label="USED_IN_TREATMENT_ARM",
    description="Study-specific medication usage - how a Medication was used in a particular TreatmentArm. This describes an instance of a medication that is used in a particular treatment arm. All treatment arms should have a relationship with at least one Medication node.",
    properties=[
        PropertyType(name="dosage", type="STRING", description="Dosage used in this study"),
        PropertyType(name="route", type="STRING", description="Route of administration"),
        PropertyType(name="frequency", type="STRING", description="Dosing frequency"),
        PropertyType(name="treatment_duration", type="STRING", description="Duration of treatment"),
        PropertyType(name="comparator", type="STRING", description="What this was compared against"),
        PropertyType(name="adherence_rate", type="FLOAT", description="Treatment adherence rate"),
        PropertyType(name="formulation", type="STRING", description="Specific formulation used"),
    ],
)

treatment_arm_has_clinical_outcome = RelationshipType(
    label="HAS_CLINICAL_OUTCOME",
    description="Links TreatmentArm to ClinicalOutcome nodes. TreatmentArm nodes should have a relationship with a ClinicalOutcome node.",
)

study_population_has_medical_condition = RelationshipType(
    label="HAS_MEDICAL_CONDITION",
    description="Links StudyPopulation to MedicalCondition nodes. StudyPopulation nodes should have a relationship with a MedicalCondition node.",
   
)

study_population_in_treatment_arm = RelationshipType(
    label="IN_TREATMENT_ARM",
    description="Links StudyPopulation to TreatmentArm nodes. StudyPopulation nodes should have a relationship with a TreatmentArm node.",
   
)

relationships = [medication_used_in_treatment_arm, 
                 treatment_arm_has_clinical_outcome, 
                 study_population_has_medical_condition, 
                 study_population_in_treatment_arm]


And here we must define the explicit patterns that exist in our entity graph. 

These are triples of `(source, relationship, target)`

In [5]:
patterns = [
    ("Medication", "USED_IN_TREATMENT_ARM", "TreatmentArm"),
    ("TreatmentArm", "HAS_CLINICAL_OUTCOME", "ClinicalOutcome"),
    ("StudyPopulation", "HAS_MEDICAL_CONDITION", "MedicalCondition"),
    ("StudyPopulation", "IN_TREATMENT_ARM", "TreatmentArm"),
]

## Processing Pipeline

We now can create a pipeline to read our PDF documents, extract entities and load the data into Neo4j.

### Define Pipeline

Our processing pipeline will flow like this:

**PDF Loader &rarr; Text Splitter &rarr; Entity Extraction &rarr; Neo4j Writer &rarr; Entity Resolver**

**&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Schema Builder &nearr;**


Where
* **PDF Loader** &rarr; Loads the PDF into a text format
* **Text Splitter** &rarr; Chunks text according to configuration
* **Schema Builder** &rarr; Create entity schema according to the nodes, relationships and patterns defined above
* **Entity Extraction** &rarr; Extract entities and relationships from the text chunks according to the defined schema
* **Neo4j Writer** &rarr; Ingests lexical and entity graphs
* **Entity Resolver** &rarr; Resolves `Medication` nodes according to the `name` property

In [6]:
def create_pipeline(llm: LLMInterface, neo4j_driver: Driver) -> Pipeline:
    """
    Create a pipeline for generating a knowledge graph from a PDF file.

    Parameters
    ----------
    llm : LLMInterface
        The LLM to use for the pipeline.
    neo4j_driver : neo4j.Driver
        The Neo4j driver to use for the pipeline.
    nodes : list[NodeType]
        The list of nodes to use for the pipeline.
    relationships : list[RelationshipType]
        The list of relationships to use for the pipeline.
    patterns : list[Tuple[str, str, str]]
        The list of patterns to use for the pipeline.

    Returns
    -------
    Pipeline
        The pipeline for generating a knowledge graph from a PDF file.
    """

    pipe = Pipeline()

    # PdfLoader will load the PDF file into a string.
    pipe.add_component(PdfLoader(), "pdf_loader")

    # FixedSizeSplitter will split the text into chunks of 1000 characters with 200 character overlap.
    pipe.add_component(
        FixedSizeSplitter(chunk_size=1000, chunk_overlap=200, approximate=False),
        "splitter",
    )

    # SchemaBuilder will build the schema for the graph based on the input nodes and relationships.
    pipe.add_component(SchemaBuilder(), "schema")

    # LLMEntityRelationExtractor will extract the entities and relationships from the text.
    pipe.add_component(
        LLMEntityRelationExtractor(
            llm=llm,
            on_error=OnError.RAISE,
        ),
        "extractor",
    )

    # Neo4jWriter will write the graph to the Neo4j database.
    pipe.add_component(Neo4jWriter(neo4j_driver), "writer")

    # FuzzyMatchResolver will resolve our Medication nodes based on their name, generic name, and brand names.
    pipe.add_component(FuzzyMatchResolver(neo4j_driver, 
                                          filter_query="WHERE entity:Medication",
                                          resolve_properties=["name"]), "medication_resolver")
    
    # Connect the components together.
    pipe.connect("pdf_loader", "splitter", input_config={"text": "pdf_loader.text"})
    pipe.connect("splitter", "extractor", input_config={"chunks": "splitter"})
    pipe.connect(
        "schema",
        "extractor",
        input_config={
            "schema": "schema",
            "document_info": "pdf_loader.document_info",
        },
    )
    pipe.connect(
        "extractor",
        "writer",
        input_config={"graph": "extractor"},
    )
    pipe.connect("writer", "medication_resolver")

    return pipe

Our inputs need to be properly formatted, so we create the function below to ensure we are passing the appropriate input object.

In [7]:
def create_pipeline_input(file_path: str, 
                          nodes: list[NodeType], 
                          relationships: list[RelationshipType], 
                          patterns: list[tuple[str, str, str]]) -> dict[str, Any]:
    """
    Create an input for the pipeline.
    This will create a single input to be fed to the knowledge graph generation pipeline.

    Parameters
    ----------
    file_path : str
        The path to the PDF file to load.
    nodes : list[NodeType]
        The list of nodes to use for the pipeline.
    relationships : list[RelationshipType]
        The list of relationships to use for the pipeline.
    patterns : list[tuple[str, str, str]]
        The list of patterns to use for the pipeline.

    Returns
    -------
    dict[str, Any]
        The input for the pipeline.
    """
    return {
        "pdf_loader": {
            "filepath": file_path,
        },
        "schema": {
            "node_types": nodes,
            "relationship_types": relationships,
            "patterns": patterns,
        },
    }

### Define LLM

We will be using OpenAI GPT-4o for our entity extraction process.

Note that we are using the Neo4j GraphRAG Python Package client wrapper.

In [8]:
llm = OpenAILLM(
        model_name="gpt-4o",
        model_params={
            "response_format": {"type": "json_object"},
        },
    )

### Initialize Pipeline

Now we create our Neo4j driver instance...

In [9]:
neo4j_driver = GraphDatabase.driver(os.getenv("NEO4J_URI"), 
                                    auth=(os.getenv("NEO4J_USERNAME"), os.getenv("NEO4J_PASSWORD")),
                                    database=os.getenv("NEO4J_DATABASE"))

and initialize our pipeline defined above with the OpenAI client and Neo4j driver.

In [10]:
pipeline = create_pipeline(llm, neo4j_driver)

### Run Pipeline

Before running our pipeline, we need to collect the PDFs we'd like to process.

In [11]:
PDF_DIR = "../articles/pdf/"

pdf_files = [f for f in os.listdir(PDF_DIR) if f.endswith(".pdf")]

print("Found PDFs")
for pdf in pdf_files:
    print("* ", pdf)

Found PDFs
*  nihms-1852972.pdf
*  fendo-11-00178.pdf
*  Diabetic Medicine - 2023 - Brønden - Effects of DPP‐4 inhibitors  GLP‐1 receptor agonists  SGLT‐2 inhibitors and.pdf
*  jama_rosenstock_2019_oi_190026_1655321720.77793.pdf
*  jciinsight-3-93936.pdf


In [12]:
async def run_pipeline_for_many_pdfs(pipeline: Pipeline, pdf_directory: str, pdf_files: list[str]) -> list[str]:
    """
    Run a collection of PDFs through the provided pipeline.

    Parameters
    ----------
    pipeline : Pipeline
        The pipeline to run the PDFs through.
    pdf_directory : str
        The directory containing the PDFs.
    pdf_files : list[str]
        The list of PDF file names to run through the pipeline.

    Returns
    -------
    list[str]
        The list of PDF file names that failed to be processed.
    """

    failed_cache: list[str] = list()

    async for i, pdf_file in enumerate(pdf_files):
        print(f"Processing PDF {i + 1}/{len(pdf_files)}: {pdf_file}")
        try:
            pdf_path = os.path.join(pdf_directory, pdf_file)
            pipeline_input = create_pipeline_input(pdf_path, nodes, relationships, patterns)
            await pipeline.run(pipeline_input)
        except Exception as e:
            print(f"Error processing PDF {pdf_file}: {e}")
            failed_cache.append(pdf_file)
            continue

    return failed_cache

Now that we've defined the PDF files we'd like to load, we can run the pipeline

In [13]:
failed_pdfs = asyncio.run(run_pipeline_for_many_pdfs(pipeline, PDF_DIR, pdf_files[:2]))

# Basic retry logic for failed PDFs
# Here we retry a document only once
if len(failed_pdfs) > 0:
    print(f"Failed to process {len(failed_pdfs)} PDFs: {failed_pdfs}")
    print("Retrying failed PDFs...")
    failed_pdfs = asyncio.run(run_pipeline_for_many_pdfs(pipeline, PDF_DIR, failed_pdfs))
    print(f"Failed to process {len(failed_pdfs)} PDFs: {failed_pdfs}")

Processing PDF 1/2: nihms-1852972.pdf


Multiple definitions in dictionary at byte 0x2d4a7 for key /MediaBox
Multiple definitions in dictionary at byte 0x2d6e8 for key /MediaBox
Multiple definitions in dictionary at byte 0x2d8d2 for key /MediaBox
Multiple definitions in dictionary at byte 0x2da94 for key /MediaBox
Multiple definitions in dictionary at byte 0x2dc2e for key /MediaBox
Multiple definitions in dictionary at byte 0x2de38 for key /MediaBox
Multiple definitions in dictionary at byte 0x2e01a for key /MediaBox
Multiple definitions in dictionary at byte 0x2e204 for key /MediaBox
Multiple definitions in dictionary at byte 0x2e41e for key /MediaBox
Multiple definitions in dictionary at byte 0x2e630 for key /MediaBox
Multiple definitions in dictionary at byte 0x2e8aa for key /MediaBox
Multiple definitions in dictionary at byte 0x2eafc for key /MediaBox
Multiple definitions in dictionary at byte 0x2ed5e for key /MediaBox


Processing PDF 2/2: fendo-11-00178.pdf


Finally, we should close our driver and client connections

In [None]:
neo4j_driver.close()
await llm.async_client.close()