# Part 2: Link Prediction on a Knowledge Graph


Link prediction is the task of predicting missing connections (or links) between two nodes in a (directed) graph. By convention, we call those two nodes the subject $s$ and the object $o$. In most of the cases, we are considering typed edges which means that the edges (ie the relations) have a label, noted as $r$. An example triplet in the context of social networks could be $\Big(s="John Doe", r="follows", o="Karen Smith"\Big)$. Link prediction is as the heart of many applications with Knowledge Graphs (KG). For instance, link prediction can be used to propose to a user to connect/follow another one.

In this hands-on tutorial, we propose to play with the FB15k-237 dataset.  This dataset contains textual  relations  that were extracted  from  200  million  sentences  in  the  ClueWeb12  corpus  coupled  with Freebase mention annotations and include textual links of all co-occurring entities from the KB set. The FB15k-237 dataset has 15'000 nodes and 272'000 edges. 17'000 are used for validation and 20'000 are used for testing. 

## Objectives:

- Sharpen your understanding of Graph Neural Networks with the R-GCN model 
- Use the high-level API of the DGL library 
- Train a link prediction task: predicting missing links in a Knowledge Graph (KG)

Disclaimer: This hands-on exercise is inspired by the official DGL tutorial introducing the Relational-Graph Convolutional Network (R-GCN) available [here](https://github.com/dmlc/dgl/tree/master/examples/pytorch/rgcn) .

## A.) Dataloading with DGL:

### Load FB15k-237 dataset

We can directly use the `dgl.data.knowledge_graph` module that allows to load several Knowledge Graphs, eg `FB15k-237`.

### Pre-processing

Working with KG requires some data manipulation to extract the test graph out of the full knowledge graph

In [1]:
from utils.graph import * 
from dgl.data.knowledge_graph import load_data

# 1. load KG
data = load_data("FB15k-237")

# 2. extract meta info
num_nodes = data.num_nodes
train_data = data.train
valid_data = torch.LongTensor(data.valid)
test_data = torch.LongTensor(data.test)
num_rels = data.num_rels

# 3. build test graph
test_graph, test_rel, test_norm = build_test_graph(num_nodes, num_rels, train_data)
test_deg = test_graph.in_degrees(range(test_graph.number_of_nodes())).float().view(-1,1)
test_node_id = torch.arange(0, num_nodes, dtype=torch.long).view(-1, 1)
test_rel = torch.from_numpy(test_rel)
test_norm = node_norm_to_edge_norm(test_graph, torch.from_numpy(test_norm).view(-1, 1))

# build adj list and calculate degrees for sampling
adj_list, degrees = get_adj_and_degrees(num_nodes, train_data)


Using backend: pytorch


# entities: 14541
# relations: 237
# training edges: 272115
# validation edges: 17535
# testing edges: 20466
Done loading data from cached files.
Test graph:


  norm = 1.0 / in_deg


# nodes: 14541, # edges: 544230


## B.) Designing the model

### High-level model design

- Encode each node using a Graph Neural Network that can operates on directed node- and edges- labeled graph. An example is to use the relational GCN (R-GCN)
- Decode using a tensor factorization method to derive scores for each candidate triplet. An example is to use the  DistMult.

### Model formulation & implementation 

#### R-GCN

We make use of the high-level API `dgl.nn.pytorch.conv` that allows to directly load GNN layers. In particular, this task is based on the R-GCN model. The R-GCN proposes to update each node as:
\begin{equation}
h(v)^{(k)} = \sigma \Big( \sum_{r \in R} \sum_{u \in N(v)^r} \frac{1}{c(v)^r} W_r^{(k)} h(u)^{(k-1)} + W_0^{(k)} h(v)^{(k-1)}   \Big)
\end{equation}

#### DistMult

The probability of a triplet defined by subject, object and a relation is computed using the `DistMult` as:
\begin{equation}
f(s, r, o) = e^T_s R_r e_o
\end{equation}
where $e_s$ and $e_o$ are the final node embeddings of the subject and object respectively, ie $e_s = h_s^{(k=K)}$

In [3]:
import torch 
import torch.nn as nn 
from dgl.data.knowledge_graph import load_data
from dgl.nn.pytorch import RelGraphConv


class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases,
                 num_hidden_layers=1, dropout=0,
                 use_self_loop=False, cuda=False):
        super(RGCN, self).__init__()
        self.num_nodes = num_nodes
        self.h_dim = h_dim
        self.out_dim = out_dim
        self.num_rels = num_rels
        self.num_bases = None if num_bases < 0 else num_bases
        self.num_hidden_layers = num_hidden_layers
        self.dropout = dropout
        self.use_self_loop = use_self_loop
        self.cuda = cuda

        # create rgcn layers
        self.build_model()

    def build_model(self):
        self.layers = nn.ModuleList()
        # i2h
        i2h = self.build_input_layer()
        if i2h is not None:
            self.layers.append(i2h)
        # h2h
        for idx in range(self.num_hidden_layers):
            h2h = self.build_hidden_layer(idx)
            self.layers.append(h2h)

    def build_input_layer(self):
        return EmbeddingLayer(self.num_nodes, self.h_dim)

    def build_hidden_layer(self, idx):
        act = F.relu if idx < self.num_hidden_layers - 1 else None
        return RelGraphConv(self.h_dim, self.h_dim, self.num_rels, "bdd",
                self.num_bases, activation=act, self_loop=True,
                dropout=self.dropout)
      
    def build_output_layer(self):
        return None

    def forward(self, g, h, r, norm):
        for layer in self.layers:
            h = layer(g, h, r, norm)
        return h


