# Graph RAG implementation

## Setup

In [163]:
from langchain.graphs import Neo4jGraph
import ast

# from langchain.vectorstores.neo4j_vector import Neo4jVector
# from langchain.embeddings.openai import OpenAIEmbeddings
# from langchain_ollama import OllamaEmbeddings
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer
import ollama
import os

Here we use the default embedding model from the AuraDB graph builder: sentence-transformers/all-MiniLM-L6-v2

Load in the neo4j database information for authentication

In [164]:
load_dotenv()

True

Connect to our hosted db

In [165]:
url = "neo4j+s://b9009f0e.databases.neo4j.io"
username = "neo4j"
# get password from .env
password = os.getenv("NEO4J_PASSWORD")

graph = Neo4jGraph(url=url, username=username, password=password)

## Vector RAG

Key steps:
* embed the whole query
* search for the closest nodes
* extract the chunks associated with that node

### Vector Index Creation

In [166]:
def create_vector_index(graph, name):
    graph.query(f"DROP INDEX `{name}` IF EXISTS")
    graph.query(
        f"""
    CREATE VECTOR INDEX `{name}`
    FOR (a:__Entity__) ON (a.embedding)
    OPTIONS {{
      indexConfig: {{
        `vector.dimensions`: 384,
        `vector.similarity_function`: 'cosine'
      }}
    }}
    """
    )

In [167]:
create_vector_index(graph, "entities")

### Vector Search

In [168]:
def vector_search(graph, query_embedding, index_name="entities", k=5):
    similarity_query = f"""
    MATCH (n:`__Entity__`)
    CALL db.index.vector.queryNodes('{index_name}', {k}, {query_embedding})
    YIELD node, score
    RETURN DISTINCT node.id, score
    ORDER BY score DESC
    LIMIT {k}

    """
    result = graph.query(similarity_query)
    return result

In [169]:
def embed_entity(entity):
    embeddings = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    return embeddings.encode(entity).tolist()

### Chunk Extraction

In [170]:
def chunk_finder(graph, query):

    # get the id of the query associated node
    query_embedding = embed_entity(query)
    response = vector_search(graph, query_embedding)
    id = response[0]["node.id"]

    chunk_find_query = f"""
    MATCH (n:Chunk)-[r]->(m:`__Entity__` {{id: "{id}"}}) RETURN n.text,n.fileName LIMIT 8
    """
    result = graph.query(chunk_find_query)
    output = []
    for record in result:
        output.append((record["n.fileName"], record["n.text"]))
    return output

In [171]:
query = "blood platelet"
print(chunk_finder(graph, query))

