In [1]:
%%capture
%pip install torch torch_geometric stark-qa neo4j python-dotenv pcst_fast datasets pandas transformers langchain langchain-openai

## Setup

load env variables

In [1]:
from dotenv import load_dotenv
import os

load_dotenv('db.env', override=True)
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

helper functions, utilities, and such

In [2]:
from torch_geometric.data.data import Data
from neo4j import Driver
import pandas as pd
from typing import List
import torch
import numpy as np
from langchain_openai import OpenAIEmbeddings
from tqdm import tqdm

embedding_model = OpenAIEmbeddings(model="text-embedding-ada-002")
embedding_dimension = 1536

def chunks(xs, n=500):
    n = max(1, n)
    return [xs[i:i + n] for i in range(0, len(xs), n)]


def embed(doc_list, chunk_size=500):
    embeddings = []
    for docs in chunks(doc_list):
        embeddings.extend(embedding_model.embed_documents(docs))
    return embeddings
def get_nodes_by_vector_search(prompt:str, driver:Driver) -> List:
  res = driver.execute_query("""
    WITH genai.vector.encode(
      $searchPrompt,
      "OpenAI",
      {token:$token}) AS queryVector
    CALL db.index.vector.queryNodes($index, $k, queryVector) YIELD node
    RETURN node.nodeId AS nodeId
    """,
    parameters_={
        "searchPrompt":prompt,
        "token":OPENAI_API_KEY,
        "index":"text_embeddings",
        "k":4})
  return [rec.data()['nodeId'] for rec in res.records]

def get_subgraph_rels(node_ids:List, driver:Driver):
  res = driver.execute_query("""
    UNWIND $nodeIds AS nodeId
    MATCH(node:_Entity_ {nodeId:nodeId})
    // create filtered cartesian product
    WITH collect(node) AS sources, collect(node) AS targets
    UNWIND sources as source
    UNWIND targets as target
    WITH source, target
    WHERE source > target //how is this calculated? on element id?...it works

    // find connecting paths
    MATCH (source)-[rl]->{0,2}(target)

    //get rels
    UNWIND rl AS r
    WITH DISTINCT r
    MATCH (m)-[r]->(n)
    RETURN
    m.nodeId AS src,
    n.nodeId AS tgt,
    n.name + ' - ' + type(r) +  ' -> ' + m.name AS text
    """,
    parameters_={"nodeIds":node_ids})
  return pd.DataFrame([rec.data() for rec in res.records])

def get_all_node_ids(initial_node_ids, rel_df):
  node_ids = set(initial_node_ids)
  if rel_df.shape[0] > 0:
    node_ids.update(rel_df.src)
    node_ids.update(rel_df.tgt)
  return list(node_ids)

def get_node_df(initial_node_ids, rel_df, driver:Driver):
  node_ids = get_all_node_ids(initial_node_ids, rel_df)
  res = driver.execute_query("""
    UNWIND $nodeIds AS nodeId
    MATCH(n:_Entity_ {nodeId:nodeId})
    RETURN n.nodeId AS nodeId, n.name AS name, n.textEmbedding AS textEmbedding, n.details AS details
    """,
    parameters_={"nodeIds":node_ids})
  return pd.DataFrame([rec.data() for rec in res.records])

def create_data_obj(node_df, rel_df, prompt):
  # sub-graph re-index for edge_index
  node_df = node_df.reset_index()
  n_df = node_df.reset_index()[['index', 'nodeId']]
  rel_df = (rel_df
    .merge(n_df, left_on='src', right_on='nodeId')
    .rename(columns={'index': 'src_index'})
    .drop(columns='nodeId')
    .merge(n_df, left_on='tgt', right_on='nodeId')
    .rename(columns={'index': 'tgt_index'})
    .drop(columns='nodeId')
  )

  # node attributes
  x = torch.tensor(np.stack(node_df.textEmbedding), dtype=torch.float)

  # edge attributes
  edge_attr = torch.tensor(np.stack(rel_df.textEmbedding), dtype=torch.float)

  # edge index
  edge_index = torch.tensor(np.array(rel_df[['src_index', 'tgt_index']]).T)

  # answer - leaving blank for now
  answer=''

  # desc - leaving blank for now
  desc=''

  return Data(x, edge_index, edge_attr, question=prompt, answer=answer, desc=desc)


def retrieve(prompt:str, driver:Driver) -> Data:
    init_node_ids = get_nodes_by_vector_search(prompt, driver)
    rel_df = get_subgraph_rels(init_node_ids, driver)
    node_df = get_node_df(init_node_ids, rel_df, driver)
    #doing this outside of the graph for now
    print('generating edge embeddings')
    rel_df['textEmbedding'] = embed(rel_df['text'])
    return create_data_obj(node_df, rel_df, prompt)

## Test Example

TODO: Adding answer and "desc" attributes. Desc is used as additional context...I think it is the "textualized graph" from the paper. 

In [3]:
from neo4j import GraphDatabase

with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    res = retrieve("Which gene or protein is engaged in DCC-mediated attractive signaling, can bind to actin filaments, and belongs to the actin-binding LIM protein family?", driver)
res

generating edge embeddings


Data(x=[148, 1536], edge_index=[2, 472], edge_attr=[472, 1536], question='Which gene or protein is engaged in DCC-mediated attractive signaling, can bind to actin filaments, and belongs to the actin-binding LIM protein family?', answer='', desc='')

In [4]:
res.x

