In [49]:
%%capture
%pip install torch torch_geometric stark-qa neo4j python-dotenv pcst_fast datasets pandas transformers langchain langchain-openai langchain-community

## Setup

load env variables

In [1]:
from dotenv import load_dotenv
import os

load_dotenv('.env', override=True)
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

Get stark-qa dataset

In [4]:
from stark_qa import load_qa, load_skb

dataset_name = 'prime'

# Load the retrieval dataset
qa_dataset = load_qa(dataset_name)

Use file from /Users/zachblumenfeld/.cache/huggingface/hub/datasets--snap-stanford--stark/snapshots/e3387ba9c873528521e125c3da687092d9872104/qa/prime/stark_qa/stark_qa_human_generated_eval_uncleaned.csv.


In [5]:
qa_dataset.data

Unnamed: 0,id,query,answer_ids
0,0,Could you identify any skin diseases associate...,[95886]
1,1,What drugs target the CYP3A4 enzyme and are us...,[15450]
2,2,What is the name of the condition characterize...,"[98851, 98853]"
3,3,What drugs are used to treat epithelioid sarco...,[15698]
4,4,Can you supply a compilation of genes and prot...,"[7161, 22045]"
...,...,...,...
11199,11199,Which gene or protein is not expressed in fema...,[2414]
11200,11200,Could you identify a biological pathway in whi...,[128199]
11201,11201,Is there an interaction between genes or prote...,"[127611, 62903]"
11202,11202,Which pharmacological agents that stimulate os...,[20180]


## Vector Retriever

In [2]:
import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings

driver = neo4j.GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))

#create text embedder
embedder = OpenAIEmbeddings()

In [13]:
from neo4j_graphrag.retrievers import VectorRetriever

vector_retriever = VectorRetriever(
    driver,
    index_name="text_embeddings",
    embedder=embedder,
    return_properties=["details", "name", "nodeId"],
)

In [34]:
from neo4j_graphrag.llm import OpenAILLM as LLM
from neo4j_graphrag.generation import RagTemplate
from neo4j_graphrag.generation.graphrag import GraphRAG

llm=LLM(
    model_name="gpt-4o",
    model_params={
        "response_format": {"type": "json_object"}, # use json_object formatting for best results
        "temperature": 0, # turning temperature down for more deterministic results
        "seed": 42
    }
)

rag_template = RagTemplate(template='''
Answer the Question using the following Context. Only respond with information mentioned in the Context. Do not inject any speculative information not mentioned. 

Return result as JSON using the following format:
{{"nodeIds": [list of nodeIds containing the correct answers in order of probability highest to smallest], "explanation": "written explanation as to why the node id answers were provided"}}

# Question:
{query_text}
 
# Context:
{context}

# Answer:
''', expected_inputs=['query_text', 'context'])

v_rag  = GraphRAG(llm=llm, retriever=vector_retriever, prompt_template=rag_template)

In [35]:
q=qa_dataset.data['query'][0]
print(q)

v_res = v_rag.search(q, return_context=True, retriever_config={"top_k": 10})

Could you identify any skin diseases associated with epithelial skin neoplasms? I've observed a tiny, yellowish lesion on sun-exposed areas of my face and neck, and I suspect it might be connected.


In [36]:
import json

print(json.dumps(eval(v_res.answer), indent=1))

{
 "nodeIds": [
  96054,
  96057,
  98637,
  39254
 ],
 "explanation": "The tiny, yellowish lesion observed on sun-exposed areas of the face and neck could be associated with basal cell carcinoma, a type of epithelial skin neoplasm. The context provides information on various types of basal cell carcinoma, which often develop on sun-exposed parts of the body, especially the head and neck. These lesions can appear as pearly white, skin-colored, or pink bumps that are translucent, and tiny blood vessels are often visible. The nodeIds 96054 (skin fibroepithelial basal cell carcinoma), 96057 (follicular basal cell carcinoma), 98637 (follicular atrophoderma-basal cell carcinoma), and 39254 (skin adenoid basal cell carcinoma) are all related to basal cell carcinoma, which matches the description of the lesion observed."
}