[('ALBIOS.txt', ' within the intravascular compartment and, in addition, possible effects of albumin as a scavenger of nitric oxide,12 mediating peripheral vasodilatation during sepsis.25,26'), ('ALBIOS.txt', 'The secondary outcomes also provide a detailed profile of the safety of albumin administration during severe sepsis. The incidence of new organ failures during the study was similar in the two groups. We observed slightly higher average SOFA subscores for liver and coagulation in the albumin group, indicating a higher serum bilirubin and a lower platelet count, respectively, than were observed in the crystalloid group. Nonetheless, the absolute excess in the serum bilirubin concentration in the albumin group was marginal (median, 1.0 mg per deciliter [interquartile range, 0.6 to 1.7] vs. 0.9 mg per deciliter [interquartile range, 0.5 to 1.5], P<0.001) and was probably related to the methods used to prepare albumin solutions, which may be inefficient in clearing bilirubin content 

## Graph RAG

### Entity Extraction

Key steps:
* Extract entities from the prompt
* Embed the entities
* query the graph db for the most similar nodes and their relationships
* generate a response


In [172]:
def get_entities(prompt, correction_context=" "):

    prompt = f"""
    You are a highly capable natural language processing assistant with extensive medical knowledge. 
    Your task is to extract medical entities from a given prompt. 
    Entities are specific names, places, dates, times, objects, organizations, or other identifiable items explicitly mentioned in the text.
    Please output the entities as a list of strings in the format ["string 1", "string 2"]. Do not include duplicates. 
    Do not include any other text. Always include at least one entity.

    {correction_context}

    Here is the input prompt:
    {prompt}

    Extracted entities: 
    """
    # use generate because we are not chatting with this instance of 3.2
    output = ollama.generate(model="llama3.2:latest", prompt=prompt)
    response = output.response

    # add some error handling to get a list of strings (recursively call the extractor with added context)
    try:
        response = ast.literal_eval(response)
        if not isinstance(response, list):
            correction_string = f"The previous output threw this error: Expected a list of strings, but got {type(response)} with value {response}"
            response = get_entities(prompt, correction_context=correction_string)
    except (ValueError, SyntaxError) as e:
        print(f"Error converting to list: {e}")
        response = get_entities(prompt)

    return response, correction_context

> This will be an interesting design choice, what qualifies as an entity? We can always adjust the context of provided to the model at inference.

In [173]:
test_prompt = """The blood platelet is a type of cell that helps blood to clot. 
John Coogan is also important. 
Surgeon is a noun, sepsis is also a disease.
Many of the Research Papers are about sepsis and what treatments may be effective."""
resp, corr_cont = get_entities(test_prompt)
print(resp, f"corr_cont: {corr_cont}")
print(type(resp))
print("*******************************************")
test_prompt2 = "What is the leading cause of sepsis?"
resp, corr_cont = get_entities(test_prompt2)
print(resp, f"corr_cont: {corr_cont}")
print(type(resp))

['blood platelet', 'John Coogan', 'sepsis', 'Research Papers'] corr_cont:  
<class 'list'>
*******************************************
['sepsis'] corr_cont:  
<class 'list'>


> here we can use regex to search for nodes with similar ids to the entities but that is not efficient. We will use the embedding of the entities to search for the most similar node embeddings like with vector search.

### Graph Retrieval

In [174]:
def graph_retriever(graph, query):
    entities, _ = get_entities(query)
    ids = []
    for entity in entities:
        embedding = embed_entity(entity)
        closest_node = vector_search(graph, embedding, k=1)
        id = closest_node[0]["node.id"]
        ids.append(id)
    context = ""
    for id in ids:
        neighbors_query = f"""
        MATCH path = (n:`__Entity__` {{id:"{id}"}})-[r*..2]-(m:`__Entity__`)
        WHERE ALL(rel IN relationships(path) WHERE NOT type(rel) IN ['HAS_ENTITY', 'MENTIONS'])
        RETURN 
        n.id AS startNode,
        [rel IN relationships(path) | 
            {{
            type: type(rel),
            direction: CASE 
                WHEN startNode(rel) = n THEN "outgoing" 
                WHEN endNode(rel) = n THEN "incoming" 
                ELSE "undirected"
            END
            }}] AS relationshipDetails,
        [node IN nodes(path) | node.id] AS pathNodes
        """
        result = graph.query(neighbors_query)
        for record in result:
            rel = record["relationshipDetails"]
            pathNodes = record["pathNodes"]
            formatted_path = ""
            for i in range(len(rel)):
                if rel[i]["direction"] == "outgoing":
                    formatted_path += (
                        f" {pathNodes[i]} {rel[i]['type']} {pathNodes[i+1]},"
                    )
                elif rel[i]["direction"] == "incoming":
                    formatted_path += (
                        f" {pathNodes[i+1]} {rel[i]['type']} {pathNodes[i]},"
                    )
                else:
                    formatted_path += (
                        f" {pathNodes[i]} {rel[i]['type']} {pathNodes[i+1]},"
                    )
            context += formatted_path + "\n"

    return context

In [175]:
test_prompt = """The blood platelet is a type of cell that helps blood to clot. 
John Coogan is also important. 
Surgeon is a noun, sepsis is also a disease.
Many of the Research Papers are about sepsis and what treatments may be effective."""
resp = graph_retriever(graph, test_prompt)
print(resp)

 Albumin GROUP_HAS_LOWER Platelet Count,
 Albumin GROUP_HAS_LOWER Platelet Count, Albumin EXPANDS Intravascular Compartment,
 Albumin GROUP_HAS_LOWER Platelet Count, Albumin GROUP_HAS_LOWER Mortality at 90 Days,
 Albumin GROUP_HAS_LOWER Platelet Count, Albumin GROUP_HAS_HIGHER Serum Bilirubin,
 Albumin GROUP_HAS_LOWER Platelet Count, Albumin ADMINISTERED_DURING Severe Sepsis,
 Albumin GROUP_HAS_LOWER Platelet Count, Albumin MEDIATES Peripheral Vasodilatation,
 Albumin GROUP_HAS_LOWER Platelet Count, Albumin SCAVENGES Nitric Oxide,
 Albumin GROUP_HAS_LOWER Platelet Count, Albumin HAS_PROPERTY buffer molecule for acid–base equilibrium,
 Albumin GROUP_HAS_LOWER Platelet Count, Albumin HAS_PROPERTY plasma colloid osmotic pressure,
 Albumin GROUP_HAS_LOWER Platelet Count, Albumin HAS_PROPERTY antioxidant and antiinflammatory properties,
 Albumin GROUP_HAS_LOWER Platelet Count, Albumin HAS_PROPERTY scavenger of reactive oxygen and nitrogen species,
 Albumin GROUP_HAS_LOWER Platelet Count, Al

## RAG Chain

### Context generation

In [None]:
def context_builder(graph, query, method="hybrid"):
    """
    This function performs vector search, graph search, or both to build a context string for
    an LLM

    Args:
    graph: Neo4jGraph object
    query: string

    Returns:
    context: string
    """
    context = ""
    if method == "vector":
        output = chunk_finder(graph, query)
        context = "Given the following context in the format [(File Name, Text),...] \n"
        context += str(output)

    elif method == "graph":
        context = graph_retriever(graph, query)
    elif method == "hybrid":

        context = (
            graph_retriever(graph, query)
            + "\n And Given the following context in the format [(File Name, Text),...] \n"
            + str(chunk_finder(graph, query))
        )
    else:
        pass  # no context
    return context

In [None]:
def generate_response(graph, query, method="hybrid", model="llama3.2:latest"):
    context = context_builder(graph, query, method)
    prompt = f""" 
    You are a highly capable natural language processing assistant with extensive medical knowledge.
    Answer the following question based on the provided context:
    Question: {query}
    Context: {context}
    """

    response = ollama.generate(model=model, prompt=prompt)
    return response, prompt

### Testing Outputs

In [None]:
Question = "What is EGDT?"
response, context = generate_response(graph, Question)
print(response.response)
# print(context)

EGDT stands for Early Goal-Directed Therapy. It is a treatment protocol that involves administering intravenous fluids and vasoactive drugs to patients with severe sepsis, as well as antimicrobial therapy, in an effort to improve their chances of survival.

The name "Early Goal-Directed Therapy" refers to the idea that early intervention can help guide the treatment plan to achieve specific "goals," such as restoring adequate blood flow and oxygenation, and correcting any underlying imbalances.

However, according to the provided context, EGDT does not appear to have a significant impact on reducing mortality rates in patients with early septic shock. The ARISE study found no reduction in 90-day all-cause mortality among patients treated with EGDT compared to those receiving usual care.

In fact, some of the context highlights concerns and controversies surrounding the use of EGDT, such as potential risks associated with individual elements of the protocol, uncertainty about external v

In [None]:
vector_response, vector_context = generate_response(graph, Question, method="vector")
print(vector_response.response)
# print(vector_context)

Based on the provided context, EGDT stands for Early Goal-Directed Therapy. It is a protocol of hemodynamic resuscitation that was initially shown to improve outcomes in patients presenting to the emergency department with severe sepsis in a 2001 proof-of-concept trial. However, subsequent trials and studies have raised concerns about its effectiveness and potential risks, leading to controversy surrounding its role in treating patients with severe sepsis.


In [None]:
graph_response, graph_context = generate_response(graph, Question, method="graph")
print(graph_response.response)
# print(graph_context)

EGDT stands for Early Goal-Directed Therapy. It is a treatment protocol that was developed to improve outcomes in patients with sepsis, a life-threatening condition caused by an infection. The goal of EGDt is to quickly and aggressively treat patients with sepsis, using a combination of antibiotics, fluid resuscitation, and vasopressors (medications that constrict blood vessels) to stabilize their vital signs and improve oxygen delivery to organs.

The EGDt protocol was widely adopted in the early 2000s, but it has been the subject of much controversy and debate in recent years. While some studies suggested that EGDt could reduce mortality rates in patients with sepsis, more recent research has failed to replicate these findings, suggesting that EGDt may not provide a survival benefit for all patients.

In fact, several large trials have shown that EGDt does not reduce 90-day all-cause mortality in patients with early septic shock. Additionally, some studies have found no significant d

In [194]:
base_response, base_context = generate_response(graph, Question, method="none")
print(base_response.response)

I'm happy to help, but it seems like you forgot to provide the context for the question. Could you please complete the sentence or provide more information about what "EGDT" refers to? I'll do my best to answer your question based on my medical knowledge and natural language processing capabilities.

That being said, I can suggest a few possibilities:

* EGDt could refer to Endoscopic Gastric Drainage Treatment, which is a treatment approach for certain types of gastritis or gastric ulcers.
* EGDt might stand for Endoscopic Gastrointestinal Disease Treatment, which encompasses various treatments for gastrointestinal disorders.

If you provide more context or clarify what "EGDT" means in your question, I'll be happy to give a more specific and accurate answer.


In [None]:
granite_response, granite_context = generate_response(
    graph, Question, model="granite3-dense:2b", method="hybrid"
)
print(granite_response.response)

EGDT, or Early Goal-Directed Therapy, is a medical intervention used in the treatment of severe sepsis. It involves a specific protocol for hemodynamic resuscitation in patients presenting to the emergency department with early septic shock. The ARISE study aimed to test the hypothesis that EGDT would decrease 90-day all-cause mortality in these patients compared to usual care. However, the study found that EGDT did not reduce all-cause mortality at 90 days. Despite some nonrandomized studies showing survival benefits with bundle-based care that included EGDT, there is considerable controversy surrounding the role of EGDT in treating patients with severe sepsis due to concerns about potential risks, external validity, and infrastructure requirements.


In [None]:
granite_base_response, granite_base_context = generate_response(
    graph, Question, model="granite3-dense:2b", method="none"
)
print(granite_base_response.response)

EGDT, or Emergency General Surgery Department, is a department in a hospital that specializes in treating emergency surgical cases. It is designed to handle urgent and critical surgical situations that require immediate attention. The department typically operates 24/7 and has the necessary resources and equipment to perform a wide range of surgical procedures.
