# Link Prediction with GCN

In [1]:
import networkx as nx
import pandas as pd
import numpy as np
from scipy import stats
import os
import time
import stellargraph as sg
from stellargraph import StellarGraph
import scipy.sparse as sp
from stellargraph.mapper import FullBatchNodeGenerator
from stellargraph.layer import GCN
from tensorflow import keras
from tensorflow.keras import layers, optimizers, losses, metrics, Model, regularizers
from sklearn import preprocessing, feature_extraction, model_selection
from copy import deepcopy
import matplotlib.pyplot as plt
from stellargraph import datasets
from IPython.display import display, HTML
import dill
from sklearn.preprocessing import LabelEncoder
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import classification_report, roc_auc_score, accuracy_score
import math
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from copy import deepcopy
import torch
from scipy.sparse import identity
import numpy as np
import logging

import plotly.graph_objects as go
import numpy as np
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
from sklearn.metrics import precision_recall_curve
from collections import Counter
# from tqdm import tqdm_notebook as tqdm
# tqdm().pandas()
import json
import pickle
from keras.callbacks import ModelCheckpoint
from keras.callbacks import EarlyStopping
from keras.models import load_model
from sklearn.model_selection import StratifiedKFold
from keras.models import load_model
from sklearn.metrics import log_loss
from keras import regularizers
from keras.models import Sequential
from keras.layers import LSTM
from keras.layers import Dense
from keras.layers import Activation
from keras.layers import Dropout
from keras.layers import Flatten
import tensorflow as tf
import random
from matplotlib import pyplot as plt
import seaborn as sns
import os
import umap
from sklearn.cluster import KMeans
from keras import backend as K
from sklearn.manifold import TSNE
from IPython.core.display import display, HTML
from sklearn.preprocessing import StandardScaler, MinMaxScaler
display(HTML("<style>.container { width:80% !important; }</style>"))
from plotly.offline import init_notebook_mode, iplot
from sklearn.decomposition import PCA
from stellargraph.data import EdgeSplitter
from stellargraph.mapper import FullBatchLinkGenerator
from stellargraph.layer import GCN, LinkEmbedding


from tensorflow import keras
from sklearn import preprocessing, feature_extraction, model_selection

from stellargraph import globalvar
from stellargraph import datasets
from IPython.display import display, HTML
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

%matplotlib inline

Using TensorFlow backend.


## Loading the CORA network

In [2]:
data_path = "keyword_and_abstract_SentenceTransform/"


In [3]:
def prepare_data_for_stellargraph(out_path):
    def _load_raw_inputs(sis_sparse=True):
        with open(out_path+'new_edges.pkl', 'rb') as f:     
            edges = dill.load(f)

        with open(out_path+'new_features.pkl', 'rb') as f:     
            features = dill.load(f)
            
        with open(out_path+'new_labels.pkl', 'rb') as f:     
            labels = dill.load(f)
            
        return edges, features, labels
    
    print("Reading raw inputs...")
    edges, features, labels = _load_raw_inputs()
    
    
    print("creating nodes...")
    tmp_df = pd.concat( [pd.DataFrame(edges), pd.DataFrame([1]*len(edges))],1)
    tmp_df.columns = ["source", "target", "weight"]
    
    print("creating edges...")
    feature_df = pd.DataFrame(features)
    feature_df.columns = [f"w{i}" for i in range(feature_df.shape[1])]
    
    print("creating labels...")
    label_series = pd.DataFrame({ "label": labels})["label"]


    my_graph = StellarGraph(
        {"paper": feature_df}, {"cites": tmp_df}
    )
    print("Done!")
    
    print(my_graph.info())
    return my_graph, label_series
G, subjects = prepare_data_for_stellargraph(data_path)

Reading raw inputs...
creating nodes...
creating edges...
creating labels...
Done!
StellarGraph: Undirected multigraph
 Nodes: 1178, Edges: 5699

 Node types:
  paper: [1178]
    Features: float32 vector, length 868
    Edge types: paper-cites->paper

 Edge types:
    paper-cites->paper: [5699]
        Weights: all 1 (default)
        Features: none


In [4]:
subjects.tail()

1173    2
1174    2
1175    2
1176    2
1177    2
Name: label, dtype: int64

### Splitting the data

In [5]:
# Define an edge splitter on the original graph G:
edge_splitter_test = EdgeSplitter(G)

# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G, and obtain the
# reduced graph G_test with the sampled links removed:
G_test, edge_ids_test, edge_labels_test = edge_splitter_test.train_test_split(
    p=0.1, method="global", keep_connected=True
)

** Sampled 569 positive and 569 negative edges. **


In [6]:
# Define an edge splitter on the reduced graph G_test:
edge_splitter_train = EdgeSplitter(G_test)

# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G_test, and obtain the
# reduced graph G_train with the sampled links removed:
G_train, edge_ids_train, edge_labels_train = edge_splitter_train.train_test_split(
    p=0.1, method="global", keep_connected=True
)

** Sampled 513 positive and 513 negative edges. **


In [7]:
print(G.info())

StellarGraph: Undirected multigraph
 Nodes: 1178, Edges: 5699

 Node types:
  paper: [1178]
    Features: float32 vector, length 868
    Edge types: paper-cites->paper

 Edge types:
    paper-cites->paper: [5699]
        Weights: all 1 (default)
        Features: none


In [8]:
print(G_test.info())

StellarGraph: Undirected multigraph
 Nodes: 1178, Edges: 5130

 Node types:
  paper: [1178]
    Features: float32 vector, length 868
    Edge types: paper-cites->paper

 Edge types:
    paper-cites->paper: [5130]
        Weights: all 1 (default)
        Features: none


In [9]:
print(G_train.info())

StellarGraph: Undirected multigraph
 Nodes: 1178, Edges: 4617

 Node types:
  paper: [1178]
    Features: float32 vector, length 868
    Edge types: paper-cites->paper

 Edge types:
    paper-cites->paper: [4617]
        Weights: all 1 (default)
        Features: none


### Creating the GCN model in Keras

In [None]:
epochs = 700
train_gen = FullBatchLinkGenerator(G_train, method="gcn")
train_flow = train_gen.flow(edge_ids_train, edge_labels_train)
test_gen = FullBatchLinkGenerator(G_test, method="gcn")
test_flow = test_gen.flow(edge_ids_test, edge_labels_test)

gcn = GCN(
    layer_sizes=[80, 32], activations=["elu", "elu"], generator=train_gen, dropout=0.25)
x_inp, x_out = gcn.in_out_tensors()

prediction = LinkEmbedding(activation="relu", method="ip")(x_out)
prediction = keras.layers.Reshape((-1,))(prediction)

model = keras.Model(inputs=x_inp, outputs=prediction)

model.compile(
    optimizer=keras.optimizers.Adam(lr=0.0001),
    loss=keras.losses.binary_crossentropy,
    metrics=["binary_accuracy"],
)

init_train_metrics = model.evaluate(train_flow)
init_test_metrics = model.evaluate(test_flow)

print("\nTrain Set Metrics of the initial (untrained) model:")
for name, val in zip(model.metrics_names, init_train_metrics):
    print("\t{}: {:0.4f}".format(name, val))

print("\nTest Set Metrics of the initial (untrained) model:")
for name, val in zip(model.metrics_names, init_test_metrics):
    print("\t{}: {:0.4f}".format(name, val))
    
history = model.fit(
    train_flow, epochs=epochs, validation_data=test_flow, verbose=2, shuffle=False
)

Using GCN (local pooling) filters...
Using GCN (local pooling) filters...

Train Set Metrics of the initial (untrained) model:
	loss: 7.6246
	binary_accuracy: 0.5000

Test Set Metrics of the initial (untrained) model:
	loss: 7.6246
	binary_accuracy: 0.5000
Epoch 1/700
1/1 - 0s - loss: 7.5897 - binary_accuracy: 0.5000 - val_loss: 7.6246 - val_binary_accuracy: 0.5000
Epoch 2/700
1/1 - 0s - loss: 7.6246 - binary_accuracy: 0.5000 - val_loss: 7.6246 - val_binary_accuracy: 0.5000
Epoch 3/700
1/1 - 0s - loss: 7.6246 - binary_accuracy: 0.5000 - val_loss: 7.6246 - val_binary_accuracy: 0.5000
Epoch 4/700
1/1 - 0s - loss: 7.5996 - binary_accuracy: 0.5000 - val_loss: 7.6246 - val_binary_accuracy: 0.5000
Epoch 5/700
1/1 - 0s - loss: 7.5625 - binary_accuracy: 0.5000 - val_loss: 7.6149 - val_binary_accuracy: 0.5000
Epoch 6/700
1/1 - 0s - loss: 7.5870 - binary_accuracy: 0.5000 - val_loss: 7.6134 - val_binary_accuracy: 0.5000
Epoch 7/700
1/1 - 0s - loss: 7.5570 - binary_accuracy: 0.5019 - val_loss: 7.6