## Text2Cypher Retriever

In [51]:
from neo4j_graphrag.experimental.components.types import Neo4jGraph
from typing import Dict
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.graphs import Neo4jGraph

embedding_model = OpenAIEmbeddings(model="text-embedding-ada-002")
graph = Neo4jGraph()


QUERY_TEMPLATE = '''
UNWIND $queryVectors AS q
CALL(q) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, q.queryVector) YIELD node AS n
  WHERE apoc.label.exists(n, q.label)
  WITH n LIMIT 500
  RETURN collect(elementId(n)) AS nodeIds
}
WITH collect(nodeIds) AS nodeIdsLists
MATCH p=(d1:Disease)-[r1:PARENT_CHILD]->(d2:Disease)-[r2:PHENOTYPE_PRESENT]->(ep:EffectOrPhenotype)
WHERE (elementId(d1) IN nodeIdsLists[0]) AND (elementId(d2) IN nodeIdsLists[1]) AND (elementId(ep) IN nodeIdsLists[2])
RETURN collect(d1.name +  "{" + "nodeId:" + d1.nodeId + "}" +
  " - " + type(r1) + " -> " + 
  d2.name +  "{" + "nodeId:" + d2.nodeId + "}" +
  " - " + type(r2) + " -> " + 
  ep.name +  "{" + "nodeId:" + d2.nodeId + "}")
  AS facts,
  collect(DISTINCT d1.name +  "{" + "nodeId:" + d1.nodeId + ", details:" +  d1.details + "}") AS disease01TextDetails,
  collect(DISTINCT d2.name +  "{" + "nodeId:" + d2.nodeId + ", details:" +  d2.details + "}") AS disease02TextDetails,
  collect(DISTINCT d2.name +  "{" + "nodeId:" + d2.nodeId + ", details:" +  d2.details + "}") AS effectOrPhenotypeTextDetails
'''

def retrieve(search_prompts: Dict[str, str]):
    query_vectors = [{'queryVector':embedding_model.embed_query(sp['searchPrompt']), 'label':sp['label'] } for sp in search_prompts]
    query_results = graph.query(QUERY_TEMPLATE, params={"queryVectors": query_vectors})
    return query_results

sps = [
    {
        'searchPrompt': "epithelial skin neoplasms",
        'label': "Disease"
    },
    {
        'searchPrompt': "skin diseases",
        'label': "Disease"
    },
    {
        'searchPrompt': "tiny, yellowish lesion",
        'label': "EffectOrPhenotype"
    }
]

res = retrieve(sps)

In [53]:
res[0]['facts']

