In [49]:
%%capture
%pip install torch torch_geometric stark-qa neo4j python-dotenv pcst_fast datasets pandas transformers langchain langchain-openai langchain-community

## Setup

load env variables

In [1]:
from dotenv import load_dotenv
import os

load_dotenv('.env', override=True)
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

Get stark-qa dataset

In [4]:
from stark_qa import load_qa, load_skb

dataset_name = 'prime'

# Load the retrieval dataset
qa_dataset = load_qa(dataset_name)

Use file from /Users/zachblumenfeld/.cache/huggingface/hub/datasets--snap-stanford--stark/snapshots/e3387ba9c873528521e125c3da687092d9872104/qa/prime/stark_qa/stark_qa_human_generated_eval_uncleaned.csv.


In [5]:
qa_dataset.data

Unnamed: 0,id,query,answer_ids
0,0,Could you identify any skin diseases associate...,[95886]
1,1,What drugs target the CYP3A4 enzyme and are us...,[15450]
2,2,What is the name of the condition characterize...,"[98851, 98853]"
3,3,What drugs are used to treat epithelioid sarco...,[15698]
4,4,Can you supply a compilation of genes and prot...,"[7161, 22045]"
...,...,...,...
11199,11199,Which gene or protein is not expressed in fema...,[2414]
11200,11200,Could you identify a biological pathway in whi...,[128199]
11201,11201,Is there an interaction between genes or prote...,"[127611, 62903]"
11202,11202,Which pharmacological agents that stimulate os...,[20180]


## Vector Retriever

In [2]:
import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings

driver = neo4j.GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))

#create text embedder
embedder = OpenAIEmbeddings()

In [13]:
from neo4j_graphrag.retrievers import VectorRetriever

vector_retriever = VectorRetriever(
    driver,
    index_name="text_embeddings",
    embedder=embedder,
    return_properties=["details", "name", "nodeId"],
)

In [34]:
from neo4j_graphrag.llm import OpenAILLM as LLM
from neo4j_graphrag.generation import RagTemplate
from neo4j_graphrag.generation.graphrag import GraphRAG

llm=LLM(
    model_name="gpt-4o",
    model_params={
        "response_format": {"type": "json_object"}, # use json_object formatting for best results
        "temperature": 0, # turning temperature down for more deterministic results
        "seed": 42
    }
)

rag_template = RagTemplate(template='''
Answer the Question using the following Context. Only respond with information mentioned in the Context. Do not inject any speculative information not mentioned. 

Return result as JSON using the following format:
{{"nodeIds": [list of nodeIds containing the correct answers in order of probability highest to smallest], "explanation": "written explanation as to why the node id answers were provided"}}

# Question:
{query_text}
 
# Context:
{context}

# Answer:
''', expected_inputs=['query_text', 'context'])

v_rag  = GraphRAG(llm=llm, retriever=vector_retriever, prompt_template=rag_template)

In [35]:
q=qa_dataset.data['query'][0]
print(q)

v_res = v_rag.search(q, return_context=True, retriever_config={"top_k": 10})

Could you identify any skin diseases associated with epithelial skin neoplasms? I've observed a tiny, yellowish lesion on sun-exposed areas of my face and neck, and I suspect it might be connected.


In [36]:
import json

print(json.dumps(eval(v_res.answer), indent=1))

{
 "nodeIds": [
  96054,
  96057,
  98637,
  39254
 ],
 "explanation": "The tiny, yellowish lesion observed on sun-exposed areas of the face and neck could be associated with basal cell carcinoma, a type of epithelial skin neoplasm. The context provides information on various types of basal cell carcinoma, which often develop on sun-exposed parts of the body, especially the head and neck. These lesions can appear as pearly white, skin-colored, or pink bumps that are translucent, and tiny blood vessels are often visible. The nodeIds 96054 (skin fibroepithelial basal cell carcinoma), 96057 (follicular basal cell carcinoma), 98637 (follicular atrophoderma-basal cell carcinoma), and 39254 (skin adenoid basal cell carcinoma) are all related to basal cell carcinoma, which matches the description of the lesion observed."
}


## Text2Cypher Retriever

In [51]:
from neo4j_graphrag.experimental.components.types import Neo4jGraph
from typing import Dict
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.graphs import Neo4jGraph

embedding_model = OpenAIEmbeddings(model="text-embedding-ada-002")
graph = Neo4jGraph()


QUERY_TEMPLATE = '''
UNWIND $queryVectors AS q
CALL(q) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, q.queryVector) YIELD node AS n
  WHERE apoc.label.exists(n, q.label)
  WITH n LIMIT 500
  RETURN collect(elementId(n)) AS nodeIds
}
WITH collect(nodeIds) AS nodeIdsLists
MATCH p=(d1:Disease)-[r1:PARENT_CHILD]->(d2:Disease)-[r2:PHENOTYPE_PRESENT]->(ep:EffectOrPhenotype)
WHERE (elementId(d1) IN nodeIdsLists[0]) AND (elementId(d2) IN nodeIdsLists[1]) AND (elementId(ep) IN nodeIdsLists[2])
RETURN collect(d1.name +  "{" + "nodeId:" + d1.nodeId + "}" +
  " - " + type(r1) + " -> " + 
  d2.name +  "{" + "nodeId:" + d2.nodeId + "}" +
  " - " + type(r2) + " -> " + 
  ep.name +  "{" + "nodeId:" + d2.nodeId + "}")
  AS facts,
  collect(DISTINCT d1.name +  "{" + "nodeId:" + d1.nodeId + ", details:" +  d1.details + "}") AS disease01TextDetails,
  collect(DISTINCT d2.name +  "{" + "nodeId:" + d2.nodeId + ", details:" +  d2.details + "}") AS disease02TextDetails,
  collect(DISTINCT d2.name +  "{" + "nodeId:" + d2.nodeId + ", details:" +  d2.details + "}") AS effectOrPhenotypeTextDetails
'''

def retrieve(search_prompts: Dict[str, str]):
    query_vectors = [{'queryVector':embedding_model.embed_query(sp['searchPrompt']), 'label':sp['label'] } for sp in search_prompts]
    query_results = graph.query(QUERY_TEMPLATE, params={"queryVectors": query_vectors})
    return query_results

