In [27]:
import os
import polars as pl
from langchain_community.graphs import Neo4jGraph
from langchain.chains import GraphCypherQAChain
from langchain_openai import ChatOpenAI
from langchain.prompts import (
    PromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
    ChatPromptTemplate,
)
from neo4j import GraphDatabase
from langchain.agents import create_openai_functions_agent, Tool, AgentExecutor
from langchain import hub
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain_openai import OpenAIEmbeddings
from langchain.chains import RetrievalQA
from dotenv import load_dotenv
load_dotenv("../.env")

True

In [28]:
HOSPITAL_QA_MODEL = os.getenv("HOSPITAL_QA_MODEL")
HOSPITAL_CYPHER_MODEL = os.getenv("HOSPITAL_CYPHER_MODEL")

graph = Neo4jGraph(
    url=os.getenv("NEO4J_URI"),
    username=os.getenv("NEO4J_USERNAME"),
    password=os.getenv("NEO4J_PASSWORD"),
)

graph.refresh_schema()

In [4]:
graph.query("""
       MATCH (n:Hospital) RETURN n;
""")

[{'n': {'state_name': 'CO', 'name': 'Wallace-Hamilton', 'id': 0}},
 {'n': {'state_name': 'NC', 'name': 'Burke, Griffin and Cooper', 'id': 1}},
 {'n': {'state_name': 'FL', 'name': 'Walton LLC', 'id': 2}},
 {'n': {'state_name': 'NC', 'name': 'Garcia Ltd', 'id': 3}},
 {'n': {'state_name': 'NC', 'name': 'Jones, Brown and Murray', 'id': 4}},
 {'n': {'state_name': 'GA', 'name': 'Boyd PLC', 'id': 5}},
 {'n': {'state_name': 'FL', 'name': 'Wheeler, Bryant and Johns', 'id': 6}},
 {'n': {'state_name': 'FL', 'name': 'Brown Inc', 'id': 7}},
 {'n': {'state_name': 'FL', 'name': 'Smith, Edwards and Obrien', 'id': 8}},
 {'n': {'state_name': 'NC', 'name': 'Brown-Golden', 'id': 9}},
 {'n': {'state_name': 'CO', 'name': 'Little-Spencer', 'id': 10}},
 {'n': {'state_name': 'FL', 'name': 'Rose Inc', 'id': 11}},
 {'n': {'state_name': 'NC', 'name': 'Malone, Thompson and Mejia', 'id': 12}},
 {'n': {'state_name': 'GA', 'name': 'Mcneil-Ali', 'id': 13}},
 {'n': {'state_name': 'CO', 'name': 'Jones, Taylor and Garcia

#### Few shot Cypher RAG
TODO:
- Create CSV with example questions and answers
- Add to initial ETL
- Create Questions node with properties "question" and "cypher"
- Write function that creates Question nodes from a question and corresponding cypher query
- Need to be able to embed questions on the fly

In [30]:
questions = [
    "Who is the oldest patient and how old are they?",
    "Which physician has billed the least to Cigna",
    "How many non-emergency patients in North Carolina have written reviews?",
]

cypher_answers = [
    """
    MATCH (p:Payer)<-[c:COVERED_BY]-(v:Visit)-[t:TREATS]-(phy:Physician)
    WHERE p.name = 'Cigna'
    RETURN phy.name AS physician_name, SUM(c.billing_amount) AS total_billed
    ORDER BY total_billed
    LIMIT 1
    """,
    """
    MATCH (p:Payer)<-[c:COVERED_BY]-(v:Visit)-[t:TREATS]-(phy:Physician)
    WHERE p.name = 'Cigna'
    RETURN phy.name AS physician_name, SUM(c.billing_amount) AS total_billed
    ORDER BY total_billed
    LIMIT 1
    """,
    """
    match (r:Review)<-[:WRITES]-(v:Visit)-[:AT]->(h:Hospital)
    where h.state_name = 'NC' and v.admission_type <> 'Emergency'
    return count(*)
    """
]

data = pl.DataFrame({
    "questions":questions,
    "cypher":cypher_answers
})



'''# Which state had the largest percent increase in Cigna visits
# from 2022 to 2023?
MATCH (h:Hospital)<-[:AT]-(v:Visit)-[:COVERED_BY]->(p:Payer)
WHERE p.name = 'Cigna' AND v.admission_date >= '2022-01-01' AND
v.admission_date < '2024-01-01'
WITH h.state_name AS state, COUNT(v) AS visit_count,
     SUM(CASE WHEN v.admission_date >= '2022-01-01' AND
     v.admission_date < '2023-01-01' THEN 1 ELSE 0 END) AS count_2022,
     SUM(CASE WHEN v.admission_date >= '2023-01-01' AND
     v.admission_date < '2024-01-01' THEN 1 ELSE 0 END) AS count_2023
WITH state, visit_count, count_2022, count_2023,
     (toFloat(count_2023) - toFloat(count_2022)) / toFloat(count_2022) * 100
     AS percent_increase
RETURN state, percent_increase
ORDER BY percent_increase DESC
LIMIT 1
'''

data.write_csv("../data/example_cypher.csv")

In [23]:
def add_document_node(uri, user, password, question, cypher, embedding):
    driver = GraphDatabase.driver(uri, auth=(user, password))
    with driver.session() as session:
        session.write_transaction(create_document_node, question, cypher, embedding)
    driver.close()

def create_document_node(tx, question, cypher, embedding):
    query = (
        "CREATE (d:Document {question: $question, cypher: $cypher, embedding: $embedding})"
    )
    tx.run(query, question=question, cypher=cypher, embedding=embedding)

In [None]:
neo4j_example_cypher_index = Neo4jVector.from_documents(
    
)

In [None]:
neo4j_vector_index = Neo4jVector.from_existing_graph(
    embedding=OpenAIEmbeddings(),
    url=os.getenv("NEO4J_URI"),
    username=os.getenv("NEO4J_USERNAME"),
    password=os.getenv("NEO4J_PASSWORD"),
    index_name="reviews",
    node_label="Review",
    text_node_properties=[
        "physician_name",
        "patient_name",
        "text",
        "hospital_name",
    ],
    embedding_node_property="embedding",
)

#### Cypher chain

In [5]:
cypher_generation_template = """
Task:
Generate Cypher query for a Neo4j graph database.

Instructions:
Use only the provided relationship types and properties in the schema.
Do not use any other relationship types or properties that are not provided.

Schema:
{schema}

Note:
Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything other than
for you to construct a Cypher statement. Do not include any text except
the generated Cypher statement. Make sure the direction of the relationship is
correct in your queries. Make sure you alias both entities and relationships
properly. Do not run any queries that would add to or delete from
the database. Make sure to alias all statements that follow as with
statement (e.g. WITH v as visit, c.billing_amount as billing_amount)
If you need to divide numbers, make sure to
filter the denominator to be non zero.

Examples:
# Who is the oldest patient and how old are they?
MATCH (p:Patient)
RETURN p.name AS oldest_patient,
       duration.between(date(p.dob), date()).years AS age
ORDER BY age DESC
LIMIT 1

# Which physician has billed the least to Cigna
MATCH (p:Payer)<-[c:COVERED_BY]-(v:Visit)-[t:TREATS]-(phy:Physician)
WHERE p.name = 'Cigna'
RETURN phy.name AS physician_name, SUM(c.billing_amount) AS total_billed
ORDER BY total_billed
LIMIT 1

# Which state had the largest percent increase in Cigna visits
# from 2022 to 2023?
MATCH (h:Hospital)<-[:AT]-(v:Visit)-[:COVERED_BY]->(p:Payer)
WHERE p.name = 'Cigna' AND v.admission_date >= '2022-01-01' AND
v.admission_date < '2024-01-01'
WITH h.state_name AS state, COUNT(v) AS visit_count,
     SUM(CASE WHEN v.admission_date >= '2022-01-01' AND
     v.admission_date < '2023-01-01' THEN 1 ELSE 0 END) AS count_2022,
     SUM(CASE WHEN v.admission_date >= '2023-01-01' AND
     v.admission_date < '2024-01-01' THEN 1 ELSE 0 END) AS count_2023
WITH state, visit_count, count_2022, count_2023,
     (toFloat(count_2023) - toFloat(count_2022)) / toFloat(count_2022) * 100
     AS percent_increase
RETURN state, percent_increase
ORDER BY percent_increase DESC
LIMIT 1

# How many non-emergency patients in North Carolina have written reviews?
match (r:Review)<-[:WRITES]-(v:Visit)-[:AT]->(h:Hospital)
where h.state_name = 'NC' and v.admission_type <> 'Emergency'
return count(*)

String category values:
Test results are one of: 'Inconclusive', 'Normal', 'Abnormal'
Visit statuses are one of: 'OPEN', 'DISCHARGED'
Admission Types are one of: 'Elective', 'Emergency', 'Urgent'
Payer names are one of: 'Cigna', 'Blue Cross', 'UnitedHealthcare', 'Medicare',
'Aetna'

A visit is considered open if its status is 'OPEN' and the discharge date is
missing.
Use abbreviations when
filtering on hospital states (e.g. "Texas" is "TX",
"Colorado" is "CO", "North Carolina" is "NC",
"Florida" is "FL", "Georgia" is "GA, etc.)

Make sure to use IS NULL or IS NOT NULL when analyzing missing properties.
Never return embedding properties in your queries. You must never include the
statement "GROUP BY" in your query. Make sure to alias all statements that
follow as with statement (e.g. WITH v as visit, c.billing_amount as
billing_amount)
If you need to divide numbers, make sure to filter the denominator to be non
zero.

The question is:
{question}
"""

cypher_generation_prompt = PromptTemplate(
    input_variables=["schema", "question"], template=cypher_generation_template
)

qa_generation_template = """You are an assistant that takes the results
from a Neo4j Cypher query and forms a human-readable response. The
query results section contains the results of a Cypher query that was
generated based on a users natural language question. The provided
information is authoritative, you must never doubt it or try to use
your internal knowledge to correct it. Make the answer sound like a
response to the question.

Query Results:
{context}

Question:
{question}

If the provided information is empty, say you don't know the answer.
Empty information looks like this: []

If the information is not empty, you must provide an answer using the
results. If the question involves a time duration, assume the query
results are in units of days unless otherwise specified.

When names are provided in the query results, such as hospital names,
beware  of any names that have commas or other punctuation in them.
For instance, 'Jones, Brown and Murray' is a single hospital name,
not multiple hospitals. Make sure you return any list of names in
a way that isn't ambiguous and allows someone to tell what the full
names are.

Never say you don't have the right information if there is data in
the query results. Make sure to show all the relevant query results
if you're asked.

Helpful Answer:
"""

qa_generation_prompt = PromptTemplate(
    input_variables=["context", "question"], template=qa_generation_template
)

hospital_cypher_chain = GraphCypherQAChain.from_llm(
    cypher_llm=ChatOpenAI(model=HOSPITAL_CYPHER_MODEL, temperature=0),
    qa_llm=ChatOpenAI(model=HOSPITAL_QA_MODEL, temperature=0),
    graph=graph,
    verbose=True,
    qa_prompt=qa_generation_prompt,
    cypher_prompt=cypher_generation_prompt,
    validate_cypher=True,
    top_k=100,
)

In [6]:
hospital_cypher_chain.invoke("Which hospitals are in the system?")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (h:Hospital)
RETURN h.name AS hospital_name, h.state_name AS state
ORDER BY state, hospital_name[0m
Full Context:
[32;1m[1;3m[{'hospital_name': 'Bell, Mcknight and Willis', 'state': 'CO'}, {'hospital_name': 'Castaneda-Hardy', 'state': 'CO'}, {'hospital_name': 'Huynh PLC', 'state': 'CO'}, {'hospital_name': 'Jones, Taylor and Garcia', 'state': 'CO'}, {'hospital_name': 'Little-Spencer', 'state': 'CO'}, {'hospital_name': 'Pearson LLC', 'state': 'CO'}, {'hospital_name': 'Vaughn PLC', 'state': 'CO'}, {'hospital_name': 'Wallace-Hamilton', 'state': 'CO'}, {'hospital_name': 'Brown Inc', 'state': 'FL'}, {'hospital_name': 'Jordan Inc', 'state': 'FL'}, {'hospital_name': 'Lewis-Nelson', 'state': 'FL'}, {'hospital_name': 'Rose Inc', 'state': 'FL'}, {'hospital_name': 'Smith, Edwards and Obrien', 'state': 'FL'}, {'hospital_name': 'Walton LLC', 'state': 'FL'}, {'hospital_name': 'Wheeler, Bryant and Johns', 'sta

{'query': 'Which hospitals are in the system?',
 'result': 'The hospitals in the system are:\n1. Bell, Mcknight and Willis in CO\n2. Castaneda-Hardy in CO\n3. Huynh PLC in CO\n4. Jones, Taylor and Garcia in CO\n5. Little-Spencer in CO\n6. Pearson LLC in CO\n7. Vaughn PLC in CO\n8. Wallace-Hamilton in CO\n9. Brown Inc in FL\n10. Jordan Inc in FL\n11. Lewis-Nelson in FL\n12. Rose Inc in FL\n13. Smith, Edwards and Obrien in FL\n14. Walton LLC in FL\n15. Wheeler, Bryant and Johns in FL\n16. Boyd PLC in GA\n17. Mcneil-Ali in GA\n18. Pugh-Rogers in GA\n19. Richardson-Powell in GA\n20. Shea LLC in GA\n21. Brown-Golden in NC\n22. Burch-White in NC\n23. Burke, Griffin and Cooper in NC\n24. Garcia Ltd in NC\n25. Jones, Brown and Murray in NC\n26. Malone, Thompson and Mejia in NC\n27. Rush, Owens and Johnson in NC\n28. Cunningham and Sons in TX\n29. Schultz-Powers in TX\n30. Taylor and Sons in TX'}

#### Review chain

In [19]:
neo4j_vector_index = Neo4jVector.from_existing_graph(
    embedding=OpenAIEmbeddings(),
    url=os.getenv("NEO4J_URI"),
    username=os.getenv("NEO4J_USERNAME"),
    password=os.getenv("NEO4J_PASSWORD"),
    index_name="reviews",
    node_label="Review",
    text_node_properties=[
        "physician_name",
        "patient_name",
        "text",
        "hospital_name",
    ],
    embedding_node_property="embedding",
)

review_template = """Your job is to use patient
reviews to answer questions about their experience at a hospital. Use
the following context to answer questions. Be as detailed as possible, but
don't make up any information that's not from the context. If you don't know
an answer, say you don't know.
{context}
"""

review_system_prompt = SystemMessagePromptTemplate(
    prompt=PromptTemplate(
        input_variables=["context"], template=review_template
    )
)

review_human_prompt = HumanMessagePromptTemplate(
    prompt=PromptTemplate(input_variables=["question"], template="{question}")
)
messages = [review_system_prompt, review_human_prompt]

review_prompt = ChatPromptTemplate(
    input_variables=["context", "question"], messages=messages
)

reviews_vector_chain = RetrievalQA.from_chain_type(
    llm=ChatOpenAI(model=HOSPITAL_QA_MODEL, temperature=0),
    chain_type="stuff",
    retriever=neo4j_vector_index.as_retriever(k=12),
)
reviews_vector_chain.combine_documents_chain.llm_chain.prompt = review_prompt


#### Agent

In [11]:
HOSPITAL_AGENT_MODEL = os.getenv("HOSPITAL_AGENT_MODEL")

hospital_agent_prompt = hub.pull("hwchase17/openai-functions-agent")

tools = [
    Tool(
        name="Experiences",
        func=reviews_vector_chain.invoke,
        description="""Useful when you need to answer questions
        about patient experiences, feelings, or any other qualitative
        question that could be answered about a patient using semantic
        search. Not useful for answering objective questions that involve
        counting, percentages, aggregations, or listing facts. Use the
        entire prompt as input to the tool. For instance, if the prompt is
        "Are patients satisfied with their care?", the input should be
        "Are patients satisfied with their care?".
        """,
    ),
    Tool(
        name="Graph",
        func=hospital_cypher_chain.invoke,
        description="""Useful for answering questions about patients,
        physicians, hospitals, insurance payers, patient review
        statistics, and hospital visit details. Use the entire prompt as
        input to the tool. For instance, if the prompt is "How many visits
        have there been?", the input should be "How many visits have
        there been?".
        """,
    ),
]

chat_model = ChatOpenAI(
    model=HOSPITAL_AGENT_MODEL,
    temperature=0,
)

hospital_rag_agent = create_openai_functions_agent(
    llm=chat_model,
    prompt=hospital_agent_prompt,
    tools=tools,
)

hospital_rag_agent_executor = AgentExecutor(
    agent=hospital_rag_agent,
    tools=tools,
    return_intermediate_steps=True,
    verbose=True,
)

In [17]:
hospital_rag_agent_executor.invoke({"input":"What is the average duration in days for closed emergency visits?"})



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `Graph` with `What is the average duration in days for closed emergency visits?`


[0m

[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (v:Visit)-[:AT]->(h:Hospital)
WHERE v.status = 'DISCHARGED' AND v.admission_type = 'Emergency'
WITH v, duration.between(date(v.admission_date), date(v.discharge_date)).days AS duration
RETURN AVG(duration) AS average_duration_days[0m
Full Context:
[32;1m[1;3m[{'average_duration_days': 15.117444409646081}][0m

[1m> Finished chain.[0m
[33;1m[1;3m{'query': 'What is the average duration in days for closed emergency visits?', 'result': 'The average duration in days for closed emergency visits is approximately 15.12 days.'}[0m[32;1m[1;3mThe average duration in days for closed emergency visits is approximately 15.12 days.[0m

[1m> Finished chain.[0m


{'input': 'What is the average duration in days for closed emergency visits?',
 'output': 'The average duration in days for closed emergency visits is approximately 15.12 days.',
 'intermediate_steps': [(AgentActionMessageLog(tool='Graph', tool_input='What is the average duration in days for closed emergency visits?', log='\nInvoking: `Graph` with `What is the average duration in days for closed emergency visits?`\n\n\n', message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{"__arg1":"What is the average duration in days for closed emergency visits?"}', 'name': 'Graph'}})]),
   {'query': 'What is the average duration in days for closed emergency visits?',
    'result': 'The average duration in days for closed emergency visits is approximately 15.12 days.'})]}