# Using Graph Sampling to Scale GraphSage GNN

![Neo4j version](https://img.shields.io/badge/Neo4j->=4.4.9-brightgreen)
![GDS version](https://img.shields.io/badge/GDS-2.2-brightgreen)
![GDS Python Client version](https://img.shields.io/badge/GDS_Python_Client-1.5-brightgreen)

This notebook exemplifies how to use Neo4j Graph Data Science to scale unsupervised GraphSage (a graph neural network) on larger graphs.  Specifically it demonstrates how to:

* Sample a part of the `ogbn-arxiv` graph using the [GDS random walk with restarts algorithm](https://neo4j.com/docs/graph-data-science/current/management-ops/projections/rwr/)
* Train an unsupervised [GraphSage](https://neo4j.com/docs/graph-data-science/current/machine-learning/node-embeddings/graph-sage/) model on the subgraph to generate node embeddings
* Apply the trained GraphSage model to generate node embeddings on the entire graph

## Prerequisites
- The `ogbn-arxiv` dataset must be loaded into Neo4j. You can do so by running the `load-ogbn-arxiv-data` notebook. It should take around a few minutes to complete.

In [32]:
!pip install ogb



## Notebook Setup

In [33]:
import json
with open('secrets.json') as f:
    secrets = json.load(f)

In [34]:
from graphdatascience import GraphDataScience

# Use Neo4j URI and credentials according to your setup
gds = GraphDataScience(secrets['host'], auth=(secrets['username'], secrets['password']), aura_ds=True)

# Necessary if you enabled Arrow on the db
gds.set_database("neo4j")

In [35]:
def clear_graph_by_name(g_name):
    if gds.graph.exists(g_name).exists:
        g = gds.graph.get(g_name)
        gds.graph.drop(g)

def clear_all_graphs():
    g_names = gds.graph.list().graphName.tolist()
    for g_name in g_names:
        g = gds.graph.get(g_name)
        gds.graph.drop(g)

def clear_model_by_name(m_name):
    if gds.beta.model.exists(m_name).exists:
        m = gds.model.get(m_name)
        gds.beta.model.drop(m)

## Project and Sample Graph

In [36]:
clear_all_graphs()
g,_ = gds.graph.project('proj',
                  {'TrainPaper':{'properties': ['wordEmbedding', 'subject']}},
                  {'CITED': {'orientation': 'UNDIRECTED'}})
_

nodeProjection            {'TrainPaper': {'label': 'TrainPaper', 'proper...
relationshipProjection    {'CITED': {'orientation': 'UNDIRECTED', 'aggre...
graphName                                                              proj
nodeCount                                                             90941
relationshipCount                                                    749678
projectMillis                                                           477
Name: 0, dtype: object

In [37]:
# We use the random walk with restarts sampling algorithm with default values, should get ~0.1 * 90941 = 9094 nodes
g_sample , _ = gds.alpha.graph.sample.rwr('sample', g, samplingRatio=0.1, restartProbability=0.05, concurrency=1, randomSeed=42);
_

fromGraphName          proj
startNodeCount            1
graphName            sample
nodeCount              9094
relationshipCount    136548
projectMillis           157
Name: 0, dtype: object

## Train GraphSage on Sample

In [38]:
clear_model_by_name('gsModel')

gds.beta.graphSage.train(g_sample, modelName='gsModel', embeddingDimension=512, sampleSizes=[30, 30], searchDepth=20,
                         epochs=5, learningRate=0.0001, activationFunction='RELU', aggregator='MEAN', featureProperties=['wordEmbedding'],
                         randomSeed=99, batchSize=10)

(GraphSageModel({'modelInfo': {0: {'modelName': 'gsModel', 'modelType': 'graphSage', 'metrics': {'ranIterationsPerEpoch': [10, 10, 10, 10, 10], 'iterationLossesPerEpoch': [[26.101553108713553, 25.96468016348165, 26.01535189532628, 25.763774801284846, 25.773294365996964, 25.573388611771332, 25.71351555960887, 25.49865176770185, 25.576237100018254, 25.011561267814475], [25.29816268642032, 25.11148399891449, 25.10644964509697, 24.84297034180481, 24.536259889122004, 24.31699363355826, 24.154708251162685, 24.024915116934157, 24.12494143918854, 23.791079130398572], [24.013198031084706, 22.839763994254934, 23.564918236135895, 22.763391478248135, 21.91081223335163, 22.494262953336488, 21.820997726188587, 21.51781860314818, 21.244757061956907, 21.222037056151656], [21.206315496024068, 20.522718717359727, 20.36431869546906, 19.653473607394307, 19.57347365371603, 19.266385907301903, 18.478333414486173, 18.268084080755575, 18.461586742264117, 16.718567940596287], [16.70566227400297, 15.87219059632

## Use GraphSage Model to Generate Embeddings on Entire Graph

In [39]:
gds.beta.graphSage.mutate(g, modelName='gsModel', mutateProperty='gsEmbedding')

GraphSage:   0%|          | 0/100 [00:00<?, ?%/s]

nodePropertiesWritten                                                90941
mutateMillis                                                             0
nodeCount                                                            90941
preProcessingMillis                                                      0
computeMillis                                                        22945
configuration            {'jobId': '3580c727-af55-49ce-9869-ba7a34a0f79...
Name: 0, dtype: object

In [40]:
# gds.graph.nodeProperties.stream(g, node_properties='gsEmbedding')

Unnamed: 0,nodeId,nodeProperty,propertyValue
0,167532,gsEmbedding,"[-0.002488996944798357, -0.0005080555086773295..."
1,167533,gsEmbedding,"[0.03931017949390128, 0.08091303579171928, -0...."
2,167534,gsEmbedding,"[0.10820473153838188, 0.10284406655456249, -0...."
3,167535,gsEmbedding,"[-0.013440787413609578, -0.04416116902508773, ..."
4,167536,gsEmbedding,"[0.03350507899774372, 0.06807220574646418, -0...."
...,...,...,...
90936,55834,gsEmbedding,"[-0.0006329617627241897, 0.001718238673257951,..."
90937,55835,gsEmbedding,"[-0.0008074166558277462, -0.000157246319203552..."
90938,55836,gsEmbedding,"[-0.00015979038170613373, -0.00214747828621173..."
90939,55837,gsEmbedding,"[0.0811878506079212, 0.012883562618342843, -0...."
