# Load STaRK Prime Into Neo4j

Resources
- [STaRK GitHub](https://github.com/snap-stanford/stark)
- [STaRK Prime Docs](https://stark.stanford.edu/dataset_prime.html)

In [1]:
%pip install stark-qa neo4j python-dotenv



## Get & Explore STaRK Prime Data

In [1]:
from stark_qa import load_qa, load_skb

dataset_name = 'prime'

# Load the retrieval dataset
qa_dataset = load_qa(dataset_name)
idx_split = qa_dataset.get_idx_split()

# Load the semi-structured knowledge base
skb = load_skb(dataset_name, download_processed=True, root=None)

Use file from /Users/sbr/.cache/huggingface/hub/datasets--snap-stanford--stark/snapshots/7b0352c7dcefbf254478c203bcfdf284a08866ac/qa/prime/stark_qa/stark_qa_human_generated_eval.csv.
Loading from /Users/sbr/.cache/huggingface/hub/datasets--snap-stanford--stark/snapshots/7b0352c7dcefbf254478c203bcfdf284a08866ac/skb/prime/processed!


In [2]:
qa_dataset.data

Unnamed: 0,id,query,answer_ids
0,0,Could you identify any skin diseases associate...,[95886]
1,1,What drugs target the CYP3A4 enzyme and are us...,[15450]
2,2,What is the name of the condition characterize...,"[98851, 98853]"
3,3,What drugs are used to treat epithelioid sarco...,[15698]
4,4,Can you supply a compilation of genes and prot...,"[7161, 22045]"
...,...,...,...
11199,11199,Which gene or protein is not expressed in fema...,[2414]
11200,11200,Could you identify a biological pathway in whi...,[128199]
11201,11201,Is there an interaction between genes or prote...,"[127611, 62903]"
11202,11202,Which pharmacological agents that stimulate os...,[20180]


In [3]:
# Get one qa pair, we masked out metadata to avoid answer leaking
query, q_id, answer_ids, _ = qa_dataset[1]

In [4]:
query, q_id, answer_ids, _ = qa_dataset[4]
print('Query:', query)
print('Query ID:', q_id)
print('Answer:\n', '\n\n'.join([str(skb[aid].dictionary) for aid in answer_ids]))

Query: Can you supply a compilation of genes and proteins associated with endothelin B receptor interaction, involved in G alpha (q) signaling, and contributing to hypertension and ovulation-related biological functions?
Query ID: 4
Answer:
 {'id': 1906, 'type': 'gene/protein', 'name': 'EDN1', 'source': 'NCBI', 'details': {'query': 'EDN1', '_id': '1906', '_score': 17.339428, 'alias': ['ARCND3', 'ET1', 'HDLCQ7', 'PPET1', 'QME'], 'genomic_pos': {'chr': '6', 'end': 12297194, 'ensemblgene': 'ENSG00000078401', 'start': 12290361, 'strand': 1}, 'name': 'endothelin 1', 'summary': 'This gene encodes a preproprotein that is proteolytically processed to generate a secreted peptide that belongs to the endothelin/sarafotoxin family. This peptide is a potent vasoconstrictor and its cognate receptors are therapeutic targets in the treatment of pulmonary arterial hypertension. Aberrant expression of this gene may promote tumorigenesis. Alternative splicing results in multiple transcript variants. [pro

In [5]:
print(skb.META_DATA)
print(skb.NODE_TYPES)
print(skb.RELATION_TYPES)

['id', 'type', 'name', 'source', 'details']
['disease', 'gene/protein', 'molecular_function', 'drug', 'pathway', 'anatomy', 'effect/phenotype', 'biological_process', 'cellular_component', 'exposure']
['ppi', 'carrier', 'enzyme', 'target', 'transporter', 'contraindication', 'indication', 'off-label use', 'synergistic interaction', 'associated with', 'parent-child', 'phenotype absent', 'phenotype present', 'side effect', 'interacts with', 'linked to', 'expression present', 'expression absent']


In [6]:
skb[answer_ids[0]]

--id
--type
--name
--source
--details
    |-----query
    |-----_id
    |-----_score
    |-----alias
    |-----genomic_pos
    |-----    |-----chr
    |-----    |-----end
    |-----    |-----ensemblgene
    |-----    |-----start
    |-----    |-----strand
    |-----name
    |-----summary

## Format & Load Nodes

1. create a node dataframe
2. create function to format node labels based off of type
3. helper functions for loading
4. load Neo4j credentials from db.env file
4. node loading

In [7]:
from tqdm import tqdm
import pandas as pd

# create node_df
node_list = []

for i in tqdm(range(skb.num_nodes())):
  node = skb[i].dictionary
  node['nodeId'] = i
  node_list.append(skb[i].dictionary)
node_df = pd.DataFrame(node_list)

# format details
node_df.loc[node_df.details.isna(), 'details'] = ''
node_df.details = node_df.details.astype(str)

node_df

100%|██████████| 129375/129375 [00:00<00:00, 430338.97it/s]


Unnamed: 0,id,type,name,source,details,nodeId
0,9796,gene/protein,PHYHIP,NCBI,"{'query': 'PHYHIP', '_id': '9796', '_score': 1...",0
1,7918,gene/protein,GPANK1,NCBI,"{'query': 'GPANK1', '_id': '7918', '_score': 1...",1
2,8233,gene/protein,ZRSR2,NCBI,"{'query': 'ZRSR2', '_id': '8233', '_score': 17...",2
3,4899,gene/protein,NRF1,NCBI,"{'query': 'NRF1', '_id': '4899', '_score': 17....",3
4,5297,gene/protein,PI4KA,NCBI,"{'query': 'PI4KA', '_id': '5297', '_score': 17...",4
...,...,...,...,...,...,...
129370,R-HSA-936837,pathway,Ion transport by P-type ATPases,REACTOME,"{'dbId': 936837, 'displayName': 'Ion transport...",129370
129371,R-HSA-997272,pathway,Inhibition of voltage gated Ca2+ channels via...,REACTOME,"{'dbId': 997272, 'displayName': 'Inhibition o...",129371
129372,1062,anatomy,anatomical entity,UBERON,,129372
129373,468,anatomy,multi-cellular organism,UBERON,,129373


In [8]:
# note the node types. We will format these to node labels.
skb.node_type_dict

{0: 'disease',
 1: 'gene/protein',
 2: 'molecular_function',
 3: 'drug',
 4: 'pathway',
 5: 'anatomy',
 6: 'effect/phenotype',
 7: 'biological_process',
 8: 'cellular_component',
 9: 'exposure'}

In [9]:
import re

# function for formatting
def format_node_label(s):
  ss = s.replace('/', '_or_').lower().split('_')
  return ''.join(t.title() for t in ss)

[(k,format_node_label(v)) for k,v in  skb.node_type_dict.items()]

[(0, 'Disease'),
 (1, 'GeneOrProtein'),
 (2, 'MolecularFunction'),
 (3, 'Drug'),
 (4, 'Pathway'),
 (5, 'Anatomy'),
 (6, 'EffectOrPhenotype'),
 (7, 'BiologicalProcess'),
 (8, 'CellularComponent'),
 (9, 'Exposure')]

In [10]:
from typing import Tuple, Union
from numpy.typing import ArrayLike

# helper functions for laoding nodes & rels

def _make_map(x):
    if type(x) == str:
        return x, x
    elif type(x) == tuple:
        return x
    else:
        raise Exception("Entry must of type string or tuple")

def _make_constraint_query(constraint_type: str, node_label, prop_name) -> str:
  const_name = f'{constraint_type.lower()}_{node_label.lower()}_{prop_name.lower()}'
  return f'CREATE CONSTRAINT {const_name} IF NOT EXISTS FOR (n:{node_label}) REQUIRE n.{prop_name} IS {constraint_type}'


def _make_set_clause(prop_names: ArrayLike, element_name='n', item_name='rec'):
    clause_list = []
    for prop_name in prop_names:
        clause_list.append(f'{element_name}.{prop_name} = {item_name}.{prop_name}')
    return 'SET ' + ', '.join(clause_list)


def _make_node_merge_query(node_key_name: str, node_label: str, cols: ArrayLike):
    template = f'''UNWIND $recs AS rec\nMERGE(n:{node_label} {{{node_key_name}: rec.{node_key_name}}})'''
    prop_names = [x for x in cols if x != node_key_name]
    if len(prop_names) > 0:
        template = template + '\n' + _make_set_clause(prop_names)
    return template + '\nRETURN count(n) AS nodeLoadedCount'


def _make_rel_merge_query(source_target_labels: Union[Tuple[str, str], str],
                          source_node_key: Union[Tuple[str, str], str],
                          target_node_key: Union[Tuple[str, str], str],
                          rel_type: str,
                          cols: ArrayLike,
                          rel_key: str = None):
    source_target_label_map = _make_map(source_target_labels)
    source_node_key_map = _make_map(source_node_key)
    target_node_key_map = _make_map(target_node_key)

    merge_statement = f'MERGE(s)-[r:{rel_type}]->(t)'
    if rel_key is not None:
        merge_statement = f'MERGE(s)-[r:{rel_type} {{{rel_key}: rec.{rel_key}}}]->(t)'

    template = f'''UNWIND $recs AS rec
    MATCH(s:{source_target_label_map[0]} {{{source_node_key_map[0]}: rec.{source_node_key_map[1]}}})
    MATCH(t:{source_target_label_map[1]} {{{target_node_key_map[0]}: rec.{target_node_key_map[1]}}})\n''' + merge_statement
    prop_names = [x for x in cols if x not in [rel_key, source_node_key_map[1], target_node_key_map[1]]]
    if len(prop_names) > 0:
        template = template + '\n' + _make_set_clause(prop_names, 'r')
    return template + '\nRETURN count(r) AS relLoadedCount'


def chunks(xs, n: int = 10_000):
    """
    split an array-like objects into chunks of size n.

    Parameters
    -------
    :param n: int
        The size of chunk. The last chunk will be the remainder if there is one.
    """
    n = max(1, n)
    return [xs[i:i + n] for i in range(0, len(xs), n)]

def load_nodes(node_df: pd.DataFrame,
               node_key_col: str,
               node_label: str,
               chunk_size: int = 5_000,
               constraint: str = 'UNIQUE',
               neo4j_uri: str = 'bolt://localhost:7687',
               neo4j_password: str = 'password',
               neo4j_username: str = 'neo4j'):
    """
    Load nodes from a dataframe.

    Parameters
    -------
    :param node_df: pd.DataFrame
        The dataframe containing node data
    :param node_key_col: str
        The column of the dataframe to use as the MERGE key property
    :param node_label: str
        The node label to use (only one allowed).
    :param chunk_size: int , default 5_000
        The chunk size to use when batching rows for loading
    :param constraint: str , default "UNIQUE"
        The constraint to use for the node key. Can be "UNIQUE", "KEY", or None.
        More details at https://neo4j.com/docs/cypher-manual/current/constraints/examples/#constraints-examples-node-uniqueness.
        Using 'None' (no node constraint) can result in very poor load performance.
    :param neo4j_uri: str , default "bolt://localhost:7687"
        The uri for the Neo4j database
    :param neo4j_password: str , default "password"
        The password for the Neo4j database
    :param neo4j_username: str , default "neo4j"
        The password for the Neo4j database
    """

    print(f'======  loading {node_label} nodes  ======')

    records = node_df.to_dict('records')
    total = len(records)
    print(f'staged {total:,} records')
    with GraphDatabase.driver(neo4j_uri,
                              auth=(neo4j_username, neo4j_password)) as driver:
      if constraint:
        constraint = constraint.upper()
        if constraint not in ["UNIQUE", "KEY"]:
          raise ValueError(f'constraint must be one of ["UNIQUE", "KEY", None] but was {constraint}')
        const_query = _make_constraint_query(constraint, node_label, node_key_col)
        print(f'\ncreating constraint:\n```\n{const_query}\n```\n')
        driver.execute_query(const_query)

      query = _make_node_merge_query(node_key_col, node_label, node_df.columns.copy())
      print(f'\nusing this Cypher query to load data:\n```\n{query}\n```\n')
      cumulative_count = 0
      for recs in chunks(records, chunk_size):
          res = driver.execute_query(query, parameters_={'recs': recs})
          cumulative_count += res[0][0][0]
          print(f'loaded {cumulative_count:,} of {total:,} nodes')

In [3]:
from neo4j import GraphDatabase
from dotenv import load_dotenv
import os

#load neo4j credentials

load_dotenv('../db.env', override=True)
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')

In [4]:
print(os.getenv('NEO4J_URI'))

bolt://localhost:7687


In [19]:
for ind, node_type in skb.node_type_dict.items():
  single_node_type_df = (node_df[node_df['type']==node_type]
                         .drop(columns=['type']))
  node_label = format_node_label(node_type)
  load_nodes(single_node_type_df,
                   'nodeId',
                   node_label,
                   neo4j_uri=NEO4J_URI,
                   neo4j_password=NEO4J_PASSWORD)

staged 17,080 records

creating constraint:
```
CREATE CONSTRAINT unique_disease_nodeid IF NOT EXISTS FOR (n:Disease) REQUIRE n.nodeId IS UNIQUE
```


using this Cypher query to load data:
```
UNWIND $recs AS rec
MERGE(n:Disease {nodeId: rec.nodeId})
SET n.id = rec.id, n.name = rec.name, n.source = rec.source, n.details = rec.details
RETURN count(n) AS nodeLoadedCount
```

loaded 5,000 of 17,080 nodes
loaded 10,000 of 17,080 nodes
loaded 15,000 of 17,080 nodes
loaded 17,080 of 17,080 nodes
staged 27,671 records

creating constraint:
```
CREATE CONSTRAINT unique_geneorprotein_nodeid IF NOT EXISTS FOR (n:GeneOrProtein) REQUIRE n.nodeId IS UNIQUE
```


using this Cypher query to load data:
```
UNWIND $recs AS rec
MERGE(n:GeneOrProtein {nodeId: rec.nodeId})
SET n.id = rec.id, n.name = rec.name, n.source = rec.source, n.details = rec.details
RETURN count(n) AS nodeLoadedCount
```

loaded 5,000 of 27,671 nodes
loaded 10,000 of 27,671 nodes
loaded 15,000 of 27,671 nodes
loaded 20,000 of 27,67

## Format & Load Relationships
1. create a relationship dataframe
2. create function formatting relationship types based off of typeedge
3. relationship loading

In [20]:
import torch
import pandas as pd

rel_df = pd.DataFrame(
    torch.cat([skb.edge_index,
               skb.edge_types.reshape(1, skb.edge_types.size()[0])],
              dim=0).t(),
     columns = ['src', 'tgt', 'typeId'])
rel_df

Unnamed: 0,src,tgt,typeId
0,0,8889,0
1,1,2798,0
2,2,5646,0
3,3,11592,0
4,4,2122,0
...,...,...,...
8100493,66747,5259,17
8100494,63824,58254,17
8100495,63826,58254,17
8100496,64523,58254,17


In [5]:
rel_types = skb.edge_type_dict
rel_types

{0: 'ppi',
 1: 'carrier',
 2: 'enzyme',
 3: 'target',
 4: 'transporter',
 5: 'contraindication',
 6: 'indication',
 7: 'off-label use',
 8: 'synergistic interaction',
 9: 'associated with',
 10: 'parent-child',
 11: 'phenotype absent',
 12: 'phenotype present',
 13: 'side effect',
 14: 'interacts with',
 15: 'linked to',
 16: 'expression present',
 17: 'expression absent'}

In [6]:
import re

def format_rel_type(s):
  return re.sub('[^0-9A-Z]+', '_', s.upper())

In [7]:
[(k,format_rel_type(v)) for k,v in  skb.edge_type_dict.items()]

[(0, 'PPI'),
 (1, 'CARRIER'),
 (2, 'ENZYME'),
 (3, 'TARGET'),
 (4, 'TRANSPORTER'),
 (5, 'CONTRAINDICATION'),
 (6, 'INDICATION'),
 (7, 'OFF_LABEL_USE'),
 (8, 'SYNERGISTIC_INTERACTION'),
 (9, 'ASSOCIATED_WITH'),
 (10, 'PARENT_CHILD'),
 (11, 'PHENOTYPE_ABSENT'),
 (12, 'PHENOTYPE_PRESENT'),
 (13, 'SIDE_EFFECT'),
 (14, 'INTERACTS_WITH'),
 (15, 'LINKED_TO'),
 (16, 'EXPRESSION_PRESENT'),
 (17, 'EXPRESSION_ABSENT')]

In [24]:
# creating unifying node label for relationship load

with GraphDatabase.driver(NEO4J_URI,
                              auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
  driver.execute_query('MATCH(n) SET n:_Entity_')
  driver.execute_query('CREATE CONSTRAINT unique__entity__nodeid IF NOT EXISTS FOR (n:_Entity_) REQUIRE n.nodeId IS UNIQUE')

In [25]:
def load_rels(rel_df: pd.DataFrame,
              source_target_labels: Union[Tuple[str, str], str],
              source_node_key: Union[Tuple[str, str], str],
              target_node_key: Union[Tuple[str, str], str],
              rel_type: str,
              rel_key: str = None,
              chunk_size: int = 10_000,
              neo4j_uri: str = 'bolt://localhost:7687',
              neo4j_password: str = 'password',
              neo4j_username: str = 'neo4j'):
    """
    Load relationships from a dataframe.

    Parameters
    -------
    :param rel_df: pd.DataFrame
        The dataframe containing relationship data
    :param source_target_labels: Union[Tuple[str, str], str]
        The source and target node labels to use.
        Can pass a single string if source and target nodes have the same labels,
        otherwise a tuple of the form (source_node_label, target_node_label)
    :param source_node_key: Union[Tuple[str, str], str]
        The column of the dataframe to use as the source node MERGE key property.
        Can optionally pass a tuple of the form (source_node_key_name, df_column_name) to map as appropriate if the
        column name is different
    :param target_node_key: Union[Tuple[str, str], str]
        The column of the dataframe to use as the target node MERGE key property.
        Can optionally pass a tuple of the form (target_node_key_name, df_column_name) to map as appropriate if the
        column name is different
    :param rel_type: str
        The relationship type to use (only one allowed).
    :param rel_key: str
        A key to distinguish unique parallel relationships.
        The default behavior of this function is to assume only one instance of a relationship type between two nodes.
        A duplicate insert will have the behavior of overriding the existing relationship.
        If this behavior is undesirable, and you want to allow multiple instances of the same relationship type between
        two nodes (a.k.a parallel relationships), provide this key to use for merging relationships uniquely
    :param chunk_size: int , default 5_000
        The chunk size to use when batching rows for loading
    :param neo4j_uri: str , default "bolt://localhost:7687"
        The uri for the Neo4j database
    :param neo4j_password: str , default "password"
        The password for the Neo4j database
    :param neo4j_username: str , default "neo4j"
        The password for the Neo4j database
    """
    records = rel_df.to_dict('records')
    print(f'======  loading {rel_type} relationships  ======')
    total = len(records)
    print(f'staged {total:,} records')
    with GraphDatabase.driver(neo4j_uri,
                              auth=(neo4j_username, neo4j_password)) as driver:
      query = _make_rel_merge_query(source_target_labels, source_node_key,
                                  target_node_key, rel_type, rel_df.columns.copy(), rel_key)
      print(f'\nusing this cypher query to load data:\n```\n{query}\n```\n')
      cumulative_count = 0
      for recs in chunks(records, chunk_size):
          res = driver.execute_query(query, parameters_={'recs': recs})
          cumulative_count += res[0][0][0]
          print(f'loaded {cumulative_count:,} of {total:,} relationships')



In [26]:
for ind, edge_type in skb.edge_type_dict.items():
  single_rel_type_df = (rel_df[rel_df['typeId']==ind]
                         .drop(columns=['typeId']))
  rel_type = format_rel_type(edge_type)
  load_rels(single_rel_type_df,
              source_target_labels='_Entity_',
              source_node_key=('nodeId', 'src'),
              target_node_key=('nodeId', 'tgt'),
              rel_type=rel_type ,
              neo4j_uri=NEO4J_URI,
              neo4j_password=NEO4J_PASSWORD)




staged 642,150 records

using this cypher query to load data:
```
UNWIND $recs AS rec
    MATCH(s:_Entity_ {nodeId: rec.src})
    MATCH(t:_Entity_ {nodeId: rec.tgt})
MERGE(s)-[r:PPI]->(t)
RETURN count(r) AS relLoadedCount
```

loaded 10,000 of 642,150 relationships
loaded 20,000 of 642,150 relationships
loaded 30,000 of 642,150 relationships
loaded 40,000 of 642,150 relationships
loaded 50,000 of 642,150 relationships
loaded 60,000 of 642,150 relationships
loaded 70,000 of 642,150 relationships
loaded 80,000 of 642,150 relationships
loaded 90,000 of 642,150 relationships
loaded 100,000 of 642,150 relationships
loaded 110,000 of 642,150 relationships
loaded 120,000 of 642,150 relationships
loaded 130,000 of 642,150 relationships
loaded 140,000 of 642,150 relationships
loaded 150,000 of 642,150 relationships
loaded 160,000 of 642,150 relationships
loaded 170,000 of 642,150 relationships
loaded 180,000 of 642,150 relationships
loaded 190,000 of 642,150 relationships
loaded 200,000 of 642,

## Get & Load Embeddings
1. download pre-computed text-embedding-ada-002 embeddings
2. format embeddings
3. load embeddings
4. create vector index

In [21]:
# Load pre-generated openai text-embedding-ada-002 embeddings
# Get emb_download.py from https://github.com/snap-stanford/stark. see Readme for other ways to generate embeddings
! python emb_download.py --dataset prime --emb_dir emb/

Downloading...
From (original): https://drive.google.com/uc?id=1MshwJttPZsHEM2cKA5T13SIrsLeBEdyU
From (redirected): https://drive.google.com/uc?id=1MshwJttPZsHEM2cKA5T13SIrsLeBEdyU&confirm=t&uuid=e3431122-19cf-4346-a4f3-f372aed5defb
To: /content/emb/prime/text-embedding-ada-002/query/query_emb_dict.pt
100% 72.0M/72.0M [00:00<00:00, 81.5MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=16EJvCMbgkVrQ0BuIBvLBp-BYPaye-Edy
From (redirected): https://drive.google.com/uc?id=16EJvCMbgkVrQ0BuIBvLBp-BYPaye-Edy&confirm=t&uuid=46dc3fa9-85b6-4e4a-b25d-2788c8c3bfff
To: /content/emb/prime/text-embedding-ada-002/doc/candidate_emb_dict.pt
100% 832M/832M [00:14<00:00, 56.6MB/s]


In [27]:
import torch

emb = torch.load('emb/prime/text-embedding-ada-002/doc/candidate_emb_dict.pt')

In [29]:
emb[0]

tensor([[-0.0497, -0.0080, -0.0108,  ..., -0.0098, -0.0167, -0.0184]])

In [30]:
from tqdm import tqdm

# format embedding records
emb_records = []
for k,v in tqdm(emb.items()):
  emb_records.append({"nodeId":k ,"textEmbedding": v.squeeze().tolist()})
emb_records[:10]

100%|██████████| 129375/129375 [00:04<00:00, 30529.62it/s]


[{'nodeId': 0,
  'textEmbedding': [-0.0497407391667366,
   -0.008042690344154835,
   -0.010807578451931477,
   -0.035548556596040726,
   -0.015404374338686466,
   0.03494926914572716,
   -6.900514563312754e-05,
   0.026314103975892067,
   -0.00907101109623909,
   -0.026804428547620773,
   0.03331485390663147,
   -0.01841442473232746,
   0.0020464255940169096,
   0.008424054831266403,
   -0.021138451993465424,
   0.005403789225965738,
   0.018210122361779213,
   -0.021846698597073555,
   -0.003117308719083667,
   0.015445235185325146,
   -0.021097591146826744,
   0.008975669741630554,
   -0.021451715379953384,
   -0.02226892299950123,
   3.306607322883792e-05,
   0.022745627909898758,
   0.023304054513573647,
   -0.028820209205150604,
   -0.017093271017074585,
   -0.018373563885688782,
   0.011869949288666248,
   0.007348063867539167,
   0.005478699691593647,
   0.005281208083033562,
   0.013722287490963936,
   -0.025401555001735687,
   0.0019085216335952282,
   -0.0138448690995574,
   

In [31]:
# load embeddings

print(f'======  loading text embeddings ======')

total = len(emb_records)
print(f'staged {total:,} records')
with GraphDatabase.driver(NEO4J_URI,
                          auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:

  query = """
  UNWIND $recs AS rec
  MATCH(n:_Entity_ {nodeId: rec.nodeId})
  CALL db.create.setNodeVectorProperty(n, "textEmbedding", rec.textEmbedding)
  RETURN count(n) AS embeddingLoadedCount
  """
  print(f'\nusing this Cypher query to load data:\n```\n{query}\n```\n')
  cumulative_count = 0
  for recs in chunks(emb_records, 1_000):
      res = driver.execute_query(query, parameters_={'recs': recs})
      cumulative_count += res[0][0][0]
      print(f'loaded {cumulative_count:,} of {total:,} embeddings')



staged 129,375 records

using this Cypher query to load data:
```

  UNWIND $recs AS rec
  MATCH(n:_Entity_ {nodeId: rec.nodeId})
  CALL db.create.setNodeVectorProperty(n, "textEmbedding", rec.textEmbedding)
  RETURN count(n) AS embeddingLoadedCount
  
```

loaded 1,000 of 129,375 embeddings
loaded 2,000 of 129,375 embeddings
loaded 3,000 of 129,375 embeddings
loaded 4,000 of 129,375 embeddings
loaded 5,000 of 129,375 embeddings
loaded 6,000 of 129,375 embeddings
loaded 7,000 of 129,375 embeddings
loaded 8,000 of 129,375 embeddings
loaded 9,000 of 129,375 embeddings
loaded 10,000 of 129,375 embeddings
loaded 11,000 of 129,375 embeddings
loaded 12,000 of 129,375 embeddings
loaded 13,000 of 129,375 embeddings
loaded 14,000 of 129,375 embeddings
loaded 15,000 of 129,375 embeddings
loaded 16,000 of 129,375 embeddings
loaded 17,000 of 129,375 embeddings
loaded 18,000 of 129,375 embeddings
loaded 19,000 of 129,375 embeddings
loaded 20,000 of 129,375 embeddings
loaded 21,000 of 129,375 embedd

In [32]:
# create vector index

with GraphDatabase.driver(NEO4J_URI,
                          auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
  driver.execute_query('''
  CREATE VECTOR INDEX text_embeddings IF NOT EXISTS FOR (n:_Entity_) ON (n.textEmbedding)
  OPTIONS {indexConfig: {
  `vector.dimensions`: toInteger($dimension),
  `vector.similarity_function`: 'cosine'
  }}''', parameters_={'dimension': len(emb_records[0]['textEmbedding'])})
  driver.execute_query('CALL db.awaitIndex("text_embeddings", 300)')

In [None]:
#### Generate relationship type embedding for all 18 reltypes

In [None]:
from langchain_openai import OpenAIEmbeddings
embedding_model = OpenAIEmbeddings(model="text-embedding-ada-002")
reltype_emb = {format_rel_type(v): embedding_model.embed_query(v) for k,v in  skb.edge_type_dict.items()}
import torch
torch.save(reltype_emb, 'emb/prime/text-embedding-ada-002/doc/reltype_emb_dict.pt')