class EmbeddingLayer(nn.Module):
    def __init__(self, num_nodes, h_dim):
        super(EmbeddingLayer, self).__init__()
        self.embedding = torch.nn.Embedding(num_nodes, h_dim)

    def forward(self, g, h, r, norm):
        return self.embedding(h.squeeze())


class LinkPredict(nn.Module):
    def __init__(self, in_dim, num_rels, h_dim=500, num_bases=100,
                 num_hidden_layers=2, dropout=0.2, cuda=False, reg_param=0.01):
        super(LinkPredict, self).__init__()
        self.rgcn = RGCN(in_dim, h_dim, h_dim, num_rels * 2, num_bases,
                         num_hidden_layers, dropout, cuda)
        self.reg_param = reg_param
        self.w_relation = nn.Parameter(torch.Tensor(num_rels, h_dim))

    def calc_score(self, embedding, triplets):
        # DistMult
        s = embedding[triplets[:,0]]
        r = self.w_relation[triplets[:,1]]
        o = embedding[triplets[:,2]]
        score = torch.sum(s * r * o, dim=1)
        return score

    def forward(self, g, h, r, norm):
        return self.rgcn.forward(g, h, r, norm)

    def regularization_loss(self, embedding):
        return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2))

    def get_loss(self, embed, triplets, labels):
        # triplets is a list of data samples (positive and negative)
        # each row in the triplets is a 3-tuple of (source, relation, destination)
        score = self.calc_score(embed, triplets)
        predict_loss = F.binary_cross_entropy_with_logits(score, labels)
        reg_loss = self.regularization_loss(embed)
        return predict_loss + self.reg_param * reg_loss


## C.) Define the training and testing loop

### Use classic PyTorch training loop 
- Define the model parameters (num layers, GNN dimensions)
- Define the training parameters (optimizer, learning rate, weight decay, number of epochs)

### Define positive and negative samples