tensor([[-0.0225,  0.0008,  0.0099,  ..., -0.0139,  0.0011, -0.0444],
        [-0.0140,  0.0068, -0.0137,  ..., -0.0098, -0.0041, -0.0517],
        [-0.0159, -0.0007, -0.0074,  ..., -0.0104, -0.0044, -0.0491],
        ...,
        [-0.0517, -0.0009, -0.0132,  ..., -0.0161, -0.0225, -0.0461],
        [-0.0104, -0.0070, -0.0021,  ..., -0.0232, -0.0206, -0.0446],
        [-0.0072,  0.0103, -0.0005,  ..., -0.0144, -0.0177, -0.0433]])

In [5]:
res.edge_index

tensor([[ 22, 129,  22,  35,  22,  64,  22, 109,  22, 117,  22, 126,  22, 130,
          22,  74,  22,  84,  22,  94,  22, 114,  22, 138,  22, 141,  22,  76,
          22, 105,  22,  91,  22, 103,  22,  97,  22,  46,  56, 129,  56,  75,
          56,  35,  56,  64,  56, 109,  56, 117,  56, 126,  56, 130,  56, 131,
          56,  74,  56,  97,  56,  84,  56,  85,  56,  87,  56,  88,  56,  90,
          56,  92,  56,  94,  56, 111,  56, 114,  56, 138,  56, 141,  56,  76,
          56, 105,  56,  15,  56,  24,  56,  46,  56,  91,  56, 103,  56,  20,
          56, 123,  56,  36,  56,  38,  56,  42,  56,  47,  56,  54,  56,  68,
          56,  81,  56, 113,  56, 128,  56, 143,  56, 146,  56,   5,  56,   6,
          56,  21,  56,  26,  56,  31,  56,  45,  56,  58,  56,  60,  56,  66,
          56,  67,  56, 102,  56,  70,  56, 100,  56, 106,  56, 121,  56, 137,
          56, 139,  56,   4,  56,   9,  56,  10,  56,  11,  56,  12,  56,  13,
          56,  16,  56,  23,  56,  25,  56,  77,  56

## Load the Prime Dataset

In [6]:
from stark_qa import load_qa, load_skb

dataset_name = 'prime'

# Load the retrieval dataset
qa_dataset = load_qa(dataset_name)

Use file from /Users/sbr/.cache/huggingface/hub/datasets--snap-stanford--stark/snapshots/7b0352c7dcefbf254478c203bcfdf284a08866ac/qa/prime/stark_qa/stark_qa_human_generated_eval.csv.


In [7]:
qa_dataset.data

Unnamed: 0,id,query,answer_ids
0,0,Could you identify any skin diseases associate...,[95886]
1,1,What drugs target the CYP3A4 enzyme and are us...,[15450]
2,2,What is the name of the condition characterize...,"[98851, 98853]"
3,3,What drugs are used to treat epithelioid sarco...,[15698]
4,4,Can you supply a compilation of genes and prot...,"[7161, 22045]"
...,...,...,...
11199,11199,Which gene or protein is not expressed in fema...,[2414]
11200,11200,Could you identify a biological pathway in whi...,[128199]
11201,11201,Is there an interaction between genes or prote...,"[127611, 62903]"
11202,11202,Which pharmacological agents that stimulate os...,[20180]


In [8]:
data_list = []
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    for prompt in tqdm(qa_dataset.data['query'][:10]):
        data_list.append(retrieve(prompt, driver))  

  0%|          | 0/10 [00:00<?, ?it/s]

generating edge embeddings


 10%|█         | 1/10 [00:01<00:15,  1.76s/it]

generating edge embeddings


 20%|██        | 2/10 [00:22<01:44, 13.10s/it]

generating edge embeddings


 30%|███       | 3/10 [00:24<00:55,  7.86s/it]

generating edge embeddings


 40%|████      | 4/10 [00:37<00:59,  9.87s/it]

generating edge embeddings


 50%|█████     | 5/10 [00:38<00:33,  6.71s/it]

generating edge embeddings


 60%|██████    | 6/10 [00:39<00:18,  4.69s/it]

generating edge embeddings


 70%|███████   | 7/10 [00:40<00:11,  3.67s/it]

generating edge embeddings


 80%|████████  | 8/10 [00:42<00:05,  2.98s/it]

generating edge embeddings


 90%|█████████ | 9/10 [00:43<00:02,  2.37s/it]

generating edge embeddings


100%|██████████| 10/10 [00:46<00:00,  4.63s/it]


In [9]:
for d in data_list:
    print(d)

Data(x=[5, 1536], edge_index=[2, 3], edge_attr=[3, 1536], question='Could you identify any skin diseases associated with epithelial skin neoplasms? I've observed a tiny, yellowish lesion on sun-exposed areas of my face and neck, and I suspect it might be connected.', answer='', desc='')
Data(x=[747, 1536], edge_index=[2, 2931], edge_attr=[2931, 1536], question='What drugs target the CYP3A4 enzyme and are used to treat strongyloidiasis?', answer='', desc='')
Data(x=[4, 1536], edge_index=[2, 5], edge_attr=[5, 1536], question='What is the name of the condition characterized by a complete interruption of the inferior vena cava, falling under congenital vena cava anomalies?', answer='', desc='')
Data(x=[1197, 1536], edge_index=[2, 2405], edge_attr=[2405, 1536], question='What drugs are used to treat epithelioid sarcoma and also affect the EZH2 gene product?', answer='', desc='')
Data(x=[8, 1536], edge_index=[2, 13], edge_attr=[13, 1536], question='Can you supply a compilation of genes and p