sps = [
    {
        'searchPrompt': "epithelial skin neoplasms",
        'label': "Disease"
    },
    {
        'searchPrompt': "skin diseases",
        'label': "Disease"
    },
    {
        'searchPrompt': "tiny, yellowish lesion",
        'label': "EffectOrPhenotype"
    }
]

res = retrieve(sps)

In [53]:
res[0]['facts']

['epidermal disease{nodeId:38028} - PARENT_CHILD -> ichthyosis (disease){nodeId:33648} - PHENOTYPE_PRESENT -> Localized epidermolytic hyperkeratosis{nodeId:33648}',
 'epidermal disease{nodeId:38028} - PARENT_CHILD -> Darier disease{nodeId:31039} - PHENOTYPE_PRESENT -> Hypermelanotic macule{nodeId:31039}',
 'epidermal disease{nodeId:38028} - PARENT_CHILD -> Darier disease{nodeId:31039} - PHENOTYPE_PRESENT -> Abnormality of skin pigmentation{nodeId:31039}',
 'epidermal disease{nodeId:38028} - PARENT_CHILD -> Darier disease{nodeId:31039} - PHENOTYPE_PRESENT -> Macule{nodeId:31039}',
 'epidermal disease{nodeId:38028} - PARENT_CHILD -> Darier disease{nodeId:31039} - PHENOTYPE_PRESENT -> Skin vesicle{nodeId:31039}',
 'epidermal disease{nodeId:38028} - PARENT_CHILD -> Darier disease{nodeId:31039} - PHENOTYPE_PRESENT -> Anal mucosal leukoplakia{nodeId:31039}',
 'vulvar seborrheic keratosis{nodeId:38032} - PARENT_CHILD -> seborrheic keratosis{nodeId:31850} - PHENOTYPE_PRESENT -> Seborrheic kera

In [55]:
{k:v for k,v in res[0].items() if k != 'facts'}

{'disease01TextDetails': ["epidermal disease{nodeId:38028, details:{'mondo_id': 19268, 'mondo_name': 'epidermal disease', 'mondo_definition': 'A skin disease that involves the epidermis.'}}",
  'vulvar seborrheic keratosis{nodeId:38032, details:{\'mondo_id\': 6622, \'mondo_name\': \'vulvar seborrheic keratosis\', \'mondo_definition\': \'A benign squamous neoplasm that arises from the vulva. It is characterized by the proliferation of the basal cells in the squamous epithelium, acanthosis, hyperkeratosis, and cysts formation.\', \'mayo_symptoms\': \'A seborrheic keratosis usually looks like a waxy or wartlike growth. It typically appears on the face, chest, shoulders or back. You may develop a single growth, though multiple growths are more common. A seborrheic keratosis: Ranges in color from light tan to brown or black, Is round or oval shaped, Has a characteristic \\\\pasted on\\\\" look, Is flat or slightly raised with a scaly surface, Ranges in size from very small to more than 1 in

In [61]:
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAIEmbeddings, ChatOpenAI

PROMPT = PromptTemplate.from_template('''
Yu are a medical professional using your expertise to Answer the Question with the following Facts and Details knowledge. Only respond with information mentioned in the Facts and Details. Do not inject any speculative information not mentioned. 

Return result as JSON using the following format:
{{"nodeIds": [list of nodeIds containing the correct answers in order of probability highest to smallest], "explanation": "written explanation as to why the node id answers were provided"}}

# Question:
{question}
 
# Facts:
{facts}

# Details
{details}

# Answer:
''')

llm = ChatOpenAI(
    model_name="gpt-4o",
    response_format={"type": "json_object"}, # use json_object formatting for best results
    temperature=0, # turning temperature down for more deterministic results
    seed=42
)



def answer(question: str, search_prompts: Dict[str, str]):
    context = retrieve(search_prompts)[0]
    prompt = PROMPT.format(question=question, facts=context['facts'], details={k:v for k,v in context.items() if k != 'facts'})
    return llm.invoke(prompt)
    

In [62]:
q="Could you identify any skin diseases associated with epithelial skin neoplasms? I've observed a tiny, yellowish lesion on sun-exposed areas of my face and neck, and I suspect it might be connected."
answer(q, sps).content

'{"nodeIds": [31039, 31850], "explanation": "The observed tiny, yellowish lesion on sun-exposed areas of the face and neck could be associated with Darier disease (nodeId:31039) or seborrheic keratosis (nodeId:31850). Darier disease is characterized by greasy, yellow-brown keratotic papules that can be exacerbated by sun exposure, which aligns with the description of the lesion. Seborrheic keratosis is a common benign skin neoplasm that appears as waxy or wartlike growths, often on the face, and can range in color from light tan to brown or black, which might also match the observed lesion."}'

In [None]:
"""
WITH genai.vector.encode($d1, "OpenAI", {token:$token}) AS d1Emb,
genai.vector.encode($d2, "OpenAI", {token:$token}) AS d2Emb,
genai.vector.encode($ep, "OpenAI", {token:$token}) AS epEmb
CALL(d1Emb) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, d1Emb) YIELD node AS d1
  WHERE d1:Disease
  WITH d1 LIMIT 500
  RETURN collect(elementId(d1)) AS d1s
}
CALL(d2Emb) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, d2Emb) YIELD node AS d2
  WHERE d2:Disease
  WITH d2 LIMIT 500
  RETURN collect(elementId(d2)) AS d2s
}
CALL(epEmb) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, epEmb) YIELD node AS ep
  WHERE ep:EffectOrPhenotype
  WITH ep LIMIT 500
  RETURN collect(elementId(ep)) AS eps
}

MATCH p=(d1:Disease)-[r1:PARENT_CHILD]->(d2:Disease)-[r2:PHENOTYPE_PRESENT]->(ep:EffectOrPhenotype)
WHERE (elementId(d1) IN d1s) AND (elementId(d2) IN d2s) AND (elementId(ep) IN eps)
RETURN d1.name +  "(" + "id:" + d1.nodeId + ")"  +
  " - " + type(r1) + " -> " + 
  d2.name + "(" + "id:" + d2.nodeId + ")"  +
  " - " + type(r2) + " -> " + 
  ep.name + "(" + "id:" + ep.nodeId + ")"  
  AS fact


"""

In [42]:
query_template = """
//get node candidates
CALL($d1Emb) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, $d1Emb) YIELD node AS d1
  WHERE d1:Disease
  RETURN collect(elementId(d1)) AS d1s LIMIT 100
}
CALL($d2Emb) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, $d2Emb) YIELD node AS d2
  WHERE d2:Disease
  RETURN collect(elementId(d2)) AS d2s LIMIT 100
}
CALL($epEmb) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, $epEmb) YIELD node AS ep
  WHERE ep:EffectOrPhenotype
  RETURN collect(elementId(ep)) AS eps LIMIT 100
}

//graph pattern
MATCH p=(d1:Disease)-[PARENT_CHILD]-(d2:Disease)-[:PHENOTYPE_PRESENT]->(ep:EffectOrPhenotype)

//filter
WHERE (elementId(d1) IN d1s) AND (elementId(d2) IN d2s) AND (elementId(ep) IN eps)

//return
RETURN p
"""

pattern_query_template = Template("""
CALL($embParam) {{
  CALL db.index.vector.queryNodes('text_embeddings', 1000, $embParam) YIELD node AS n
  WHERE n:$label
  RETURN collect(elementId(n)) AS $vars LIMIT 100
}}
""")

search_node_query_template.substitute(embParam="$disemb", label="Disease", var="d1s")

"\nCALL($disemb) {{\n  CALL db.index.vector.queryNodes('text_embeddings', 1000, $disemb) YIELD node AS n\n  WHERE n:Disease\n  RETURN collect(elementId(n)) AS d1s\\ss LIMIT 100\n}}\n"

In [None]:
from abc import ABC, abstractmethod
from typing import Type, Dict
from langchain_core.tools import BaseTool
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool

from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)


class NodeSearchClauses(BaseModel, ABC):
    nodeDict: Dict[str,str]

    @abstractmethod
    def generate(self):
        pass


class NodeVectorSearchClauses(BaseModel, ABC):
    nodeDict: Dict[str,str]

    @abstractmethod
    def generate(self):
        pass



class FindInput(BaseModel):
    nodeLabels: Dict[str,str] = Field(description="pairs of node labels to search for")

# come up with query template
class FindDiseases0(BaseTool):
    name = "FindDiseases0"
    description = "Find diseases with zero indication drug and are associated with <effect/phenotype>",
    args_schema: Type[BaseModel] = MemoryInput
    
    base_query = '''(n:EffectOrPhenotype)->[:PHENOTYPE_ABSENT]->(m:Disease)<-[:!INDICATION]<-(o:Drug)'''

    def _run(
            self,
            movie: str,
            rating: int,
            run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        """Use the tool."""
        return store_movie_rating(movie, rating)

    async def _arun(
            self,
            movie: str,
            rating: int,
            run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> str:
        """Use the tool asynchronously."""
        return store_movie_rating(movie, rating)

In [64]:
from neo4j_graphrag.schema import get_schema, get_structured_schema

print(get_schema(driver))

Node properties:
Disease {nodeId: INTEGER, id: STRING, name: STRING, source: STRING, details: STRING, textEmbedding: LIST}
GeneOrProtein {nodeId: INTEGER, id: INTEGER, name: STRING, source: STRING, details: STRING, textEmbedding: LIST}
MolecularFunction {nodeId: INTEGER, id: STRING, name: STRING, source: STRING, details: STRING, textEmbedding: LIST}
Drug {nodeId: INTEGER, id: STRING, name: STRING, source: STRING, details: STRING, textEmbedding: LIST}
Pathway {nodeId: INTEGER, id: STRING, name: STRING, source: STRING, details: STRING, textEmbedding: LIST}
Anatomy {nodeId: INTEGER, id: INTEGER, name: STRING, source: STRING, details: STRING, textEmbedding: LIST}
EffectOrPhenotype {nodeId: INTEGER, id: STRING, name: STRING, source: STRING, details: STRING, textEmbedding: LIST}
BiologicalProcess {nodeId: INTEGER, id: STRING, name: STRING, source: STRING, details: STRING, textEmbedding: LIST}
CellularComponent {nodeId: INTEGER, id: INTEGER, name: STRING, source: STRING, details: STRING, text

In [157]:
schema = """(:Disease)-[:ASSOCIATED_WITH]->(:GeneOrProtein)
(:Disease)-[:PARENT_CHILD]->(:Disease)
(:Disease)-[:PHENOTYPE_ABSENT]->(:EffectOrPhenotype)
(:Disease)-[:PHENOTYPE_PRESENT]->(:EffectOrPhenotype)
(:Disease)-[:CONTRAINDICATION]->(:Drug)
(:Disease)-[:INDICATION]->(:Drug)
(:Disease)-[:OFF_LABEL_USE]->(:Drug)
(:Disease)-[:LINKED_TO]->(:Exposure)
(:GeneOrProtein)-[:ASSOCIATED_WITH]->(:Disease)
(:GeneOrProtein)-[:ASSOCIATED_WITH]->(:EffectOrPhenotype)
(:GeneOrProtein)-[:PPI]->(:GeneOrProtein)
(:GeneOrProtein)-[:EXPRESSION_PRESENT]->(:Anatomy)
(:GeneOrProtein)-[:INTERACTS_WITH]->(:MolecularFunction)
(:GeneOrProtein)-[:INTERACTS_WITH]->(:Pathway)
(:GeneOrProtein)-[:INTERACTS_WITH]->(:BiologicalProcess)
(:GeneOrProtein)-[:INTERACTS_WITH]->(:CellularComponent)
(:GeneOrProtein)-[:INTERACTS_WITH]->(:Exposure)
(:GeneOrProtein)-[:TARGET]->(:Drug)
(:GeneOrProtein)-[:EXPRESSION_ABSENT]->(:Anatomy)
(:GeneOrProtein)-[:TRANSPORTER]->(:Drug)
(:GeneOrProtein)-[:ENZYME]->(:Drug)
(:GeneOrProtein)-[:CARRIER]->(:Drug)
(:MolecularFunction)-[:PARENT_CHILD]->(:MolecularFunction)
(:MolecularFunction)-[:INTERACTS_WITH]->(:GeneOrProtein)
(:MolecularFunction)-[:INTERACTS_WITH]->(:Exposure)
(:Drug)-[:CARRIER]->(:GeneOrProtein)
(:Drug)-[:ENZYME]->(:GeneOrProtein)
(:Drug)-[:TARGET]->(:GeneOrProtein)
(:Drug)-[:TRANSPORTER]->(:GeneOrProtein)
(:Drug)-[:CONTRAINDICATION]->(:Disease)
(:Drug)-[:INDICATION]->(:Disease)
(:Drug)-[:SYNERGISTIC_INTERACTION]->(:Drug)
(:Drug)-[:SIDE_EFFECT]->(:EffectOrPhenotype)
(:Drug)-[:OFF_LABEL_USE]->(:Disease)
(:Pathway)-[:PARENT_CHILD]->(:Pathway)
(:Pathway)-[:INTERACTS_WITH]->(:GeneOrProtein)
(:Anatomy)-[:PARENT_CHILD]->(:Anatomy)
(:Anatomy)-[:EXPRESSION_PRESENT]->(:GeneOrProtein)
(:Anatomy)-[:EXPRESSION_ABSENT]->(:GeneOrProtein)
(:EffectOrPhenotype)-[:PARENT_CHILD]->(:EffectOrPhenotype)
(:EffectOrPhenotype)-[:PHENOTYPE_PRESENT]->(:Disease)
(:EffectOrPhenotype)-[:PHENOTYPE_ABSENT]->(:Disease)
(:EffectOrPhenotype)-[:ASSOCIATED_WITH]->(:GeneOrProtein)
(:EffectOrPhenotype)-[:SIDE_EFFECT]->(:Drug)
(:BiologicalProcess)-[:PARENT_CHILD]->(:BiologicalProcess)
(:BiologicalProcess)-[:INTERACTS_WITH]->(:GeneOrProtein)
(:BiologicalProcess)-[:INTERACTS_WITH]->(:Exposure)
(:CellularComponent)-[:PARENT_CHILD]->(:CellularComponent)
(:CellularComponent)-[:INTERACTS_WITH]->(:GeneOrProtein)
(:CellularComponent)-[:INTERACTS_WITH]->(:Exposure)
(:Exposure)-[:PARENT_CHILD]->(:Exposure)
(:Exposure)-[:INTERACTS_WITH]->(:GeneOrProtein)
(:Exposure)-[:INTERACTS_WITH]->(:BiologicalProcess)
(:Exposure)-[:INTERACTS_WITH]->(:MolecularFunction)
(:Exposure)-[:INTERACTS_WITH]->(:CellularComponent)
(:Exposure)-[:LINKED_TO]->(:Disease)"""



In [215]:
access_logics = {
    "(effect/phenotype → [phenotype absent] → disease ← [!indication] ← drug)": "Find diseases with zero indication drug and are associated with <effect/phenotype>",
    "(drug → [contraindication] → disease ← [associated with] ← gene/protein)": "Identify diseases associated with <gene/protein> and are contraindicated with <drug>",
    "(anatomy → [expression present] → gene/protein ← [expression absent] ← anatomy)": "What gene or protein is expressed in <anatomy1> while is absent in <anatomy2>?",
    "(anatomy → [expression absent] → gene/protein ← [expression absent] ← anatomy)": "What gene/protein is absent in both <anatomy1> and <anatomy2>?",
    "(drug → [carrier] → gene/protein ← [carrier] ← drug)": "Which target genes are shared carriers between <drug1> and <drug2>?",
    "(anatomy → [expression present] → gene/protein → [target] → drug)": "What is the drug that targets the genes or proteins which are expressed in <anatomy>?",
    "(drug → [side effect] → effect/phenotype → [side effect] → drug)": "What drug has common side effects as <drug>?",
    "(drug → [carrier] → gene/protein → [carrier] → drug)": "What is the drug that has common gene/protein carrier with <drug>?",
    "(anatomy → [expression present] → gene/protein → enzyme → drug)": "What is the drug that some genes or proteins act as an enzyme upon, where the genes or proteins are expressed in <anatomy>?",
    "(cellular_component → [interacts with] → gene/protein → [carrier] → drug)": "What is the drug carried by genes or proteins that interact with <cellular_component>?",
    "(molecular_function → [interacts with] → gene/protein → [target] → drug)": "What drug targets the genes or proteins that interact with <molecular_function>?",
    "(effect/phenotype → [side effect] → drug → [synergistic interaction] → drug)": "What drug has a synergistic interaction with the drug that has <effect/phenotype> as a side effect?",
    "(disease → [indication] → drug → [contraindication] → disease)": "What disease is a contraindication for the drugs indicated for <disease>?",
    "(disease → [parent-child] → disease → [phenotype present] → effect/phenotype)": "What effect or phenotype is present in the sub type of <disease>?",
    "(gene/protein → [transporter] → drug → [side effect] → effect/phenotype)": "What effect or phenotype is a [side effect] of the drug transported by <gene/protein>?",
    "(drug → [transporter] → gene/protein → [interacts with] → exposure)": "What exposure may affect <drug>s efficacy by acting on its transporter genes?",
    "(pathway → [interacts with] → gene/protein → [ppi] → gene/protein)": "What gene/protein interacts with the gene/protein that related to <pathway>?",
    "(drug → [synergistic interaction] → drug → [transporter] → gene/protein)": "What gene or protein transports the drugs that have a synergistic interaction with <drug>?",
    "(biological_process → [interacts with] → gene/protein → [interacts with] → biological_process)": "What biological process has the common interactino pattern with gene or proteins as <biological_process>?",
    "(effect/phenotype → [associated with] → gene/protein → [interacts with] → biological_process)": "What biological process interacts with the gene/protein associated with <effect/phenotype>?",
    "(drug → [transporter] → gene/protein → [expression present] → anatomy)": "What anatomy expressesed by the gene/protein that affect the transporter of <drug>?",
    "(drug → [target] → gene/protein → [interacts with] → cellular_component)": "What cellular component interacts with genes or proteins targeted by <drug>?",
    "(biological_process → [interacts with] → gene/protein → [expression absent] → anatomy)": "What anatomy does not express the genes or proteins that interacts with <biological_process>?",
    "(effect/phenotype → [associated with] → gene/protein → [expression absent] → anatomy)": "What anatomy does not express the genes or proteins associated with <effect/phenotype>?",
    "(drug → [indication] → disease → [indication] → drug) & (drug → [synergistic interaction] → drug)": "Find drugs that has a synergistic interaction with <drug> and both are indicated for the same disease.",
    "(pathway → [interacts with] → gene/protein → [interacts with] → pathway) & (pathway → [parent-child] → pathway)": "Find pathway that is related with <pathway> and both can [interacts with] the same gene/protein.",
    "(gene/protein → [associated with] → disease → [associated with] → gene/protein) & (gene/protein → [ppi] → gene/protein)": "Find gene/protein that can interect with <gene/protein> and both are associated with the same disease.",
    "(gene/protein → [associated with] → effect/phenotype → [associated with] → gene/protein) & (gene/protein → [ppi] → gene/protein)": "Find gene/protein that can interect with <gene/protein> and both are associated with the same effect/phenotype."
}

In [216]:
examples = """
## Example 1
### AccessLogic
(disease → [parent-child] → disease → [phenotype present] → effect/phenotype)
#### MATCH Statement
MATCH p=(n1:Disease)-[r1:PARENT_CHILD]->(n2:Disease)-[r2:PHENOTYPE_PRESENT]->(n3:EffectOrPhenotype)

## Example 2
### AccessLogic
(effect/phenotype → [phenotype absent] → disease ← [!indication] ← drug)
#### MATCH Statement
MATCH p=(n1:EffectOrPhenotype)-[r1:PHENOTYPE_ABSENT]->(n2:Disease)<-[r2:!INDICATION]-(n3:Drug)

## Example 3
### AccessLogic
(pathway → [interacts with] → gene/protein → [interacts with] → pathway) & (pathway → [parent-child] → pathway)
#### MATCH Statement
MATCH (n1:Pathway)-[r1:INTERACTS_WITH]->(n2:GeneOrProtein)-[r2:INTERACTS_WITH]->(n3:Pathway), (n1:Pathway)-[r3:PARENT_CHILD]->(n3:Pathway)

## Example 4
### AccessLogic
(gene/protein → [associated with] → effect/phenotype → [associated with] → gene/protein) & (gene/protein → [ppi] → gene/protein)
#### MATCH Statement
MATCH (n1:GeneOrProtein)-[r1:ASSOCIATED_WITH]->(n2:EffectOrPhenotype)-[r2:ASSOCIATED_WITH]->(n3:GeneOrProtein), (n1:GeneOrProtein)-[r3:PPI]->(n3:GeneOrProtein)


"""

In [217]:
CYPHER_TEMPLATE = """
Task: Translate the below AccessLogic into a MATCH Statement in Cypher, Neo4j's Graph Query language, based on the supplied GraphSchema. 
Follow the format used in the Examples. 


# AccessLogic
{accessLogic}

# GraphSchema:
{schema}

# Examples:
{examples}


- Do not use any labels or relationships not included in the GraphSchema.
- The symbol ! should be respected and maintained in translation as long as it is valid syntax
- When an "&" is used in AccessLogic, it usually means no new nodes are being introduced after the "&" symbol. Instead reuse node variables assuming no cyclical relationships.  See examples.
- Do not include triple backticks ``` or any additional text except the generated Cypher MATCH Statement in your response.

# MATCH Statement:
"""

In [218]:
from typing import Tuple, Optional
from pydantic import BaseModel
import re

p="(n1:EffectOrPhenotype)-[r1:PHENOTYPE_ABSENT]->(n2:Disease)<-[r2:!INDICATION]-(n3:Drug)"

class Head(BaseModel):
    node: str = None
    rel: Tuple[str, str, str, str] = None
    
    def from_string(s:str):
        
        if ":" not in s:
            return Head()
        else:
            node = s[:s.index(":")]
            rel_pattern = re.findall('\)(.*)\[(.*):(.*)](.*)', s)
            if len(rel_pattern) > 0:
                return Head(node=node, rel=rel_pattern[0])
            else:
                return Head(node=node)
    
    def make_where_part(self, ind:int) -> str:
        return f'(elementId({self.node}) IN nodeIdsLists[{ind}])'
        
    
    def make_fact_part(self) -> str:
        if self.node is None:
            return ''
        else:
            fact_part = f'{self.node}.name +  "{{" + "nodeId:" + {self.node}.nodeId "}}"'
            if self.rel is not None:
                fact_part = fact_part + f'+ " {self.rel[0]} " + type({self.rel[1]}) + " {self.rel[3]} "'

        return fact_part
    
    def make_details_part(self) -> str:
        return f'collect(DISTINCT {self.node}.name +  "{{" + "nodeId:" + {self.node}.nodeId + ", details:" +  {self.node}.details + "}}") AS {self.node}TextDetails'
        
    
    

In [219]:
from typing import List

QUERY_START = '''
UNWIND $queryVectors AS q
CALL(q) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, q.queryVector) YIELD node AS n
  WHERE apoc.label.exists(n, q.label)
  WITH n LIMIT 500
  RETURN collect(elementId(n)) AS nodeIds
}
WITH collect(nodeIds) AS nodeIdsLists'''

class GraphPattern(BaseModel):
    heads: List[Head] = []
    pattern:str

    def from_string(s:str):
        heads = []
        for split in s.split("("):
            if len(split) > 0:
                h = Head.from_string(split)
                if h.node is not None:
                    heads.append(Head.from_string(split))
        return GraphPattern(heads=heads, pattern=s)
    
    def make_where(self) -> str:
        res_list = []
        nodes_used=[]
        iter = 0
        for head in self.heads:
            if head.node not in nodes_used:
                res_list.append(head.make_where_part(iter))
                nodes_used.append(head.node)
                iter+=1
        return "WHERE " + " AND ".join(res_list)

    def make_facts(self) -> str:
        fact_list = []
        for head in self.heads:
            fact_part = head.make_fact_part()
            if len(fact_part) > 1:
                fact_list.append(fact_part)
        return " + ".join(fact_list)
    
    def make_details(self) -> str:
        res_list = []
        nodes_used=[]
        for head in self.heads:
            if head.node not in nodes_used:
                res_list.append(head.make_details_part())
                nodes_used.append(head.node)
        return ",\n ".join(res_list)
    
    def make_query(self) -> str:
        return f'''
{QUERY_START}
{self.pattern}
{self.make_where()}
RETURN {self.make_facts()} AS facts,
{self.make_details()}
'''

In [220]:
p="MATCH (n1:Pathway)-[r1:INTERACTS_WITH]->(n2:GeneOrProtein)-[r2:INTERACTS_WITH]->(n3:Pathway), (n1:Pathway)-[r3:PARENT_CHILD]->(n3:Pathway)"
gp = GraphPattern.from_string(p)
print(gp.make_query())



UNWIND $queryVectors AS q
CALL(q) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, q.queryVector) YIELD node AS n
  WHERE apoc.label.exists(n, q.label)
  WITH n LIMIT 500
  RETURN collect(elementId(n)) AS nodeIds
}
WITH collect(nodeIds) AS nodeIdsLists
MATCH (n1:Pathway)-[r1:INTERACTS_WITH]->(n2:GeneOrProtein)-[r2:INTERACTS_WITH]->(n3:Pathway), (n1:Pathway)-[r3:PARENT_CHILD]->(n3:Pathway)
WHERE (elementId(n1) IN nodeIdsLists[0]) AND (elementId(n2) IN nodeIdsLists[1]) AND (elementId(n3) IN nodeIdsLists[2])
RETURN n1.name +  "{" + "nodeId:" + n1.nodeId "}"+ " - " + type(r1) + " -> " + n2.name +  "{" + "nodeId:" + n2.nodeId "}"+ " - " + type(r2) + " -> " + n3.name +  "{" + "nodeId:" + n3.nodeId "}" + n1.name +  "{" + "nodeId:" + n1.nodeId "}"+ " - " + type(r3) + " -> " + n3.name +  "{" + "nodeId:" + n3.nodeId "}" AS facts,
collect(DISTINCT n1.name +  "{" + "nodeId:" + n1.nodeId + ", details:" +  n1.details + "}") AS n1TextDetails,
 collect(DISTINCT n2.name +  "{" + "nodeId:"

In [255]:
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAIEmbeddings, ChatOpenAI


cypher_llm = ChatOpenAI(
    model_name="gpt-4o-mini",
    temperature=0, # turning temperature down for more deterministic results
    seed=42
)



def gen_match_pattern(access_logic:str):
    prompt = CYPHER_TEMPLATE.format(accessLogic = access_logic, schema=schema, examples=examples)
    match_pattern = cypher_llm.invoke(prompt)
    gp = GraphPattern.from_string(match_pattern.content)
    return match_pattern.content #gp.make_query()

In [256]:
access_patterns = []
for k, v in access_logics.items():
    print('\n=========================')
    print(k)
    print('-------------------')
    access_patterns.append({"pattern": gen_match_pattern(k), "description": v})


(effect/phenotype → [phenotype absent] → disease ← [!indication] ← drug)
-------------------

(drug → [contraindication] → disease ← [associated with] ← gene/protein)
-------------------

(anatomy → [expression present] → gene/protein ← [expression absent] ← anatomy)
-------------------

(anatomy → [expression absent] → gene/protein ← [expression absent] ← anatomy)
-------------------

(drug → [carrier] → gene/protein ← [carrier] ← drug)
-------------------

(anatomy → [expression present] → gene/protein → [target] → drug)
-------------------

(drug → [side effect] → effect/phenotype → [side effect] → drug)
-------------------

(drug → [carrier] → gene/protein → [carrier] → drug)
-------------------

(anatomy → [expression present] → gene/protein → enzyme → drug)
-------------------

(cellular_component → [interacts with] → gene/protein → [carrier] → drug)
-------------------

(molecular_function → [interacts with] → gene/protein → [target] → drug)
-------------------

(effect/phenoty

In [257]:
print(json.dumps(access_patterns, indent=4))

[
    {
        "pattern": "MATCH (n1:EffectOrPhenotype)-[r1:PHENOTYPE_ABSENT]->(n2:Disease)<-[r2:CONTRAINDICATION]-(n3:Drug)",
        "description": "Find diseases with zero indication drug and are associated with <effect/phenotype>"
    },
    {
        "pattern": "MATCH (n1:Drug)-[r1:CONTRAINDICATION]->(n2:Disease)<-[r2:ASSOCIATED_WITH]-(n3:GeneOrProtein)",
        "description": "Identify diseases associated with <gene/protein> and are contraindicated with <drug>"
    },
    {
        "pattern": "MATCH (n1:Anatomy)-[r1:EXPRESSION_PRESENT]->(n2:GeneOrProtein)<-[r2:EXPRESSION_ABSENT]-(n3:Anatomy)",
        "description": "What gene or protein is expressed in <anatomy1> while is absent in <anatomy2>?"
    },
    {
        "pattern": "MATCH (n1:Anatomy)-[r1:EXPRESSION_ABSENT]->(n2:GeneOrProtein)<-[r2:EXPRESSION_ABSENT]-(n3:Anatomy)",
        "description": "What gene/protein is absent in both <anatomy1> and <anatomy2>?"
    },
    {
        "pattern": "MATCH (n1:Drug)-[r1:CARRIER]->(n

In [278]:
access_patterns = [
    {
        "pattern": "MATCH (n1:EffectOrPhenotype)-[r1:PHENOTYPE_ABSENT]->(n2:Disease)<-[r2:CONTRAINDICATION]-(n3:Drug)",
        "questionType": "Find diseases with zero indication drug and are associated with <effect/phenotype>"
    },
    {
        "pattern": "MATCH (n1:Drug)-[r1:CONTRAINDICATION]->(n2:Disease)<-[r2:ASSOCIATED_WITH]-(n3:GeneOrProtein)",
        "questionType": "Identify diseases associated with <gene/protein> and are contraindicated with <drug>"
    },
    {
        "pattern": "MATCH (n1:Anatomy)-[r1:EXPRESSION_PRESENT]->(n2:GeneOrProtein)<-[r2:EXPRESSION_ABSENT]-(n3:Anatomy)",
        "questionType": "What gene or protein is expressed in <anatomy1> while is absent in <anatomy2>?"
    },
    {
        "pattern": "MATCH (n1:Anatomy)-[r1:EXPRESSION_ABSENT]->(n2:GeneOrProtein)<-[r2:EXPRESSION_ABSENT]-(n3:Anatomy)",
        "questionType": "What gene/protein is absent in both <anatomy1> and <anatomy2>?"
    },
    {
        "pattern": "MATCH (n1:Drug)-[r1:CARRIER]->(n2:GeneOrProtein)<-[r2:CARRIER]-(n3:Drug)",
        "questionType": "Which target genes are shared carriers between <drug1> and <drug2>?"
    },
    {
        "pattern": "MATCH (n1:Anatomy)-[r1:EXPRESSION_PRESENT]->(n2:GeneOrProtein)-[r2:TARGET]->(n3:Drug)",
        "questionType": "What is the drug that targets the genes or proteins which are expressed in <anatomy>?"
    },
    {
        "pattern": "MATCH (n1:Drug)-[r1:SIDE_EFFECT]->(n2:EffectOrPhenotype)-[r2:SIDE_EFFECT]->(n3:Drug)",
        "questionType": "What drug has common side effects as <drug>?"
    },
    {
        "pattern": "MATCH (n1:Drug)-[r1:CARRIER]->(n2:GeneOrProtein)-[r2:CARRIER]->(n3:Drug)",
        "questionType": "What is the drug that has common gene/protein carrier with <drug>?"
    },
    {
        "pattern": "MATCH p=(n1:Anatomy)-[r1:EXPRESSION_PRESENT]->(n2:GeneOrProtein)-[r2:ENZYME]->(n3:Drug)",
        "questionType": "What is the drug that some genes or proteins act as an enzyme upon, where the genes or proteins are expressed in <anatomy>?"
    },
    {
        "pattern": "MATCH p=(n1:CellularComponent)-[r1:INTERACTS_WITH]->(n2:GeneOrProtein)-[r2:CARRIER]->(n3:Drug)",
        "questionType": "What is the drug carried by genes or proteins that interact with <cellular_component>?"
    },
    {
        "pattern": "MATCH (n1:MolecularFunction)-[r1:INTERACTS_WITH]->(n2:GeneOrProtein)-[r2:TARGET]->(n3:Drug)",
        "questionType": "What drug targets the genes or proteins that interact with <molecular_function>?"
    },
    {
        "pattern": "MATCH (n1:EffectOrPhenotype)-[r1:SIDE_EFFECT]->(n2:Drug)-[r2:SYNERGISTIC_INTERACTION]->(n3:Drug)",
        "questionType": "What drug has a synergistic interaction with the drug that has <effect/phenotype> as a side effect?"
    },
    {
        "pattern": "MATCH p=(n1:Disease)-[r1:INDICATION]->(n2:Drug)-[r2:CONTRAINDICATION]->(n3:Disease)",
        "questionType": "What disease is a contraindication for the drugs indicated for <disease>?"
    },
    {
        "pattern": "MATCH (n1:Disease)-[r1:PARENT_CHILD]->(n2:Disease)-[r2:PHENOTYPE_PRESENT]->(n3:EffectOrPhenotype)",
        "questionType": "What effect or phenotype is present in the sub type of <disease>?"
    },
    {
        "pattern": "MATCH (n1:GeneOrProtein)-[r1:TRANSPORTER]->(n2:Drug)-[r2:SIDE_EFFECT]->(n3:EffectOrPhenotype)",
        "questionType": "What effect or phenotype is a [side effect] of the drug transported by <gene/protein>?"
    },
    {
        "pattern": "MATCH (n1:Drug)-[r1:TRANSPORTER]->(n2:GeneOrProtein)-[r2:INTERACTS_WITH]->(n3:Exposure)",
        "questionType": "What exposure may affect <drug>s efficacy by acting on its transporter genes?"
    },
    {
        "pattern": "MATCH (n1:Pathway)-[r1:INTERACTS_WITH]->(n2:GeneOrProtein)-[r2:PPI]->(n3:GeneOrProtein)",
        "questionType": "What gene/protein interacts with the gene/protein that related to <pathway>?"
    },
    {
        "pattern": "MATCH (n1:Drug)-[r1:SYNERGISTIC_INTERACTION]->(n2:Drug)-[r2:TRANSPORTER]->(n3:GeneOrProtein)",
        "questionType": "What gene or protein transports the drugs that have a synergistic interaction with <drug>?"
    },
    {
        "pattern": "MATCH (n1:BiologicalProcess)-[r1:INTERACTS_WITH]->(n2:GeneOrProtein)-[r2:INTERACTS_WITH]->(n3:BiologicalProcess)",
        "questionType": "What biological process has the common interactino pattern with gene or proteins as <biological_process>?"
    },
    {
        "pattern": "MATCH (n1:EffectOrPhenotype)-[r1:ASSOCIATED_WITH]->(n2:GeneOrProtein)-[r2:INTERACTS_WITH]->(n3:BiologicalProcess)",
        "questionType": "What biological process interacts with the gene/protein associated with <effect/phenotype>?"
    },
    {
        "pattern": "MATCH p=(n1:Drug)-[r1:TRANSPORTER]->(n2:GeneOrProtein)-[r2:EXPRESSION_PRESENT]->(n3:Anatomy)",
        "questionType": "What anatomy expressesed by the gene/protein that affect the transporter of <drug>?"
    },
    {
        "pattern": "MATCH (n1:Drug)-[r1:TARGET]->(n2:GeneOrProtein)-[r2:INTERACTS_WITH]->(n3:CellularComponent)",
        "questionType": "What cellular component interacts with genes or proteins targeted by <drug>?"
    },
    {
        "pattern": "MATCH (n1:BiologicalProcess)-[r1:INTERACTS_WITH]->(n2:GeneOrProtein)-[r2:EXPRESSION_ABSENT]->(n3:Anatomy)",
        "questionType": "What anatomy does not express the genes or proteins that interacts with <biological_process>?"
    },
    {
        "pattern": "MATCH (n1:EffectOrPhenotype)-[r1:ASSOCIATED_WITH]->(n2:GeneOrProtein)-[r2:EXPRESSION_ABSENT]->(n3:Anatomy)",
        "questionType": "What anatomy does not express the genes or proteins associated with <effect/phenotype>?"
    },
    {
        "pattern": "MATCH (n1:Drug)-[r1:INDICATION]->(n2:Disease)-[r2:INDICATION]->(n3:Drug), (n1:Drug)-[r3:SYNERGISTIC_INTERACTION]->(n3:Drug)",
        "questionType": "Find drugs that has a synergistic interaction with <drug> and both are indicated for the same disease."
    },
    {
        "pattern": "MATCH (n1:Pathway)-[r1:INTERACTS_WITH]->(n2:GeneOrProtein)-[r2:INTERACTS_WITH]->(n3:Pathway), (n1:Pathway)-[r3:PARENT_CHILD]->(n3:Pathway)",
        "questionType": "Find pathway that is related with <pathway> and both can [interacts with] the same gene/protein."
    },
    {
        "pattern": "MATCH (n1:GeneOrProtein)-[r1:ASSOCIATED_WITH]->(n2:Disease)-[r2:ASSOCIATED_WITH]->(n3:GeneOrProtein), (n1:GeneOrProtein)-[r3:PPI]->(n3:GeneOrProtein)",
        "questionType": "Find gene/protein that can interect with <gene/protein> and both are associated with the same disease."
    },
    {
        "pattern": "MATCH (n1:GeneOrProtein)-[r1:ASSOCIATED_WITH]->(n2:EffectOrPhenotype)-[r2:ASSOCIATED_WITH]->(n3:GeneOrProtein), (n1:GeneOrProtein)-[r3:PPI]->(n3:GeneOrProtein)",
        "questionType": "Find gene/protein that can interect with <gene/protein> and both are associated with the same effect/phenotype."
    }
]


In [259]:
MATCH_TEMPLATE = """
Task: Choose a MATCH Statement That best matches the Question. 
MATCH Statements consist of only a Cypher query language Match statement.  They do not contain WHERE clauses or filters which are accounted for elsewhere. 
If no MATCH Statements match the question well, generate a new one. 
DO not generate new access patterns just to filter on properties or add WHERE clauses. 
Each MATCH Statement is a MATCH Statement in Cypher, Neo4j's Graph Query language, based on the supplied GraphSchema. 

# Question
{question}

# MATCH Statements
{accessPatterns}

# GraphSchema:
{schema}


If generating a new MATCH statement
- Do not use any labels or relationships not included in the GraphSchema.
- It must be a valid Cypher Match statement with correct syntax.
- use the same format as shown in the other match statements
- Do not include triple backticks ``` or any additional text except the generated Cypher MATCH Statement in your response.

# Chosen MATCH Statement:
"""
    

In [319]:
MATCH_TEMPLATE = """
Task: From AccessPatterns, choose a questionType that best matches the Below Question. 
Return:
- the questionType
- associated pattern
- nodeSearchTerms: For each node in `pattern`, determine if the node requires a search phrase/filter based on the question.  return a seperate nodeSearchTerms mapping in the JSON form of {{"d1":"<search_prompt>","d2":"<search_prompt>"...}}  For each node that requires a search phrase/filter.
- goodMatchFound
- feedback


If no match seems appropriate, provide feedback explaining why. 
Focus on matching intent and key concepts from the Question.



# Question
{question}

# AccessPatterns
{accessPatterns}

# Chosen MATCH Statement:
"""

In [323]:
from typing import Any
from pydantic import BaseModel, Field

class AccessPattern(BaseModel):
    """Access Pattern (Cypher Match Statement)"""
    pattern: str = Field(description="The Match statement")
    nodeSearchTerms: str = Field(description="Search terms to use for each node, decomposed from the question")
    questionType: str = Field(description="the type of question")
    goodMatchFound: bool = Field(description="Does the MATCH Statement fit the question well?")
    matchFeedback: str = Field(description="feedback explaining why.")

match_llm = cypher_llm.with_structured_output(AccessPattern)

In [324]:
def gen_match_pattern(question:str):
    prompt = MATCH_TEMPLATE.format(question=question, accessPatterns=json.dumps(access_patterns, indent=4), schema=schema)
    #print(prompt)
    match_pattern = match_llm.invoke(prompt)
    return match_pattern

## Multi-Part Question Retrievers

In [325]:
for q in qa_dataset.data["query"][:5]:
    print('==================')
    print(q)
    print('----')
    print(gen_match_pattern(q))

Could you identify any skin diseases associated with epithelial skin neoplasms? I've observed a tiny, yellowish lesion on sun-exposed areas of my face and neck, and I suspect it might be connected.
----
pattern='MATCH (n1:Disease)-[r1:PARENT_CHILD]->(n2:Disease)-[r2:PHENOTYPE_PRESENT]->(n3:EffectOrPhenotype)' nodeSearchTerms='{"n1":"skin diseases","n2":"epithelial skin neoplasms","n3":"phenotype"}' questionType='What effect or phenotype is present in the sub type of <disease>?' goodMatchFound=True matchFeedback='The question is about identifying skin diseases associated with epithelial skin neoplasms, which aligns with the pattern that looks for diseases and their associated phenotypes.'
What drugs target the CYP3A4 enzyme and are used to treat strongyloidiasis?
----
pattern='MATCH (n1:Drug)-[r1:TARGET]->(n2:GeneOrProtein)' nodeSearchTerms='{"n1":"drugs targeting CYP3A4","n2":"CYP3A4 enzyme"}' questionType='Identify drugs that target the CYP3A4 enzyme' goodMatchFound=True matchFeedback