['epidermal disease{nodeId:38028} - PARENT_CHILD -> ichthyosis (disease){nodeId:33648} - PHENOTYPE_PRESENT -> Localized epidermolytic hyperkeratosis{nodeId:33648}',
 'epidermal disease{nodeId:38028} - PARENT_CHILD -> Darier disease{nodeId:31039} - PHENOTYPE_PRESENT -> Hypermelanotic macule{nodeId:31039}',
 'epidermal disease{nodeId:38028} - PARENT_CHILD -> Darier disease{nodeId:31039} - PHENOTYPE_PRESENT -> Abnormality of skin pigmentation{nodeId:31039}',
 'epidermal disease{nodeId:38028} - PARENT_CHILD -> Darier disease{nodeId:31039} - PHENOTYPE_PRESENT -> Macule{nodeId:31039}',
 'epidermal disease{nodeId:38028} - PARENT_CHILD -> Darier disease{nodeId:31039} - PHENOTYPE_PRESENT -> Skin vesicle{nodeId:31039}',
 'epidermal disease{nodeId:38028} - PARENT_CHILD -> Darier disease{nodeId:31039} - PHENOTYPE_PRESENT -> Anal mucosal leukoplakia{nodeId:31039}',
 'vulvar seborrheic keratosis{nodeId:38032} - PARENT_CHILD -> seborrheic keratosis{nodeId:31850} - PHENOTYPE_PRESENT -> Seborrheic kera

In [55]:
{k:v for k,v in res[0].items() if k != 'facts'}

{'disease01TextDetails': ["epidermal disease{nodeId:38028, details:{'mondo_id': 19268, 'mondo_name': 'epidermal disease', 'mondo_definition': 'A skin disease that involves the epidermis.'}}",
  'vulvar seborrheic keratosis{nodeId:38032, details:{\'mondo_id\': 6622, \'mondo_name\': \'vulvar seborrheic keratosis\', \'mondo_definition\': \'A benign squamous neoplasm that arises from the vulva. It is characterized by the proliferation of the basal cells in the squamous epithelium, acanthosis, hyperkeratosis, and cysts formation.\', \'mayo_symptoms\': \'A seborrheic keratosis usually looks like a waxy or wartlike growth. It typically appears on the face, chest, shoulders or back. You may develop a single growth, though multiple growths are more common. A seborrheic keratosis: Ranges in color from light tan to brown or black, Is round or oval shaped, Has a characteristic \\\\pasted on\\\\" look, Is flat or slightly raised with a scaly surface, Ranges in size from very small to more than 1 in

In [58]:
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAIEmbeddings, ChatOpenAI

PROMPT = PromptTemplate.from_template('''
Answer the Question using the following Facts and Details. Only respond with information mentioned in the Facts and Details. Do not inject any speculative information not mentioned. 

Return result as JSON using the following format:
{{"nodeIds": [list of nodeIds containing the correct answers in order of probability highest to smallest], "explanation": "written explanation as to why the node id answers were provided"}}

# Question:
{question}
 
# Facts:
{facts}

# Details
{details}

# Answer:
''')

llm = ChatOpenAI(
    model_name="gpt-4o",
    response_format={"type": "json_object"}, # use json_object formatting for best results
    temperature=0, # turning temperature down for more deterministic results
    seed=42
)



def answer(question: str, search_prompts: Dict[str, str]):
    context = retrieve(search_prompts)[0]
    prompt = PROMPT.format(question=question, facts=context['facts'], details={k:v for k,v in context.items() if k != 'facts'})
    return llm.invoke(prompt)
    

In [60]:
q="Could you identify any skin diseases associated with epithelial skin neoplasms? I've observed a tiny, yellowish lesion on sun-exposed areas of my face and neck, and I suspect it might be connected."
answer(q, sps).content

'{"nodeIds": [31039, 31850, 32367], "explanation": "The observed tiny, yellowish lesion on sun-exposed areas of the face and neck could be associated with several skin conditions related to epithelial skin neoplasms. Based on the provided facts, Darier disease (nodeId:31039) is characterized by greasy and colored (yellow-brown or brown) keratotic papules, which may be exacerbated by exposure to sunlight. Seborrheic keratosis (nodeId:31850) is a common benign skin neoplasm that appears as black or brown, slightly elevated skin lesions, often on the face. Porokeratosis (nodeId:32367) involves the development of localized or multiple atrophic skin patches surrounded by an annular keratotic ring, which can progress to cutaneous neoplasm. These conditions are relevant to the description of the lesion and its location on sun-exposed areas."}'

In [None]:
"""
WITH genai.vector.encode($d1, "OpenAI", {token:$token}) AS d1Emb,
genai.vector.encode($d2, "OpenAI", {token:$token}) AS d2Emb,
genai.vector.encode($ep, "OpenAI", {token:$token}) AS epEmb
CALL(d1Emb) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, d1Emb) YIELD node AS d1
  WHERE d1:Disease
  WITH d1 LIMIT 500
  RETURN collect(elementId(d1)) AS d1s
}
CALL(d2Emb) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, d2Emb) YIELD node AS d2
  WHERE d2:Disease
  WITH d2 LIMIT 500
  RETURN collect(elementId(d2)) AS d2s
}
CALL(epEmb) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, epEmb) YIELD node AS ep
  WHERE ep:EffectOrPhenotype
  WITH ep LIMIT 500
  RETURN collect(elementId(ep)) AS eps
}

MATCH p=(d1:Disease)-[r1:PARENT_CHILD]->(d2:Disease)-[r2:PHENOTYPE_PRESENT]->(ep:EffectOrPhenotype)
WHERE (elementId(d1) IN d1s) AND (elementId(d2) IN d2s) AND (elementId(ep) IN eps)
RETURN d1.name +  "(" + "id:" + d1.nodeId + ")"  +
  " - " + type(r1) + " -> " + 
  d2.name + "(" + "id:" + d2.nodeId + ")"  +
  " - " + type(r2) + " -> " + 
  ep.name + "(" + "id:" + ep.nodeId + ")"  
  AS fact


"""

In [42]:
query_template = """
//get node candidates
CALL($d1Emb) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, $d1Emb) YIELD node AS d1
  WHERE d1:Disease
  RETURN collect(elementId(d1)) AS d1s LIMIT 100
}
CALL($d2Emb) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, $d2Emb) YIELD node AS d2
  WHERE d2:Disease
  RETURN collect(elementId(d2)) AS d2s LIMIT 100
}
CALL($epEmb) {
  CALL db.index.vector.queryNodes("text_embeddings", 1000, $epEmb) YIELD node AS ep
  WHERE ep:EffectOrPhenotype
  RETURN collect(elementId(ep)) AS eps LIMIT 100
}

//graph pattern
MATCH p=(d1:Disease)-[PARENT_CHILD]-(d2:Disease)-[:PHENOTYPE_PRESENT]->(ep:EffectOrPhenotype)

//filter
WHERE (elementId(d1) IN d1s) AND (elementId(d2) IN d2s) AND (elementId(ep) IN eps)

//return
RETURN p
"""

pattern_query_template = Template("""
CALL($embParam) {{
  CALL db.index.vector.queryNodes('text_embeddings', 1000, $embParam) YIELD node AS n
  WHERE n:$label
  RETURN collect(elementId(n)) AS $vars LIMIT 100
}}
""")

search_node_query_template.substitute(embParam="$disemb", label="Disease", var="d1s")

"\nCALL($disemb) {{\n  CALL db.index.vector.queryNodes('text_embeddings', 1000, $disemb) YIELD node AS n\n  WHERE n:Disease\n  RETURN collect(elementId(n)) AS d1s\\ss LIMIT 100\n}}\n"

In [None]:
from abc import ABC, abstractmethod
from typing import Type, Dict
from langchain_core.tools import BaseTool
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool

from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)


class NodeSearchClauses(BaseModel, ABC):
    nodeDict: Dict[str,str]

    @abstractmethod
    def generate(self):
        pass


class NodeVectorSearchClauses(BaseModel, ABC):
    nodeDict: Dict[str,str]

    @abstractmethod
    def generate(self):
        pass



class FindInput(BaseModel):
    nodeLabels: Dict[str,str] = Field(description="pairs of node labels to search for")

# come up with query template
class FindDiseases0(BaseTool):
    name = "FindDiseases0"
    description = "Find diseases with zero indication drug and are associated with <effect/phenotype>",
    args_schema: Type[BaseModel] = MemoryInput
    
    base_query = '''(n:EffectOrPhenotype)->[:PHENOTYPE_ABSENT]->(m:Disease)<-[:!INDICATION]<-(o:Drug)'''

    def _run(
            self,
            movie: str,
            rating: int,
            run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        """Use the tool."""
        return store_movie_rating(movie, rating)

    async def _arun(
            self,
            movie: str,
            rating: int,
            run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> str:
        """Use the tool asynchronously."""
        return store_movie_rating(movie, rating)

In [None]:
'''
{
(effect/phenotype → [phenotype absent] → disease ← [!indication] ← drug):
"Find diseases with zero indication drug and are associated with <effect/phenotype>",
(drug → [contraindication] → disease ← [associated with] ← gene/protein):
"Identify diseases associated with <gene/protein> and are contraindicated with <drug>",
(anatomy → [expression present] → gene/protein ← [expression absent] ← anatomy):
"What gene or protein is expressed in <anatomy1> while is absent in <anatomy2>?",
(anatomy → [expression absent] → gene/protein ← [expression absent] ← anatomy):
"What gene/protein is absent in both <anatomy1> and <anatomy2>?",
(drug → [carrier] → gene/protein ← [carrier] ← drug):
"Which target genes are shared carriers between <drug1> and <drug2>?",
(anatomy → [expression present] → gene/protein → [target] → drug):
"What is the drug that targets the genes or proteins which are expressed in <anatomy>?",
(drug → [side effect] → effect/phenotype → [side effect] → drug):
"What drug has common side effects as <drug>?",
(drug → [carrier] → gene/protein → [carrier] → drug):
"What is the drug that has common gene/protein carrier with <drug>?",
(anatomy → [expression present] → gene/protein → enzyme → drug):
"What is the drug that some genes or proteins act as an enzyme upon,
where the genes or proteins are expressed in <anatomy>?",
(cellular_component → [interacts with] → gene/protein → [carrier] → drug):
"What is the drug carried by genes or proteins that interact with <cellular_component>?",
(molecular_function → [interacts with] → gene/protein → [target] → drug):
"What drug targets the genes or proteins that interact with <molecular_function>?",
(effect/phenotype → [side effect] → drug → [synergistic interaction] → drug):
"What drug has a synergistic interaction with the drug that has <effect/phenotype> as a side effect?",
(disease → [indication] → drug → [contraindication] → disease):
"What disease is a contraindication for the drugs indicated for <disease>?",
(disease → [parent-child] → disease → [phenotype present] → effect/phenotype):
"What effect or phenotype is present in the sub type of <disease>?",
(gene/protein → [transporter] → drug → [side effect] → effect/phenotype):
"What effect or phenotype is a [side effect] of the drug transported by <gene/protein>?",
(drug → [transporter] → gene/protein → [interacts with] → exposure):
"What exposure may affect <drug>s efficacy by acting on its transporter genes?",
(pathway → [interacts with] → gene/protein → [ppi] → gene/protein):
"What gene/protein interacts with the gene/protein that related to <pathway>?",
19
(drug → [synergistic interaction] → drug → [transporter] → gene/protein):
"What gene or protein transports the drugs that have a synergistic interaction with <drug>?",
(biological_process → [interacts with] → gene/protein → [interacts with] → biological_process):
"What biological process has the common interactino pattern with gene or proteins as <biological_process>?",
(effect/phenotype → [associated with] → gene/protein → [interacts with] → biological_process):
"What biological process interacts with the gene/protein associated with <effect/phenotype>?",
(drug → [transporter] → gene/protein → [expression present] → anatomy):
"What anatomy expressesed by the gene/protein that affect the transporter of <drug>?",
(drug → [target] → gene/protein → [interacts with] → cellular_component):
"What cellular component interacts with genes or proteins targeted by <drug>?",
(biological_process → [interacts with] → gene/protein → [expression absent] → anatomy):
"What anatomy does not express the genes or proteins that interacts with <biological_process>?",
(effect/phenotype → [associated with] → gene/protein → [expression absent] → anatomy):
"What anatomy does not express the genes or proteins associated with <effect/phenotype>?",
(drug → [indication] → disease → [indication] → drug)
& (drug → [synergistic interaction] → drug):
"Find drugs that has a synergistic interaction with <drug> and both are indicated for the same disease.",
(pathway → [interacts with] → gene/protein → [interacts with] → pathway)
& (pathway → [parent-child] → pathway):
"Find pathway that is related with <pathway> and both can [interacts with] the same gene/protein.",
(gene/protein → [associated with] → disease → [associated with] → gene/protein)
& (gene/protein → [ppi] → gene/protein):
"Find gene/protein that can interect with <gene/protein> and both are associated with the same disease.",
(gene/protein → [associated with] → effect/phenotype → [associated with] → gene/protein)
& (gene/protein → [ppi] → gene/protein):
"Find gene/protein that can interect with <gene/protein> and both are associated with the same effect/phenotype."
}
'''

## Multi-Part Question Retrievers

helper functions, utilities, and such

In [2]:
from torch_geometric.data.data import Data
from neo4j import Driver
import pandas as pd
from typing import List
import torch
import numpy as np
from langchain_openai import OpenAIEmbeddings
from tqdm import tqdm

embedding_model = OpenAIEmbeddings(model="text-embedding-ada-002")
embedding_dimension = 1536

def chunks(xs, n=500):
    n = max(1, n)
    return [xs[i:i + n] for i in range(0, len(xs), n)]


def embed(doc_list, chunk_size=500):
    embeddings = []
    for docs in chunks(doc_list):
        embeddings.extend(embedding_model.embed_documents(docs))
    return embeddings
def get_nodes_by_vector_search(prompt:str, driver:Driver) -> List:
  res = driver.execute_query("""
    WITH genai.vector.encode(
      $searchPrompt,
      "OpenAI",
      {token:$token}) AS queryVector
    CALL db.index.vector.queryNodes($index, $k, queryVector) YIELD node
    RETURN node.nodeId AS nodeId
    """,
    parameters_={
        "searchPrompt":prompt,
        "token":OPENAI_API_KEY,
        "index":"text_embeddings",
        "k":4})
  return [rec.data()['nodeId'] for rec in res.records]

def get_subgraph_rels(node_ids:List, driver:Driver):
  res = driver.execute_query("""
    UNWIND $nodeIds AS nodeId
    MATCH(node:_Entity_ {nodeId:nodeId})
    // create filtered cartesian product
    WITH collect(node) AS sources, collect(node) AS targets
    UNWIND sources as source
    UNWIND targets as target
    WITH source, target
    WHERE source > target //how is this calculated? on element id?...it works

    // find connecting paths
    MATCH (source)-[rl]->{0,2}(target)

    //get rels
    UNWIND rl AS r
    WITH DISTINCT r
    MATCH (m)-[r]->(n)
    RETURN
    m.nodeId AS src,
    n.nodeId AS tgt,
    n.name + ' - ' + type(r) +  ' -> ' + m.name AS text
    """,
    parameters_={"nodeIds":node_ids})
  return pd.DataFrame([rec.data() for rec in res.records])

def get_all_node_ids(initial_node_ids, rel_df):
  node_ids = set(initial_node_ids)
  if rel_df.shape[0] > 0:
    node_ids.update(rel_df.src)
    node_ids.update(rel_df.tgt)
  return list(node_ids)

def get_node_df(initial_node_ids, rel_df, driver:Driver):
  node_ids = get_all_node_ids(initial_node_ids, rel_df)
  res = driver.execute_query("""
    UNWIND $nodeIds AS nodeId
    MATCH(n:_Entity_ {nodeId:nodeId})
    RETURN n.nodeId AS nodeId, n.name AS name, n.textEmbedding AS textEmbedding, n.details AS details
    """,
    parameters_={"nodeIds":node_ids})
  return pd.DataFrame([rec.data() for rec in res.records])

def create_data_obj(node_df, rel_df, prompt):
  # sub-graph re-index for edge_index
  node_df = node_df.reset_index()
  n_df = node_df.reset_index()[['index', 'nodeId']]
  rel_df = (rel_df
    .merge(n_df, left_on='src', right_on='nodeId')
    .rename(columns={'index': 'src_index'})
    .drop(columns='nodeId')
    .merge(n_df, left_on='tgt', right_on='nodeId')
    .rename(columns={'index': 'tgt_index'})
    .drop(columns='nodeId')
  )

  # node attributes
  x = torch.tensor(np.stack(node_df.textEmbedding), dtype=torch.float)

  # edge attributes
  edge_attr = torch.tensor(np.stack(rel_df.textEmbedding), dtype=torch.float)

  # edge index
  edge_index = torch.tensor(np.array(rel_df[['src_index', 'tgt_index']]).T)

  # answer - leaving blank for now
  answer=''

  # desc - leaving blank for now
  desc=''

  return Data(x, edge_index, edge_attr, question=prompt, answer=answer, desc=desc)


def retrieve(prompt:str, driver:Driver) -> Data:
    init_node_ids = get_nodes_by_vector_search(prompt, driver)
    rel_df = get_subgraph_rels(init_node_ids, driver)
    node_df = get_node_df(init_node_ids, rel_df, driver)
    #doing this outside of the graph for now
    print('generating edge embeddings')
    rel_df['textEmbedding'] = embed(rel_df['text'])
    return create_data_obj(node_df, rel_df, prompt)

## Test Example

TODO: Adding answer and "desc" attributes. Desc is used as additional context...I think it is the "textualized graph" from the paper. 

In [3]:
from neo4j import GraphDatabase

with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    res = retrieve("Which gene or protein is engaged in DCC-mediated attractive signaling, can bind to actin filaments, and belongs to the actin-binding LIM protein family?", driver)
res

generating edge embeddings


Data(x=[148, 1536], edge_index=[2, 472], edge_attr=[472, 1536], question='Which gene or protein is engaged in DCC-mediated attractive signaling, can bind to actin filaments, and belongs to the actin-binding LIM protein family?', answer='', desc='')

In [4]:
res.x

tensor([[-0.0225,  0.0008,  0.0099,  ..., -0.0139,  0.0011, -0.0444],
        [-0.0140,  0.0068, -0.0137,  ..., -0.0098, -0.0041, -0.0517],
        [-0.0159, -0.0007, -0.0074,  ..., -0.0104, -0.0044, -0.0491],
        ...,
        [-0.0517, -0.0009, -0.0132,  ..., -0.0161, -0.0225, -0.0461],
        [-0.0104, -0.0070, -0.0021,  ..., -0.0232, -0.0206, -0.0446],
        [-0.0072,  0.0103, -0.0005,  ..., -0.0144, -0.0177, -0.0433]])

In [5]:
res.edge_index

tensor([[ 22, 129,  22,  35,  22,  64,  22, 109,  22, 117,  22, 126,  22, 130,
          22,  74,  22,  84,  22,  94,  22, 114,  22, 138,  22, 141,  22,  76,
          22, 105,  22,  91,  22, 103,  22,  97,  22,  46,  56, 129,  56,  75,
          56,  35,  56,  64,  56, 109,  56, 117,  56, 126,  56, 130,  56, 131,
          56,  74,  56,  97,  56,  84,  56,  85,  56,  87,  56,  88,  56,  90,
          56,  92,  56,  94,  56, 111,  56, 114,  56, 138,  56, 141,  56,  76,
          56, 105,  56,  15,  56,  24,  56,  46,  56,  91,  56, 103,  56,  20,
          56, 123,  56,  36,  56,  38,  56,  42,  56,  47,  56,  54,  56,  68,
          56,  81,  56, 113,  56, 128,  56, 143,  56, 146,  56,   5,  56,   6,
          56,  21,  56,  26,  56,  31,  56,  45,  56,  58,  56,  60,  56,  66,
          56,  67,  56, 102,  56,  70,  56, 100,  56, 106,  56, 121,  56, 137,
          56, 139,  56,   4,  56,   9,  56,  10,  56,  11,  56,  12,  56,  13,
          56,  16,  56,  23,  56,  25,  56,  77,  56

## Load the Prime Dataset

In [6]:
from stark_qa import load_qa, load_skb

dataset_name = 'prime'

# Load the retrieval dataset
qa_dataset = load_qa(dataset_name)

Use file from /Users/sbr/.cache/huggingface/hub/datasets--snap-stanford--stark/snapshots/7b0352c7dcefbf254478c203bcfdf284a08866ac/qa/prime/stark_qa/stark_qa_human_generated_eval.csv.


In [7]:
qa_dataset.data

Unnamed: 0,id,query,answer_ids
0,0,Could you identify any skin diseases associate...,[95886]
1,1,What drugs target the CYP3A4 enzyme and are us...,[15450]
2,2,What is the name of the condition characterize...,"[98851, 98853]"
3,3,What drugs are used to treat epithelioid sarco...,[15698]
4,4,Can you supply a compilation of genes and prot...,"[7161, 22045]"
...,...,...,...
11199,11199,Which gene or protein is not expressed in fema...,[2414]
11200,11200,Could you identify a biological pathway in whi...,[128199]
11201,11201,Is there an interaction between genes or prote...,"[127611, 62903]"
11202,11202,Which pharmacological agents that stimulate os...,[20180]


In [8]:
data_list = []
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    for prompt in tqdm(qa_dataset.data['query'][:10]):
        data_list.append(retrieve(prompt, driver))  

  0%|          | 0/10 [00:00<?, ?it/s]

generating edge embeddings


 10%|█         | 1/10 [00:01<00:15,  1.76s/it]

generating edge embeddings


 20%|██        | 2/10 [00:22<01:44, 13.10s/it]

generating edge embeddings


 30%|███       | 3/10 [00:24<00:55,  7.86s/it]

generating edge embeddings


 40%|████      | 4/10 [00:37<00:59,  9.87s/it]

generating edge embeddings


 50%|█████     | 5/10 [00:38<00:33,  6.71s/it]

generating edge embeddings


 60%|██████    | 6/10 [00:39<00:18,  4.69s/it]

generating edge embeddings


 70%|███████   | 7/10 [00:40<00:11,  3.67s/it]

generating edge embeddings


 80%|████████  | 8/10 [00:42<00:05,  2.98s/it]

generating edge embeddings


 90%|█████████ | 9/10 [00:43<00:02,  2.37s/it]

generating edge embeddings


100%|██████████| 10/10 [00:46<00:00,  4.63s/it]


In [9]:
for d in data_list:
    print(d)

Data(x=[5, 1536], edge_index=[2, 3], edge_attr=[3, 1536], question='Could you identify any skin diseases associated with epithelial skin neoplasms? I've observed a tiny, yellowish lesion on sun-exposed areas of my face and neck, and I suspect it might be connected.', answer='', desc='')
Data(x=[747, 1536], edge_index=[2, 2931], edge_attr=[2931, 1536], question='What drugs target the CYP3A4 enzyme and are used to treat strongyloidiasis?', answer='', desc='')
Data(x=[4, 1536], edge_index=[2, 5], edge_attr=[5, 1536], question='What is the name of the condition characterized by a complete interruption of the inferior vena cava, falling under congenital vena cava anomalies?', answer='', desc='')
Data(x=[1197, 1536], edge_index=[2, 2405], edge_attr=[2405, 1536], question='What drugs are used to treat epithelioid sarcoma and also affect the EZH2 gene product?', answer='', desc='')
Data(x=[8, 1536], edge_index=[2, 13], edge_attr=[13, 1536], question='Can you supply a compilation of genes and p