# Transductive Node Classification

![Neo4j version](https://img.shields.io/badge/Neo4j->=4.4.9-brightgreen)
![GDS version](https://img.shields.io/badge/GDS-2.3-brightgreen)
![GDS Python Client version](https://img.shields.io/badge/GDS_Python_Client-1.6-brightgreen)

__This notebook demonstrates how graph features can be used to improve Machine Learning accuracy__ in a transductive setting.
In this example, we see accuracy increase by ~20% for supervised node classification.

## Transductive Node classification
In this application of graph machine learning we are provided a graph with partially missing node labels.  The goal is to train a model using the subset of labeled nodes and then to predict the missing node labels with the model. The entire graph is available during training, including the nodes with missing labels.
We refer to this as *__transductive__* graph machine learning, and it is a very powerful approach to leverage knowledge graph and other network relationships to improve classifier performance.

## Dataset
The dataset we use to demonstrate is the [`ogbn-arxiv`](https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv) citation graph which is composed of papers as nodes and citations between papers as relationships.  Each paper comes with a 128-dimensional floating point vector representing a word embedding of the paper's title and abstract.

__Classification Task:__ The goal is to predict the paper subject, i.e. what the paper is about. There are 40 possible subjects.

To reflect the transductive setting we will redact a random subset of node labels.

## Models
We will train two Machine Learning Models and compare their test set accuracy:

- __Default Best Guess__: A really naive heuristic (predict by most frequent class in training set) to use as a sanity check. A model must do better than this to be useful.
- __Non-Graph (NLP Only)__ : Use just the 128-dimensional word embeddings as feature inputs to a Neural Network classifier.
- __Graph ML with FastRP Embeddings:__ Generate Fast Random Projection (FastRP) node embeddings with the word embeddings as weights. Use the FastRP embeddings as inputs to a Neural Network Classifier
- __Graph ML with GraphSage Embeddings:__ train an unsupervised GraphSage model (a type of neural network) on a subset of the graph, using word embeddings as inputs, to generate node embeddings.  Use the Graph Sage model to predict node embeddings on the entire graph. Use the node embeddings as inputs to a Neural Network Classifier.


In [1]:
import torch
import pandas as pd
import numpy as np
from graphdatascience import GraphDataScience
from dotenv import load_dotenv
import os
import benchmark.ogbn_arxiv as bm_ogbn_arxiv
from graph_data.data_import import get_ogbn_arxiv_data



## Prepare Data
Load source data, mask a proportion of node labels to reflect transductive setting, and create test and train set indexes.

In [2]:
RANDOM_SEED = 7474
MISSING_LABEL_PROPORTION = 0.15
TEST_SET_PROPORTION = 0.2
VALID_SET_PROPORTION = 0.25

In [3]:
paper_source_df, citation_source_df = get_ogbn_arxiv_data()
paper_source_df

Unnamed: 0,nodeId,textEmbedding,year,subjectId
0,0,"[-0.05794300138950348, -0.05253000184893608, -...",2013,4
1,1,"[-0.12449999898672104, -0.07066500186920166, -...",2015,5
2,2,"[-0.08024200052022934, -0.02332800067961216, -...",2014,28
3,3,"[-0.1450439989566803, 0.05491499975323677, -0....",2014,8
4,4,"[-0.07115399837493896, 0.07076600193977356, -0...",2014,27
...,...,...,...,...
169338,169338,"[-0.32135099172592163, -0.03933500126004219, -...",2020,4
169339,169339,"[-0.15121200680732727, -0.12470199912786484, -...",2020,24
169340,169340,"[-0.22053000330924988, -0.03656800091266632, -...",2020,10
169341,169341,"[-0.13823600113391876, 0.04088500142097473, -0...",2020,4


In [4]:
# Set some labels as missing to reflect transductive setting
paper_df = paper_source_df.copy()
missing_label_idx = paper_df.sample(frac=MISSING_LABEL_PROPORTION, random_state=RANDOM_SEED).nodeId
paper_df.loc[paper_df.nodeId.isin(missing_label_idx), 'subjectId'] = np.NaN

# randomly split test and training
# we call this 'labels' to prepare for graph ingest
test_idx = paper_df[~paper_df.nodeId.isin(missing_label_idx)].sample(frac=TEST_SET_PROPORTION, random_state=RANDOM_SEED).nodeId
valid_idx = paper_df[(~paper_df.nodeId.isin(missing_label_idx)) & (~paper_df.nodeId.isin(test_idx))].sample(frac=VALID_SET_PROPORTION, random_state=RANDOM_SEED).nodeId
train_idx = paper_df.nodeId[(~paper_df.nodeId.isin(missing_label_idx)) & (~paper_df.nodeId.isin(test_idx))& (~paper_df.nodeId.isin(valid_idx))]

print(f'{100*sum(paper_df.subjectId.isna())/paper_df.shape[0]:.6}% of the papers are now masked as missing a subject')
print(f'{100*len(test_idx)/paper_df.shape[0]:.6}% of the papers are assigned to the test set')
print(f'{100*len(valid_idx)/paper_df.shape[0]:.6}% of the papers are assigned to the validation set')
print(f'{100*len(train_idx)/paper_df.shape[0]:.6}% of the papers are assigned to the train set')
print(f'{100*(len(train_idx) + len(test_idx) + len(valid_idx) + len(missing_label_idx))/paper_df.shape[0]}% of the papers are assigned to just one of the above sets')

14.9997% of the papers are now masked as missing a subject
16.9998% of the papers are assigned to the test set
16.9998% of the papers are assigned to the validation set
51.0006% of the papers are assigned to the train set
100.0% of the papers are assigned to just one of the above sets


## Default Best Guess
As a dummy baseline - how good do we do if we predicted all examples as the most frequent class in the training set?
A useful model must do better than this.

In [5]:
%%time
default_stats = bm_ogbn_arxiv.default_best_guess_benchmark(paper_df, train_idx, valid_idx, test_idx)
default_stats

CPU times: user 23.9 ms, sys: 2.36 ms, total: 26.3 ms
Wall time: 25.1 ms


{'train_acc': 0.1608155987309821,
 'valid_acc': 0.1653119355286925,
 'test_acc': 0.16037932471863275}

This performs horribly, as one may expect, with an accuracy of about 16%.

## Non Graph (NLP Only)
Use non-graph features only.  For a model with graph features to be useful, it must do better than this.

In [6]:
# convert indexes to tensors for PyTorch
train_idx = torch.tensor(train_idx.to_numpy(), dtype=torch.long)
valid_idx = torch.tensor(valid_idx.to_numpy(), dtype=torch.long)
test_idx = torch.tensor(test_idx.to_numpy(), dtype=torch.long)

In [7]:
%%time
x=torch.tensor(np.stack(paper_df.textEmbedding), dtype=torch.float)
y=torch.tensor(paper_df.subjectId.to_numpy(), dtype=torch.long)

non_graph_benchmark, _ = bm_ogbn_arxiv.run_model(x, y, train_idx, valid_idx, test_idx,
                                                     hidden_dims=[64], epochs=300, patience=3, verbose=False)

CPU times: user 16.8 s, sys: 5.18 s, total: 22 s
Wall time: 8.15 s


In [8]:
non_graph_benchmark.qualityChecks

{'train': {'allZeroFeatureVec_Count': 0, 'allZeroFeatureVec_Percent': 0.0},
 'valid': {'allZeroFeatureVec_Count': 0, 'allZeroFeatureVec_Percent': 0.0},
 'test': {'allZeroFeatureVec_Count': 0, 'allZeroFeatureVec_Percent': 0.0}}

In [9]:
non_graph_benchmark.bestStats

{'epoch': 86,
 'loss': 1.651861548423767,
 'train_acc': 0.5499965264108562,
 'valid_acc': 0.5484924274003057,
 'test_acc': 0.5385229956926497}

## Load Graph Into Neo4j for Feature Engineering

In [10]:
load_dotenv('db-credentials.env', override=True)

# Use Neo4j URI and credentials according to our setup
gds = GraphDataScience(
    os.getenv('NEO4J_URI'),
    auth=(os.getenv('NEO4J_USERNAME'),
          os.getenv('NEO4J_PASSWORD')),
    aura_ds=eval(os.getenv('AURA_DS').title()))

# Necessary if you enabled Arrow on the db - this is true for AuraDS
gds.set_database("neo4j")
PROJ_NAME = 'proj'

In [11]:
gds.version()

'2.3.0'

In [12]:
node_df = paper_df.drop(columns=['subjectId', 'year'])
node_df['labels'] = 'Paper'
node_df

Unnamed: 0,nodeId,textEmbedding,labels
0,0,"[-0.05794300138950348, -0.05253000184893608, -...",Paper
1,1,"[-0.12449999898672104, -0.07066500186920166, -...",Paper
2,2,"[-0.08024200052022934, -0.02332800067961216, -...",Paper
3,3,"[-0.1450439989566803, 0.05491499975323677, -0....",Paper
4,4,"[-0.07115399837493896, 0.07076600193977356, -0...",Paper
...,...,...,...
169338,169338,"[-0.32135099172592163, -0.03933500126004219, -...",Paper
169339,169339,"[-0.15121200680732727, -0.12470199912786484, -...",Paper
169340,169340,"[-0.22053000330924988, -0.03656800091266632, -...",Paper
169341,169341,"[-0.13823600113391876, 0.04088500142097473, -0...",Paper


In [13]:
rel_df = citation_source_df.rename(columns={'paper': 'sourceNodeId', 'citedPaper': 'targetNodeId'})
rel_df['relationshipType'] = 'CITED'
rel_df

Unnamed: 0,sourceNodeId,targetNodeId,relationshipType
0,104447,13091,CITED
1,15858,47283,CITED
2,107156,69161,CITED
3,107156,136440,CITED
4,107156,107366,CITED
...,...,...,...
1166238,45118,79124,CITED
1166239,45118,147994,CITED
1166240,45118,162473,CITED
1166241,45118,162537,CITED


In [14]:
if gds.graph.exists(PROJ_NAME)['exists']:
    gds.graph.get(PROJ_NAME).drop()

In [15]:
%%time
g = gds.alpha.graph.construct(PROJ_NAME, node_df, rel_df, undirected_relationship_types = ['CITED'])

Uploading Nodes:   0%|          | 0/169343 [00:00<?, ?Records/s]

Uploading Relationships:   0%|          | 0/1166243 [00:00<?, ?Records/s]

CPU times: user 939 ms, sys: 1.83 s, total: 2.77 s
Wall time: 1min 15s


In [16]:
print(f'Node Count: {g.node_count():,}')
print(f'Relationship Count: {g.relationship_count():,}')

Node Count: 169,343
Relationship Count: 2,332,486


## Graph ML with FastRP Embeddings

In [17]:
# Create FastRP embeddings
gds.fastRP.mutate(g, embeddingDimension=256, mutateProperty='fastrpEmb', featureProperties=['textEmbedding'],
                  propertyRatio=0.5, randomSeed=RANDOM_SEED)

FastRP:   0%|          | 0/100 [00:00<?, ?%/s]

nodePropertiesWritten                                               169343
mutateMillis                                                             0
nodeCount                                                           169343
preProcessingMillis                                                      0
computeMillis                                                         1559
configuration            {'nodeSelfInfluence': 0, 'propertyRatio': 0.5,...
Name: 0, dtype: object

In [18]:
# stream embeddings and merge into other data
fastrp_df = gds.graph.nodeProperties.stream(g, node_properties=['fastrpEmb'], separate_property_columns=True)
fastrp_df = paper_df.merge(fastrp_df, on='nodeId').drop(columns=['textEmbedding'])
fastrp_df

Unnamed: 0,nodeId,year,subjectId,fastrpEmb
0,0,2013,4.0,"[-0.010436784, 2.0225583e-05, -0.0057415017, 0..."
1,1,2015,5.0,"[0.031432256, -0.0018280161, 0.010902237, -0.0..."
2,2,2014,28.0,"[0.0007805123, 0.0008067958, 0.008566914, 0.01..."
3,3,2014,8.0,"[0.01189787, -0.029362265, -0.013015488, -0.01..."
4,4,2014,27.0,"[0.018842973, 0.007072161, -0.012446582, -0.00..."
...,...,...,...,...
169338,169338,2020,,"[-0.008318075, -0.0039453013, -0.0050907666, 2..."
169339,169339,2020,24.0,"[-0.0048398594, -0.0017861046, 0.003423419, 0...."
169340,169340,2020,10.0,"[-0.00029355922, 0.0054876376, -0.0012249375, ..."
169341,169341,2020,4.0,"[-0.0060406476, 0.0008232696, -0.0027380043, 0..."


In [19]:
%%time
x=torch.tensor(np.stack(fastrp_df.fastrpEmb), dtype=torch.float)
y=torch.tensor(fastrp_df.subjectId.to_numpy(), dtype=torch.long)

fastrp_benchmark, _= bm_ogbn_arxiv.run_model(x, y, train_idx, valid_idx, test_idx,
                                                 hidden_dims=[64], epochs=300, patience=3, verbose=False)

CPU times: user 23.2 s, sys: 7.69 s, total: 30.9 s
Wall time: 11 s


In [20]:
fastrp_benchmark.qualityChecks

{'train': {'allZeroFeatureVec_Count': 0, 'allZeroFeatureVec_Percent': 0.0},
 'valid': {'allZeroFeatureVec_Count': 0, 'allZeroFeatureVec_Percent': 0.0},
 'test': {'allZeroFeatureVec_Count': 0, 'allZeroFeatureVec_Percent': 0.0}}

In [21]:
fastrp_benchmark.bestStats

{'epoch': 67,
 'loss': 1.104737639427185,
 'train_acc': 0.6897505962994697,
 'valid_acc': 0.6842781714603307,
 'test_acc': 0.6795192441294984}

In [22]:
# clean the graph projection
gds.graph.nodeProperties.drop(g, 'fastrpEmb')

graphName                   proj
nodeProperties       [fastrpEmb]
propertiesRemoved         169343
Name: 0, dtype: object

## Graph ML with GraphSage Embeddings

In [23]:
# create a graph subsample to train graphSage. This will speed up computation
if gds.graph.exists(PROJ_NAME + '_sample')['exists']:
    gds.graph.get(PROJ_NAME + '_sample').drop()
g_sample, _ = gds.alpha.graph.sample.rwr(PROJ_NAME + '_sample', g, samplingRatio=0.4,
                                         restartProbability=0.05, concurrency=1, randomSeed=RANDOM_SEED)

Random walk with restarts sampling:   0%|          | 0/100 [00:00<?, ?%/s]

In [24]:
# train GraphSage
gds.beta.graphSage.train(g_sample, modelName='gsModel', embeddingDimension=256, sampleSizes=[30, 30],
                         searchDepth=20, epochs=20, learningRate=0.001, activationFunction='RELU',
                         aggregator='MEAN', featureProperties=['textEmbedding'], randomSeed=RANDOM_SEED,
                         batchSize=10)

(GraphSageModel({'modelInfo': {0: {'modelName': 'gsModel', 'modelType': 'graphSage', 'metrics': {'ranIterationsPerEpoch': [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10], 'iterationLossesPerEpoch': [[25.990710291134878, 25.284054177938096, 24.87264464276741, 23.36564632372362, 22.979104190158175, 21.109724708851353, 19.43509812234374, 19.309677713411965, 18.188368736570037, 16.54027878135621], [16.66206119471486, 17.65879630702002, 17.626914687481012, 17.072811829716365, 17.075401323594075, 17.24369978112747, 16.304216817174037, 16.19267271540229, 16.178919873941943, 15.583437476216588], [15.906564445072167, 17.20602711188626, 17.39985436674118, 17.861158531903456, 16.59074805266497, 16.565785453005283, 16.39429658944531, 16.074582979402113, 16.004782809136156, 15.52596174033766], [15.16432652532097, 17.228808211696872, 16.82976115328669, 15.969812854370025, 17.0671662025759, 16.2748804962936, 16.70824146852329, 15.418074858515217, 16.090494153009228, 1

In [25]:
# generate graphSage embeddings
gds.beta.graphSage.mutate(g, modelName='gsModel', mutateProperty='graphSageEmb')

GraphSage:   0%|          | 0/100 [00:00<?, ?%/s]

nodePropertiesWritten                                               169343
mutateMillis                                                             0
nodeCount                                                           169343
preProcessingMillis                                                      0
computeMillis                                                        34361
configuration            {'jobId': 'b3de3da0-b792-4743-a9be-5291cdd9276...
Name: 0, dtype: object

In [26]:
# stream embeddings and merge into other data
graphsage_df = gds.graph.nodeProperties.stream(g, node_properties=['graphSageEmb'], separate_property_columns=True)
graphsage_df = paper_df.merge(graphsage_df, on='nodeId').drop(columns=['textEmbedding'])
graphsage_df

Unnamed: 0,nodeId,year,subjectId,graphSageEmb
0,0,2013,4.0,"[-0.005759144506133723, -0.004078297011198453,..."
1,1,2015,5.0,"[-0.00420300212688539, -0.00583400407235607, -..."
2,2,2014,28.0,"[-0.000851027726619688, -0.00466705959319111, ..."
3,3,2014,8.0,"[-0.0012063443654581204, -0.001550210771127644..."
4,4,2014,27.0,"[-0.0019912510476662666, -0.006279553423920565..."
...,...,...,...,...
169338,169338,2020,,"[-0.01091124838158967, -0.009865168182162234, ..."
169339,169339,2020,24.0,"[-0.04382474807042934, -0.0021663353088544153,..."
169340,169340,2020,10.0,"[-0.011078653669222765, 0.008526905611005758, ..."
169341,169341,2020,4.0,"[-0.005614548899110106, -0.004550669226187732,..."


In [27]:
%%time
x=torch.tensor(np.stack(graphsage_df.graphSageEmb), dtype=torch.float)
y=torch.tensor(graphsage_df.subjectId.to_numpy(), dtype=torch.long)

gs_benchmark, _ = bm_ogbn_arxiv.run_model(x, y, train_idx, valid_idx, test_idx,
                                             hidden_dims=[64], epochs=300, patience=3, verbose=False)

CPU times: user 19.3 s, sys: 6.22 s, total: 25.5 s
Wall time: 9.14 s


In [28]:
gs_benchmark.qualityChecks

{'train': {'allZeroFeatureVec_Count': 0, 'allZeroFeatureVec_Percent': 0.0},
 'valid': {'allZeroFeatureVec_Count': 0, 'allZeroFeatureVec_Percent': 0.0},
 'test': {'allZeroFeatureVec_Count': 0, 'allZeroFeatureVec_Percent': 0.0}}

In [29]:
gs_benchmark.bestStats

{'epoch': 102,
 'loss': 1.2033071517944336,
 'train_acc': 0.6709005858787023,
 'valid_acc': 0.6583993330554397,
 'test_acc': 0.6537793525079895}

In [30]:
# clean the graph projections
g_sample.drop()
gds.graph.nodeProperties.drop(g, 'graphSageEmb')

graphName                      proj
nodeProperties       [graphSageEmb]
propertiesRemoved            169343
Name: 0, dtype: object

## Results

In [31]:
non_graph_acc = non_graph_benchmark.bestStats['test_acc']
fastrp_acc = fastrp_benchmark.bestStats['test_acc']
gs_acc = gs_benchmark.bestStats['test_acc']

print(f'====== Model Results =========')
pd.DataFrame({'Model': ['Non Graph (NLP Only)', 'Graph ML with FastRP Embeddings', 'Graph ML with GraphSAGE Embeddings'],
    'Test Set Accuracy': [round(non_graph_acc,3), round(fastrp_acc, 3), round(gs_acc, 3)],
    '% Improvement Over Non-Graph': ['.', f'{(fastrp_acc - non_graph_acc)/non_graph_acc:.2%}', f'{(gs_acc - non_graph_acc)/non_graph_acc:.2%}']})



Unnamed: 0,Model,Test Set Accuracy,% Improvement Over Non-Graph
0,Non Graph (NLP Only),0.539,.
1,Graph ML with FastRP Embeddings,0.68,26.18%
2,Graph ML with GraphSAGE Embeddings,0.654,21.40%


# Cleanup

In [32]:
if gds.beta.model.exists('gsModel')['exists']:
    gds.model.get('gsModel').drop()

In [33]:
g.drop()

graphName                                                             proj
database                                                             neo4j
memoryUsage                                                               
sizeInBytes                                                             -1
nodeCount                                                           169343
relationshipCount                                                  2332486
configuration                                                           {}
density                                                           0.000081
creationTime                           2023-01-23T00:40:36.088904449+00:00
modificationTime                       2023-01-23T00:43:40.798960598+00:00
schema                   {'graphProperties': {}, 'relationships': {'CIT...
schemaWithOrientation    {'graphProperties': {}, 'relationships': {'CIT...
Name: 0, dtype: object