Epoch 71/700
1/1 - 0s - loss: 0.7203 - binary_accuracy: 0.7729 - val_loss: 0.6063 - val_binary_accuracy: 0.7777
Epoch 72/700
1/1 - 0s - loss: 0.6457 - binary_accuracy: 0.7602 - val_loss: 0.5766 - val_binary_accuracy: 0.7803
Epoch 73/700
1/1 - 0s - loss: 0.5857 - binary_accuracy: 0.7700 - val_loss: 0.5663 - val_binary_accuracy: 0.7830
Epoch 74/700
1/1 - 0s - loss: 0.6083 - binary_accuracy: 0.7710 - val_loss: 0.5554 - val_binary_accuracy: 0.7689
Epoch 75/700
1/1 - 0s - loss: 0.5036 - binary_accuracy: 0.7778 - val_loss: 0.5569 - val_binary_accuracy: 0.7742
Epoch 76/700
1/1 - 0s - loss: 0.6001 - binary_accuracy: 0.7505 - val_loss: 0.5581 - val_binary_accuracy: 0.7724
Epoch 77/700
1/1 - 0s - loss: 0.5882 - binary_accuracy: 0.7788 - val_loss: 0.5594 - val_binary_accuracy: 0.7715
Epoch 78/700
1/1 - 0s - loss: 0.6488 - binary_accuracy: 0.7485 - val_loss: 0.5692 - val_binary_accuracy: 0.7698
Epoch 79/700
1/1 - 0s - loss: 0.6320 - binary_accuracy: 0.7485 - val_loss: 0.5708 - val_binary_accuracy:

Epoch 144/700
1/1 - 0s - loss: 0.5965 - binary_accuracy: 0.7573 - val_loss: 0.5642 - val_binary_accuracy: 0.7460
Epoch 145/700
1/1 - 0s - loss: 0.6209 - binary_accuracy: 0.7339 - val_loss: 0.5513 - val_binary_accuracy: 0.7557
Epoch 146/700
1/1 - 0s - loss: 0.7077 - binary_accuracy: 0.7427 - val_loss: 0.5486 - val_binary_accuracy: 0.7548
Epoch 147/700
1/1 - 0s - loss: 0.5875 - binary_accuracy: 0.7290 - val_loss: 0.5548 - val_binary_accuracy: 0.7636
Epoch 148/700
1/1 - 0s - loss: 0.6458 - binary_accuracy: 0.7602 - val_loss: 0.5628 - val_binary_accuracy: 0.7680
Epoch 149/700
1/1 - 0s - loss: 0.6535 - binary_accuracy: 0.7349 - val_loss: 0.5628 - val_binary_accuracy: 0.7698
Epoch 150/700
1/1 - 0s - loss: 0.6199 - binary_accuracy: 0.7437 - val_loss: 0.5625 - val_binary_accuracy: 0.7724
Epoch 151/700
1/1 - 0s - loss: 0.6646 - binary_accuracy: 0.7417 - val_loss: 0.5605 - val_binary_accuracy: 0.7724
Epoch 152/700
1/1 - 0s - loss: 0.6102 - binary_accuracy: 0.7398 - val_loss: 0.5589 - val_binary_

