# Link prediction with GCN

The goal of this notebook is to get complete predictions for whole dataset split into cross-validation folds. For more detailed description of the training refer to the `gcn-link-prediction-demo.ipynb` notebook.

In [1]:
import numpy as np
import pandas as pd
import stellargraph as sg
from stellargraph.mapper import FullBatchLinkGenerator
from stellargraph.layer import GCN, LinkEmbedding
from tensorflow import keras
from tqdm.keras import TqdmCallback

from graph import load_splits

In [2]:
def train_predict(split, epochs=5000):
    (G_train, edge_ids_train, edge_labels_train), (G_test, edge_ids_test, edge_labels_test) = split

    train_gen = FullBatchLinkGenerator(G_train, method="gcn")
    train_flow = train_gen.flow(edge_ids_train, edge_labels_train)
    test_gen = FullBatchLinkGenerator(G_test, method="gcn")
    test_flow = test_gen.flow(edge_ids_test, edge_labels_test)

    gcn = GCN(
        layer_sizes=[256, 256, 256, 256], activations=["relu", "relu", "relu", "relu"], generator=train_gen, dropout=0.25
    )

    x_inp, x_out = gcn.in_out_tensors()

    prediction = LinkEmbedding(activation="tanh", method="ip")(x_out)
    prediction = keras.layers.Reshape((-1,))(prediction)

    model = keras.Model(inputs=x_inp, outputs=prediction)
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.0001),
        loss=keras.losses.binary_crossentropy,
        metrics=[
            keras.metrics.BinaryAccuracy(), 
            keras.metrics.Precision(), 
            keras.metrics.Recall(),
            keras.metrics.AUC()
        ],
    )

    model.fit(
        train_flow, epochs=epochs, validation_data=test_flow, 
        verbose=0, callbacks=[TqdmCallback(verbose=1)], shuffle=True
    )

    y_test = test_flow[0][-1][0]
    y_prob = model.predict(test_flow)[0]

    rows = []

    for edge, p in zip(test_flow[0][0][1][0], y_prob):
        phn, gen = edge
        rows.append([G_test.nodes()[gen], G_test.nodes()[phn], p])

    df = pd.DataFrame(rows, columns=["Gen", "Phn", "p"])

    return df

In [3]:
dfs = []

splits = load_splits(sample_test_negatives=False, drop_disconnected=True)

for split in splits:
    dfs.append(train_predict(split))

df = pd.concat(dfs)
df

Using GCN (local pooling) filters...
Using GCN (local pooling) filters...


0epoch [00:00, ?epoch/s]

0batch [00:00, ?batch/s]

Using GCN (local pooling) filters...
Using GCN (local pooling) filters...


0epoch [00:00, ?epoch/s]

0batch [00:00, ?batch/s]

Using GCN (local pooling) filters...
Using GCN (local pooling) filters...


0epoch [00:00, ?epoch/s]

0batch [00:00, ?batch/s]

Using GCN (local pooling) filters...
Using GCN (local pooling) filters...


0epoch [00:00, ?epoch/s]

0batch [00:00, ?batch/s]

Using GCN (local pooling) filters...
Using GCN (local pooling) filters...


0epoch [00:00, ?epoch/s]

0batch [00:00, ?batch/s]



Unnamed: 0,Gen,Phn,p
0,ERCC1,D1071,0.881632
1,DDX54,D1071,0.845112
2,SNCAIP,D1071,0.698243
3,DNMT3A,D1071,0.833472
4,PAICS,D1071,0.470386
...,...,...,...
179923,GKN1,D1222,0.000014
179924,ITPKB,D1222,0.000014
179925,Mar-05,D1222,0.000014
179926,MORF4,D1222,0.000000


In [4]:
df.to_csv('results.csv', index=False)