In order to train a link prediction system, we need to generate positive and negative triplets $(o, r, s)$ associated to a score (the triplet exists or doesn't exist).

**Positive samples:** During training, the system operates on a subset of the original graph, where we simply randomly drop edges. The dropped relationships are then used to evaluate the system during inference.

**Negative samples:** For each observed example we sample negative ones. We sample by randomly corrupting either the subject or the object of each positive example.

In [None]:
from utils.graph import * 
from utils.metrics import * 

import torch.nn.functional as F
import torch
from tqdm import tqdm

# with cuda?
cuda = torch.cuda.is_available()
device = 'cuda:0' if cuda else 'cpu'

# declare model
model = LinkPredict(num_nodes, num_rels, cuda=cuda)
model.to(device)

# build optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-2,
    weight_decay=5e-4
)

epoch = 0
best_mrr = 0
while True:
    model.train()
    epoch += 1

    # perform edge neighborhood sampling to generate training graph and data
    g, node_id, edge_type, node_norm, data, labels = generate_sampled_graph_and_labels(train_data, num_rels, adj_list, degrees)

    # set node/edge feature
    node_id = torch.from_numpy(node_id).view(-1, 1).long()
    edge_type = torch.from_numpy(edge_type)
    edge_norm = node_norm_to_edge_norm(g, torch.from_numpy(node_norm).view(-1, 1))
    data, labels = torch.from_numpy(data), torch.from_numpy(labels)
    deg = g.in_degrees(range(g.number_of_nodes())).float().view(-1, 1)
    if cuda:
        node_id, deg = node_id.cuda(), deg.cuda()
        edge_type, edge_norm = edge_type.cuda(), edge_norm.cuda()
        data, labels = data.cuda(), labels.cuda()
        g = g.to(device)

    embed = model(g, node_id, edge_type, edge_norm)
    loss = model.get_loss(embed, data, labels)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # clip gradients
    optimizer.step()

    print("Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}".format(epoch, loss.item(), best_mrr))

    optimizer.zero_grad()

    # validation
    if epoch % 500 == 0:
        # perform validation on CPU because full graph is too large
        model.cpu()
        model.eval()
        print("start eval")
        embed = model(test_graph, test_node_id, test_rel, test_norm)
        mrr = calc_mrr(embed, model.w_relation, torch.LongTensor(train_data),
                             valid_data, test_data, hits=[1, 3, 10])
        # save best model
        if mrr < best_mrr:
            if epoch >= 6000:
                break
        else:
            best_mrr = mrr
            torch.save({'state_dict': model.state_dict(), 'epoch': epoch},
                       'checkpoints/model_state_{}.pt'.format(str(epoch)))
        if cuda:
            model.cuda()
print("training done")


# sampled nodes: 11824
# sampled edges: 30000
# nodes: 11824, # edges: 30000
Epoch 0001 | Loss 0.7253 | Best MRR 0.0000
# sampled nodes: 11835
# sampled edges: 30000
# nodes: 11835, # edges: 30000
Epoch 0002 | Loss 0.7162 | Best MRR 0.0000
# sampled nodes: 11739
# sampled edges: 30000
# nodes: 11739, # edges: 30000
Epoch 0003 | Loss 1.1346 | Best MRR 0.0000
# sampled nodes: 11753
# sampled edges: 30000
# nodes: 11753, # edges: 30000
Epoch 0004 | Loss 0.4699 | Best MRR 0.0000
# sampled nodes: 11837
# sampled edges: 30000
# nodes: 11837, # edges: 30000
Epoch 0005 | Loss 0.4522 | Best MRR 0.0000
# sampled nodes: 11780
# sampled edges: 30000
# nodes: 11780, # edges: 30000
Epoch 0006 | Loss 0.3425 | Best MRR 0.0000
# sampled nodes: 11796
# sampled edges: 30000
# nodes: 11796, # edges: 30000
Epoch 0007 | Loss 0.8033 | Best MRR 0.0000
# sampled nodes: 11804
# sampled edges: 30000
# nodes: 11804, # edges: 30000
Epoch 0008 | Loss 0.3704 | Best MRR 0.0000
# sampled nodes: 11793
# sampled edges: 

