# Load [`ogbn-arxiv`](https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv) Graph Into Neo4j (~5 Min)
![Neo4j version](https://img.shields.io/badge/Neo4j->=4.4.9-brightgreen)
![GDS version](https://img.shields.io/badge/GDS-2.2-brightgreen)
![GDS Python Client version](https://img.shields.io/badge/GDS_Python_Client-1.5-brightgreen)

In [13]:
!pip install ogb



In [14]:
from ogb.nodeproppred import NodePropPredDataset
# Download and process data at './dataset/ogbg_molhiv/'
dataset = NodePropPredDataset(name = "ogbn-arxiv", root = 'dataset/')

In [15]:
graph, subject = dataset[0]

In [16]:
import pandas as pd
raw_node_df = pd.DataFrame(graph['node_feat'])

raw_node_df['wordEmbedding'] = raw_node_df.apply(lambda x: x.tolist(), axis=1)
node_df = raw_node_df[['wordEmbedding']].reset_index().rename(columns={'index': 'paperId'})
node_df

Unnamed: 0,paperId,wordEmbedding
0,0,"[-0.05794300138950348, -0.05253000184893608, -..."
1,1,"[-0.12449999898672104, -0.07066500186920166, -..."
2,2,"[-0.08024200052022934, -0.02332800067961216, -..."
3,3,"[-0.1450439989566803, 0.05491499975323677, -0...."
4,4,"[-0.07115399837493896, 0.07076600193977356, -0..."
...,...,...
169338,169338,"[-0.32135099172592163, -0.03933500126004219, -..."
169339,169339,"[-0.15121200680732727, -0.12470199912786484, -..."
169340,169340,"[-0.22053000330924988, -0.03656800091266632, -..."
169341,169341,"[-0.13823600113391876, 0.04088500142097473, -0..."


In [17]:
node_df['year'] = graph['node_year']
node_df['subject'] = subject

In [18]:
split_index = dataset.get_idx_split()

In [19]:
node_df['split'] = -1
node_df['splitName'] = 'UNKNOWN'

In [20]:
node_df.loc[split_index['train'], 'split'] = 0
node_df.loc[split_index['train'], 'splitName'] = 'TRAIN'

node_df.loc[split_index['valid'], 'split'] = 1
node_df.loc[split_index['valid'], 'splitName'] = 'VALID'

node_df.loc[split_index['test'], 'split'] = 2
node_df.loc[split_index['test'], 'splitName'] = 'TEST'

In [21]:
node_df[['split', 'splitName', 'paperId']].groupby(['split', 'splitName']).agg('count').rename(columns={'paperId': 'cnt'})

Unnamed: 0_level_0,Unnamed: 1_level_0,cnt
split,splitName,Unnamed: 2_level_1
0,TRAIN,90941
1,VALID,29799
2,TEST,48603


In [22]:
edge_df = pd.DataFrame(graph['edge_index'].T)
edge_df.columns = ['paperId', 'citedPaperId']
edge_df

Unnamed: 0,paperId,citedPaperId
0,104447,13091
1,15858,47283
2,107156,69161
3,107156,136440
4,107156,107366
...,...,...
1166238,45118,79124
1166239,45118,147994
1166240,45118,162473
1166241,45118,162537


In [23]:
import json
with open('secrets.json') as f:
    secrets = json.load(f)

In [24]:
from graphdatascience import GraphDataScience

# Use Neo4j URI and credentials according to your setup
gds = GraphDataScience(secrets['host'], auth=(secrets['username'], secrets['password']), aura_ds=True)

In [25]:
# Clear last graph - All data and schema attributes
gds.run_cypher('MATCH(n) DETACH DELETE n')
gds.run_cypher('CALL apoc.schema.assert({},{})')

Unnamed: 0,label,key,keys,unique,action


In [26]:
gds.run_cypher('CREATE CONSTRAINT paper_unique IF NOT EXISTS ON (n:Paper) ASSERT n.paperId  IS UNIQUE')

In [27]:
node_df_chunks = []
i=0
while i<node_df.shape[0]:
    next_i = i + 10_000
    node_df_chunks.append(node_df[i:next_i])
    i = next_i
len(node_df_chunks)

17

In [28]:
i=0
for node_df_chunk in node_df_chunks:
    i+=1
    node_records = node_df_chunk.to_dict('records')
    print(gds.run_cypher('''
        UNWIND $nodeRecords AS nodeRecord
        WITH toInteger(nodeRecord.paperId) AS paperId,
            toFloatList(nodeRecord.wordEmbedding) AS wordEmbedding,
            toInteger(nodeRecord.year) AS year,
            toInteger(nodeRecord.subject) AS subject,
            toInteger(nodeRecord.split) AS split,
            nodeRecord.splitName AS splitName
        MERGE(n:Paper {paperId: paperId})
        SET n.wordEmbedding=wordEmbedding,
            n.year=year,
            n.subject=subject,
            n.split=split,
            n.splitName=splitName
        RETURN count(n)
    ''', params={'nodeRecords':node_records}))
    print(f'Ingested {i} of {len(node_df_chunks)} chunks')

   count(n)
0     10000
Ingested 1 of 17 chunks
   count(n)
0     10000
Ingested 2 of 17 chunks
   count(n)
0     10000
Ingested 3 of 17 chunks
   count(n)
0     10000
Ingested 4 of 17 chunks
   count(n)
0     10000
Ingested 5 of 17 chunks
   count(n)
0     10000
Ingested 6 of 17 chunks
   count(n)
0     10000
Ingested 7 of 17 chunks
   count(n)
0     10000
Ingested 8 of 17 chunks
   count(n)
0     10000
Ingested 9 of 17 chunks
   count(n)
0     10000
Ingested 10 of 17 chunks
   count(n)
0     10000
Ingested 11 of 17 chunks
   count(n)
0     10000
Ingested 12 of 17 chunks
   count(n)
0     10000
Ingested 13 of 17 chunks
   count(n)
0     10000
Ingested 14 of 17 chunks
   count(n)
0     10000
Ingested 15 of 17 chunks
   count(n)
0     10000
Ingested 16 of 17 chunks
   count(n)
0      9343
Ingested 17 of 17 chunks


In [29]:
gds.run_cypher('''
    MATCH (n:Paper) WHERE n.split=0
    SET n:TrainPaper
''')

gds.run_cypher('''
    MATCH (n:Paper) WHERE n.split=1
    SET n:ValidPaper
''')

gds.run_cypher('''
    MATCH (n:Paper) WHERE n.split=2
    SET n:TestPaper
''')

In [30]:
edge_df_chunks = []
i=0
while i<edge_df.shape[0]:
    next_i = i + 150_000
    edge_df_chunks.append(edge_df[i:next_i])
    i = next_i
len(edge_df_chunks)

8

In [31]:
i=0
for edge_df_chunk in edge_df_chunks:
    i+=1
    edge_records = edge_df_chunk.to_dict('records')
    print(gds.run_cypher('''
        UNWIND $edgeRecords AS edgeRecord
        WITH toInteger(edgeRecord.paperId) AS paperId,
            toInteger(edgeRecord.citedPaperId) AS citedPaperId
        MATCH(n0:Paper {paperId: paperId})
        MATCH(n1:Paper {paperId: citedPaperId})
        MERGE (n0)-[r:CITED]->(n1)
        RETURN count(r)
    ''', params={'edgeRecords':edge_records}))
    print(f'Ingested {i} of {len(edge_df_chunks)} chunks')

   count(r)
0    150000
Ingested 1 of 8 chunks
   count(r)
0    150000
Ingested 2 of 8 chunks
   count(r)
0    150000
Ingested 3 of 8 chunks
   count(r)
0    150000
Ingested 4 of 8 chunks
   count(r)
0    150000
Ingested 5 of 8 chunks
   count(r)
0    150000
Ingested 6 of 8 chunks
   count(r)
0    150000
Ingested 7 of 8 chunks
   count(r)
0    116243
Ingested 8 of 8 chunks