Epoch 217/700
1/1 - 0s - loss: 0.6225 - binary_accuracy: 0.7437 - val_loss: 0.5359 - val_binary_accuracy: 0.7575
Epoch 218/700
1/1 - 0s - loss: 0.5552 - binary_accuracy: 0.7544 - val_loss: 0.5316 - val_binary_accuracy: 0.7496
Epoch 219/700
1/1 - 0s - loss: 0.5602 - binary_accuracy: 0.7427 - val_loss: 0.5352 - val_binary_accuracy: 0.7434
Epoch 220/700
1/1 - 0s - loss: 0.5648 - binary_accuracy: 0.7398 - val_loss: 0.5361 - val_binary_accuracy: 0.7434
Epoch 221/700
1/1 - 0s - loss: 0.5927 - binary_accuracy: 0.7320 - val_loss: 0.5356 - val_binary_accuracy: 0.7434
Epoch 222/700
1/1 - 0s - loss: 0.6094 - binary_accuracy: 0.7290 - val_loss: 0.5231 - val_binary_accuracy: 0.7460
Epoch 223/700
1/1 - 0s - loss: 0.6009 - binary_accuracy: 0.7290 - val_loss: 0.5171 - val_binary_accuracy: 0.7513
Epoch 224/700
1/1 - 0s - loss: 0.5452 - binary_accuracy: 0.7329 - val_loss: 0.5116 - val_binary_accuracy: 0.7627
Epoch 225/700
1/1 - 0s - loss: 0.5808 - binary_accuracy: 0.7602 - val_loss: 0.5076 - val_binary_

Epoch 290/700
1/1 - 0s - loss: 0.6321 - binary_accuracy: 0.7505 - val_loss: 0.5366 - val_binary_accuracy: 0.7821
Epoch 291/700
1/1 - 0s - loss: 0.5663 - binary_accuracy: 0.7788 - val_loss: 0.5342 - val_binary_accuracy: 0.7803
Epoch 292/700
1/1 - 0s - loss: 0.6171 - binary_accuracy: 0.7544 - val_loss: 0.5429 - val_binary_accuracy: 0.7821
Epoch 293/700
1/1 - 0s - loss: 0.5165 - binary_accuracy: 0.7651 - val_loss: 0.5572 - val_binary_accuracy: 0.7750
Epoch 294/700
1/1 - 0s - loss: 0.5463 - binary_accuracy: 0.7739 - val_loss: 0.5724 - val_binary_accuracy: 0.7742
Epoch 295/700
1/1 - 0s - loss: 0.5713 - binary_accuracy: 0.7729 - val_loss: 0.5740 - val_binary_accuracy: 0.7707
Epoch 296/700
1/1 - 0s - loss: 0.6544 - binary_accuracy: 0.7495 - val_loss: 0.5737 - val_binary_accuracy: 0.7707
Epoch 297/700
1/1 - 0s - loss: 0.5387 - binary_accuracy: 0.7661 - val_loss: 0.5715 - val_binary_accuracy: 0.7733
Epoch 298/700
1/1 - 0s - loss: 0.5887 - binary_accuracy: 0.7602 - val_loss: 0.5688 - val_binary_

Epoch 363/700
1/1 - 0s - loss: 0.7377 - binary_accuracy: 0.6881 - val_loss: 0.7052 - val_binary_accuracy: 0.6388
Epoch 364/700
1/1 - 0s - loss: 0.7769 - binary_accuracy: 0.6598 - val_loss: 0.7297 - val_binary_accuracy: 0.6643
Epoch 365/700
1/1 - 0s - loss: 0.7760 - binary_accuracy: 0.6725 - val_loss: 0.7543 - val_binary_accuracy: 0.6951
Epoch 366/700
1/1 - 0s - loss: 0.6381 - binary_accuracy: 0.7261 - val_loss: 0.9389 - val_binary_accuracy: 0.7083
Epoch 367/700
1/1 - 0s - loss: 0.8200 - binary_accuracy: 0.7212 - val_loss: 1.4478 - val_binary_accuracy: 0.6626
Epoch 368/700
1/1 - 0s - loss: 1.2949 - binary_accuracy: 0.6793 - val_loss: 2.0761 - val_binary_accuracy: 0.6063
Epoch 369/700
1/1 - 0s - loss: 1.8913 - binary_accuracy: 0.6589 - val_loss: 2.6661 - val_binary_accuracy: 0.5756
Epoch 370/700
1/1 - 0s - loss: 2.3635 - binary_accuracy: 0.6218 - val_loss: 3.0131 - val_binary_accuracy: 0.5580
Epoch 371/700
1/1 - 0s - loss: 2.7832 - binary_accuracy: 0.6101 - val_loss: 3.0270 - val_binary_