Epoch 0069 | Loss 0.3010 | Best MRR 0.0000
# sampled nodes: 11765
# sampled edges: 30000
# nodes: 11765, # edges: 30000
Epoch 0070 | Loss 0.3006 | Best MRR 0.0000
# sampled nodes: 11733
# sampled edges: 30000
# nodes: 11733, # edges: 30000
Epoch 0071 | Loss 0.2987 | Best MRR 0.0000
# sampled nodes: 11826
# sampled edges: 30000
# nodes: 11826, # edges: 30000
Epoch 0072 | Loss 0.2962 | Best MRR 0.0000
# sampled nodes: 11763
# sampled edges: 30000
# nodes: 11763, # edges: 30000
Epoch 0073 | Loss 0.2959 | Best MRR 0.0000
# sampled nodes: 11805
# sampled edges: 30000
# nodes: 11805, # edges: 30000
Epoch 0074 | Loss 0.2946 | Best MRR 0.0000
# sampled nodes: 11778
# sampled edges: 30000
# nodes: 11778, # edges: 30000
Epoch 0075 | Loss 0.2951 | Best MRR 0.0000
# sampled nodes: 11859
# sampled edges: 30000
# nodes: 11859, # edges: 30000
Epoch 0076 | Loss 0.2935 | Best MRR 0.0000
# sampled nodes: 11832
# sampled edges: 30000
# nodes: 11832, # edges: 30000
Epoch 0077 | Loss 0.2955 | Best MRR 0.00

Epoch 0138 | Loss 0.2995 | Best MRR 0.0000
# sampled nodes: 11768
# sampled edges: 30000
# nodes: 11768, # edges: 30000
Epoch 0139 | Loss 0.2988 | Best MRR 0.0000
# sampled nodes: 11772
# sampled edges: 30000
# nodes: 11772, # edges: 30000
Epoch 0140 | Loss 0.2993 | Best MRR 0.0000
# sampled nodes: 11781
# sampled edges: 30000
# nodes: 11781, # edges: 30000
Epoch 0141 | Loss 0.2995 | Best MRR 0.0000
# sampled nodes: 11785
# sampled edges: 30000
# nodes: 11785, # edges: 30000
Epoch 0142 | Loss 0.3009 | Best MRR 0.0000
# sampled nodes: 11849
# sampled edges: 30000
# nodes: 11849, # edges: 30000
Epoch 0143 | Loss 0.2998 | Best MRR 0.0000
# sampled nodes: 11880
# sampled edges: 30000
# nodes: 11880, # edges: 30000
Epoch 0144 | Loss 0.3064 | Best MRR 0.0000
# sampled nodes: 11827
# sampled edges: 30000
# nodes: 11827, # edges: 30000
Epoch 0145 | Loss 0.3017 | Best MRR 0.0000
# sampled nodes: 11794
# sampled edges: 30000
# nodes: 11794, # edges: 30000
Epoch 0146 | Loss 0.3056 | Best MRR 0.00

# nodes: 11796, # edges: 30000
Epoch 0207 | Loss 0.3038 | Best MRR 0.0000
# sampled nodes: 11848
# sampled edges: 30000
# nodes: 11848, # edges: 30000
Epoch 0208 | Loss 0.3029 | Best MRR 0.0000
# sampled nodes: 11755
# sampled edges: 30000
# nodes: 11755, # edges: 30000
Epoch 0209 | Loss 0.3031 | Best MRR 0.0000
# sampled nodes: 11801
# sampled edges: 30000
# nodes: 11801, # edges: 30000
Epoch 0210 | Loss 0.3040 | Best MRR 0.0000
# sampled nodes: 11789
# sampled edges: 30000
# nodes: 11789, # edges: 30000
Epoch 0211 | Loss 0.3034 | Best MRR 0.0000
# sampled nodes: 11752
# sampled edges: 30000
# nodes: 11752, # edges: 30000
Epoch 0212 | Loss 0.3033 | Best MRR 0.0000
# sampled nodes: 11817
# sampled edges: 30000
# nodes: 11817, # edges: 30000
Epoch 0213 | Loss 0.3038 | Best MRR 0.0000
# sampled nodes: 11833
# sampled edges: 30000
# nodes: 11833, # edges: 30000
Epoch 0214 | Loss 0.3034 | Best MRR 0.0000
# sampled nodes: 11797
# sampled edges: 30000
# nodes: 11797, # edges: 30000
Epoch 021

