# Link Prediction using Graph Neural Networks

GNNs are powerful tools for many machine learning tasks on graphs. This tutorial teaches the basic workflow of using GNNs for link prediction. We again use the Cora citation network and try to predict whether a paper will cite another.

In this tutorial, you will learn:
* Prepare training and testing sets for link prediction task.
* Build a GNN-based link prediction model.
* Train the model and verify the result.

In [1]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import numpy as np
import scipy.sparse as sp

Using backend: pytorch


## Problem formulation

- Given the graph structure and node features
- Predict whether any two nodes in the graph are connected

## Load graph and features

Following the last session, we first load the Cora network.

In [2]:
from dgl.data import CoraGraphDataset
# ----------- 0. load graph -------------- #
data = CoraGraphDataset()
g = data[0]
features = g.ndata['feat']
labels = g.ndata['label']

in_feats = features.shape[1]
n_classes = data.num_classes
n_edges = data.graph.number_of_edges()

Loading from cache failed, re-processing.
Finished data loading and preprocessing.
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.




## Prepare training and testing sets

- Link prediction data set contains two types of edges, *training* and *testing edges*. 
- Testing edges are usually drawn from the existing edges in the graph. 
- In this example, we randomly pick 1000 edges for testing and leave the rest for training. 

In [3]:
# Split edge set for training and testing
u, v = g.edges()
eids = np.arange(g.number_of_edges())
eids = np.random.permutation(eids)
test_pos_u, test_pos_v = u[eids[:1000]], v[eids[:1000]]
train_pos_u, train_pos_v = u[eids[1000:]], v[eids[1000:]]

### Negative links

- Positive links prompt the GNN to connect the corresponding nodes.
- Negative links are used to train the GNN to NOT connect certain nodes.
- Negative links are typically sampled from the non existing links in the graphs.
- How to choose proper negative sampling algorithms is a widely-studied topic and is out of scope of this tutorial. 
- We enumerate all the missing edges and randomly pick 500 for testing and 500 for training.

In [4]:
# Find all negative edges and split them for training and testing
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = 1 - adj.todense() - np.eye(adj.shape[0])
neg_u, neg_v = np.where(adj_neg != 0)
neg_eids = np.random.choice(len(neg_u), 2000)
test_neg_u, test_neg_v = neg_u[neg_eids[:500]], neg_v[neg_eids[:500]]
train_neg_u, train_neg_v = neg_u[neg_eids[500:]], neg_v[neg_eids[500:]]

In [9]:
# Put positive and negative edges together and form training and testing sets.

# Create training set.
train_u = torch.cat([torch.as_tensor(train_pos_u), torch.as_tensor(train_neg_u)])
train_v = torch.cat([torch.as_tensor(train_pos_v), torch.as_tensor(train_neg_v)])
train_label = torch.cat([torch.zeros(len(train_pos_u)), torch.ones(len(train_neg_u))])

# Create testing set.
test_u = torch.cat([torch.as_tensor(test_pos_u), torch.as_tensor(test_neg_u)])
test_v = torch.cat([torch.as_tensor(test_pos_v), torch.as_tensor(test_neg_v)])
test_label = torch.cat([torch.zeros(len(test_pos_u)), torch.ones(len(test_neg_u))])

## Define a GCN model

Our model consists of two layers, each computes new node representations by aggregating neighbor information as follows

$$
 h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}h_j^{(l)}W^{(l)})
$$
<img src='https://tkipf.github.io/graph-convolutional-networks/images/gcn_web.png' align='center' width="400px" height="300px" />

In [6]:
from dgl.nn import GraphConv

# ----------- 2. create model -------------- #
# build a two-layer GCN model
class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)
    
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h
    
# Create the model with given dimensions 
# input layer dimension: 1433, node features
# hidden layer dimension: 16
model = GCN(in_feats, 16, 16)

## Link prediction loss function

- The following loss function is used.

$$
\hat{y}_{u\sim v} = \sigma(h_u^T h_v)
$$

$$
\mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right)
$$

- The model predicts a score for each edge as a dot-product of the two node embeddings the representations. 
- The binary cross entropy loss has target $y$ being 0 or 1 for positive or negative links.

In [7]:
# ----------- 3. set up loss and optimizer -------------- #
# in this case, loss will in training loop
optimizer = torch.optim.Adam(itertools.chain(model.parameters()), lr=0.005)

# ----------- 4. training -------------------------------- #
all_logits = []
for e in range(500):
    # forward
    logits = model(g, features)
    pred = torch.sigmoid((logits[train_u] * logits[train_v]).sum(dim=1))
    
    # compute loss
    loss = F.binary_cross_entropy(pred, train_label)
    
    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    all_logits.append(logits.detach())
    
    if e % 50 == 0:
        print('In epoch {}, loss: {}'.format(e, loss))

In epoch 0, loss: 0.693248987197876
In epoch 50, loss: 0.6931480765342712
In epoch 100, loss: 0.6931473016738892
In epoch 150, loss: 0.6931399703025818
In epoch 200, loss: 0.6319965124130249
In epoch 250, loss: 0.5070324540138245
In epoch 300, loss: 0.46513524651527405
In epoch 350, loss: 0.4416673481464386
In epoch 400, loss: 0.4180602729320526
In epoch 450, loss: 0.39537087082862854


In [8]:
# ----------- 5. check results ------------------------ #
pred = torch.sigmoid((logits[test_u] * logits[test_v]).sum(dim=1))
print('Accuracy', ((pred >= 0.5) == test_label).sum().item() / len(pred))

Accuracy 0.692