Epoch 436/700
1/1 - 0s - loss: 0.5069 - binary_accuracy: 0.7622 - val_loss: 0.5787 - val_binary_accuracy: 0.7733
Epoch 437/700
1/1 - 0s - loss: 0.5399 - binary_accuracy: 0.7632 - val_loss: 0.5764 - val_binary_accuracy: 0.7750
Epoch 438/700
1/1 - 0s - loss: 0.5216 - binary_accuracy: 0.7749 - val_loss: 0.5760 - val_binary_accuracy: 0.7724
Epoch 439/700
1/1 - 0s - loss: 0.4954 - binary_accuracy: 0.7612 - val_loss: 0.5669 - val_binary_accuracy: 0.7671
Epoch 440/700
1/1 - 0s - loss: 0.6050 - binary_accuracy: 0.7427 - val_loss: 0.5670 - val_binary_accuracy: 0.7575
Epoch 441/700
1/1 - 0s - loss: 0.5876 - binary_accuracy: 0.7515 - val_loss: 0.5676 - val_binary_accuracy: 0.7531
Epoch 442/700
1/1 - 0s - loss: 0.5639 - binary_accuracy: 0.7719 - val_loss: 0.5684 - val_binary_accuracy: 0.7504
Epoch 443/700
1/1 - 0s - loss: 0.5672 - binary_accuracy: 0.7524 - val_loss: 0.5690 - val_binary_accuracy: 0.7504
Epoch 444/700
1/1 - 0s - loss: 0.4756 - binary_accuracy: 0.7749 - val_loss: 0.5406 - val_binary_

Epoch 509/700
1/1 - 0s - loss: 0.5499 - binary_accuracy: 0.7875 - val_loss: 0.5810 - val_binary_accuracy: 0.7575
Epoch 510/700
1/1 - 0s - loss: 0.5972 - binary_accuracy: 0.7593 - val_loss: 0.5858 - val_binary_accuracy: 0.7540
Epoch 511/700
1/1 - 0s - loss: 0.5415 - binary_accuracy: 0.7661 - val_loss: 0.5975 - val_binary_accuracy: 0.7504
Epoch 512/700
1/1 - 0s - loss: 0.6187 - binary_accuracy: 0.7495 - val_loss: 0.6000 - val_binary_accuracy: 0.7504
Epoch 513/700
1/1 - 0s - loss: 0.5858 - binary_accuracy: 0.7524 - val_loss: 0.6018 - val_binary_accuracy: 0.7469
Epoch 514/700
1/1 - 0s - loss: 0.5923 - binary_accuracy: 0.7505 - val_loss: 0.6038 - val_binary_accuracy: 0.7496
Epoch 515/700
1/1 - 0s - loss: 0.6177 - binary_accuracy: 0.7602 - val_loss: 0.6115 - val_binary_accuracy: 0.7531
Epoch 516/700
1/1 - 0s - loss: 0.6424 - binary_accuracy: 0.7466 - val_loss: 0.6037 - val_binary_accuracy: 0.7575
Epoch 517/700
1/1 - 0s - loss: 0.5896 - binary_accuracy: 0.7583 - val_loss: 0.6014 - val_binary_

Epoch 582/700
1/1 - 0s - loss: 0.4546 - binary_accuracy: 0.7797 - val_loss: 0.5214 - val_binary_accuracy: 0.7970
Epoch 583/700
1/1 - 0s - loss: 0.4859 - binary_accuracy: 0.7836 - val_loss: 0.5212 - val_binary_accuracy: 0.7970


In [None]:
sg.utils.plot_history(history)

In [None]:
train_metrics = model.evaluate(train_flow)
test_metrics = model.evaluate(test_flow)

print("\nTrain Set Metrics of the trained model:")
for name, val in zip(model.metrics_names, train_metrics):
    print("\t{}: {:0.4f}".format(name, val))

print("\nTest Set Metrics of the trained model:")
for name, val in zip(model.metrics_names, test_metrics):
    print("\t{}: {:0.4f}".format(name, val))

In [None]:
test_metrics

In [None]:
res_pred = model.predict(test_flow)

In [None]:
tmp = pd.DataFrame({"pred": res_pred[0], "label": edge_labels_test})
tmp

In [None]:
accuracy_score(edge_labels_test, res_pred[0]>0.4), roc_auc_score(edge_labels_test, res_pred[0])

In [None]:
import seaborn as sns, numpy as np
sns.distplot(tmp[tmp.label==1].pred)
sns.distplot(tmp[tmp.label==0].pred, color="r")