Epoch 0276 | Loss 0.3009 | Best MRR 0.0000
# sampled nodes: 11763
# sampled edges: 30000
# nodes: 11763, # edges: 30000
Epoch 0277 | Loss 0.3005 | Best MRR 0.0000
# sampled nodes: 11806
# sampled edges: 30000
# nodes: 11806, # edges: 30000
Epoch 0278 | Loss 0.3014 | Best MRR 0.0000
# sampled nodes: 11823
# sampled edges: 30000
# nodes: 11823, # edges: 30000
Epoch 0279 | Loss 0.3007 | Best MRR 0.0000
# sampled nodes: 11825
# sampled edges: 30000
# nodes: 11825, # edges: 30000
Epoch 0280 | Loss 0.3007 | Best MRR 0.0000
# sampled nodes: 11764
# sampled edges: 30000
# nodes: 11764, # edges: 30000
Epoch 0281 | Loss 0.3008 | Best MRR 0.0000
# sampled nodes: 11741
# sampled edges: 30000
# nodes: 11741, # edges: 30000
Epoch 0282 | Loss 0.3008 | Best MRR 0.0000
# sampled nodes: 11837
# sampled edges: 30000
# nodes: 11837, # edges: 30000
Epoch 0283 | Loss 0.3007 | Best MRR 0.0000
# sampled nodes: 11854
# sampled edges: 30000
# nodes: 11854, # edges: 30000
Epoch 0284 | Loss 0.3004 | Best MRR 0.00

Epoch 0345 | Loss 0.2978 | Best MRR 0.0000
# sampled nodes: 11810
# sampled edges: 30000
# nodes: 11810, # edges: 30000
Epoch 0346 | Loss 0.2979 | Best MRR 0.0000
# sampled nodes: 11873
# sampled edges: 30000
# nodes: 11873, # edges: 30000
Epoch 0347 | Loss 0.2981 | Best MRR 0.0000
# sampled nodes: 11846
# sampled edges: 30000
# nodes: 11846, # edges: 30000
Epoch 0348 | Loss 0.2976 | Best MRR 0.0000
# sampled nodes: 11837
# sampled edges: 30000
# nodes: 11837, # edges: 30000
Epoch 0349 | Loss 0.2972 | Best MRR 0.0000
# sampled nodes: 11814
# sampled edges: 30000
# nodes: 11814, # edges: 30000
Epoch 0350 | Loss 0.2975 | Best MRR 0.0000
# sampled nodes: 11811
# sampled edges: 30000
# nodes: 11811, # edges: 30000
Epoch 0351 | Loss 0.2978 | Best MRR 0.0000
# sampled nodes: 11826
# sampled edges: 30000
# nodes: 11826, # edges: 30000
Epoch 0352 | Loss 0.2976 | Best MRR 0.0000
# sampled nodes: 11762
# sampled edges: 30000
# nodes: 11762, # edges: 30000
Epoch 0353 | Loss 0.2973 | Best MRR 0.00

## D.) Testing 

### Metrics
- The system is assessed using:
    - the mean reciprocal rank ($MRR$)
    - Hits at n ($H@n$)

In [None]:
# 1. load pre-trained network 
checkpoint = torch.load('checkpoints/model_state_500.pt')

# 2. eval and compute MRR
model.cpu() 
model.eval()
model.load_state_dict(checkpoint['state_dict'])
print("Using best epoch: {}".format(checkpoint['epoch']))
embed = model(test_graph, test_node_id, test_rel, test_norm)
calc_mrr(embed, model.w_relation, torch.LongTensor(train_data), valid_data, test_data, hits=[1, 3, 10])