# Multi-relational Link Prediction on Knowledge Graphs
By Haoxin Li, on 13 July 2020

In the biological world, different types of relation could exist between two entities. For example, a drug/chemical compound can act as a *target, enzyme, carrier* or *transporter* on proteins, forming 4 types of edges. Thus, it would not be ideal to represent these relations using the same edge embeddings. In this demo, we explore [Relational Graph Convolutional Neural Network](https://arxiv.org/pdf/1703.06103.pdf) (RGCN) and apply this achitecture on real world biological dataset, including protein-protein interactions, and drug-protein interactions.

In [1]:
from torch_geometric.data import Data, GraphSAINTRandomWalkSampler, NeighborSampler, GraphSAINTEdgeSampler
from torch_geometric.nn import RGCNConv, Node2Vec, FastRGCNConv
from torch_geometric.utils import contains_isolated_nodes

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score, average_precision_score
import matplotlib.pyplot as plt

### Preparing Data

In [2]:
edge_index = torch.load('data/edge_index.pt')
row, col = edge_index # row: first row, col: second row
edge_attr = torch.load('data/edge_attr.pt')
edge_meta_type = torch.load('data/edge_meta_type.pt')
edge_type = torch.load('data/edge_type.pt')
x = torch.load('data/x.pt')
y = torch.load('data/y.pt')
num_nodes = len(y) # total number of nodes in the graph

train_mask = torch.load('data/train_mask.pt') # training mask of edges, split randomly 80%
val_mask = torch.load('data/val_mask.pt') # validation mask of edges, split randomly 10%
test_mask = torch.load('data/test_mask.pt') # test_mask of edges, split randomly 10%

num_relations = edge_type.unique().size(0) # total number of edge types in the graph

data = Data(edge_attr=edge_attr, edge_index=edge_index, edge_type=edge_type, edge_meta_type=edge_meta_type, 
            x=x, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)

- `edge_index` stores all the edges in the dataset in the form of a 2-D tensor. Each column represents an edge formed by two nodes and the number of columns indicate the total number of edges in the dataset. For example, the first column in `edge_index` is [0, 9052], which represents an edge between node 0 and node 9052.
- `edge_attr` contains edge attributes calulated using `1.0 / torch_geometric.utils.degree(col, num_nodes)[col]`. This attribute is used for GraphSAINT sampler. Please see [this](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/graph_saint.py) and [this](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html) for reference. 
- `edge_meta_type` helps to identify the meta edge type of each edge in `edge_index`. Because drug and protein edges are directional, we use edge meta types here to do negative sampling more easily.  There are 3 meta edges. `1` represents edges between a drug and a protein, where drug is the starting node and protein is the ending node. `2` represents edges between proteins and proteins. `3` represents edges between a protein and a drug where protein is the starting node and drug is the ending node.
- `edge_type` stores the edge type for each edge in `edge_index`. The meaning of each number is shown in the next cell. See `edge_type_mapping`.
- `x` stores the input embeddings/attributes of each node, with dimension of 128. It was learnt separately using [node2vec](https://arxiv.org/pdf/1607.00653.pdf). The main reason to use these embeddings is to decrease the input dimension for each node from 25455 to 128. Naively, one-hot-encoded embeddings are used to represent each node.
- `y` stores the node type, where `0` represents a drug and `1` represents a protein.

In [3]:
edge_type_mapping = {
    0: 'target', 
    1: 'enzyme', 
    2: 'carrier', 
    3: 'transporter', 
    4: 'ppi', 
    5: 'target_rev',
    6: 'enzyme_rev',
    7: 'carrier_rev',
    8: 'transporter_rev'}

Here we have 9 different edge types. The last 4 edge types are the opposites of the first 4 edge types as we want our graph to be un-directional.
e.g. Drug A **targets** Protein A is equivalent to Protein A is **targeted** by Drug A

In [4]:
data_loader = GraphSAINTRandomWalkSampler(data, batch_size=128, walk_length=16, num_steps=32)

We utilize [GraphSAINT Random Walk Sampler](https://arxiv.org/pdf/1907.04931.pdf) as it allows us to sample fully-connected sub-graphs for training.

### Constructing a GNN Model

In [5]:
class RGCN(torch.nn.Module):
    def __init__(self, in_dim, h_dim, out_dim, num_rels):
        super(RGCN, self).__init__()
        self.num_rels = num_rels
        self.conv1 = FastRGCNConv(
            in_dim, h_dim, num_rels)
        self.conv2 = FastRGCNConv(
            h_dim, out_dim, num_rels)
        self.relu = nn.ReLU()
        self.w_rels = nn.Parameter(torch.Tensor(num_rels, out_dim))
        nn.init.xavier_uniform_(self.w_rels,
                                gain=nn.init.calculate_gain('relu'))
        
    def forward(self, x, edge_index, edge_type):
        x1 = self.conv1(x, edge_index, edge_type)
        x1 = self.relu(x1)
        x2 = self.conv2(x1, edge_index, edge_type)
        out = F.log_softmax(x2, dim=1)
        
        return out
    
def get_metrics(model, embed, edge_index, edge_type, labels):
    probs = DistMult(embed, edge_index, edge_type, model)
    loss = F.binary_cross_entropy(probs, labels)

    probs = probs.cpu().detach().numpy()
    labels = labels.cpu().detach().numpy()

    return loss, probs, labels

def DistMult(embed, edge_index, edge_type, model):
    s = embed[edge_index[0, :]]
    o = embed[edge_index[1, :]]
    r = model.w_rels[edge_type]
    scores = torch.sum(s * r * o, dim=1)
    
    return torch.sigmoid(scores)


def get_link_labels(edge_index_pos_len, edge_index_neg_len):
    link_labels = torch.zeros(edge_index_pos_len + edge_index_neg_len).float().to(device)
    link_labels[:int(edge_index_pos_len)] = 1.
    return link_labels

def get_embeddings(data):
    data = data.to(device)
    x = data.x
    edge_index_pos = data.edge_index
    edge_type = torch.squeeze(data.edge_type)
    embed = model(x, edge_index_pos, edge_type)
    
    return embed

def negative_sample(edge_index, edge_meta_type):
    """
    generate negative samples but keep the node type the same
    """
    edge_index_copy = edge_index.clone()
    
    # resample ppi, the meta edge type for ppi is 2
    ppi = edge_index_copy[0, torch.squeeze(edge_meta_type == 2)]
    new_index = torch.randperm(ppi.shape[0])
    new_ppi = ppi[new_index]
    edge_index_copy[0, torch.squeeze(edge_meta_type == 2)] = new_ppi

    #resample dpi, the meta edge type for ppi is 1
    dpi = edge_index_copy[0, torch.squeeze(edge_meta_type == 1)]
    new_index = torch.randperm(dpi.shape[0])
    new_dpi = dpi[new_index]
    edge_index_copy[0, torch.squeeze(edge_meta_type == 1)] = new_dpi

    #resample dpi_rev, the meta edge type for ppi is 3
    dpi_rev = edge_index_copy[0, torch.squeeze(edge_meta_type == 3)]
    new_index = torch.randperm(dpi_rev.shape[0])
    new_dpi_rev = dpi_rev[new_index]
    edge_index_copy[0, torch.squeeze(edge_meta_type == 3)] = new_dpi_rev
    
    return edge_index_copy

In [6]:
params = {'in_dim': 128, 
          'h_dim':64,
          'out_dim':64,
          'num_rels': num_relations,
          'epochs':50}

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RGCN(params['in_dim'], params['h_dim'], params['out_dim'], params['num_rels']).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)

Here we construct a 2-layer RGCN with hidden dimension of 64 for both node and edge embeddings. We model it as a binary classification task that tries to minimize the loss between real edge labels and fake edge labels geneated from negative sampling. We use RGCN as the encoder for node embeddings and DistMult as the decoder.

### Model Training
Note: the data for training is sampled from GraphSaint, whereas the data for validation is the whole graph. Parameters initialization may affect model convergence.

In [8]:
def train(data, embed):
    data = data.to(device)
    x = data.x
    
    edge_index_train_pos = data.edge_index[:, data.train_mask]
    edge_type_train = torch.squeeze(data.edge_type[data.train_mask])
    
    edge_meta_type = data.edge_meta_type[data.train_mask]
    edge_index_train_neg = negative_sample(edge_index_train_pos, edge_meta_type)

    edge_index_train_total = torch.cat([edge_index_train_pos, edge_index_train_neg], dim=-1)
    edge_type_train_total = torch.cat([edge_type_train, edge_type_train[:edge_index_train_neg.size(1)]], dim=-1)


    link_labels = get_link_labels(edge_index_train_pos.size(1), edge_index_train_neg.size(1))
    loss, probs, labels = get_metrics(model, embed, edge_index_train_total, edge_type_train_total, 
                                            link_labels)
    
    auroc = roc_auc_score(labels, probs)
    auprc = average_precision_score(labels, probs)
    
    loss_epoch_train.append(loss.item())
    auroc_epoch_train.append(auroc)
    
    loss.backward()
    optimizer.step()

@torch.no_grad()
def validation(data, embed, evaluate_rel=False):
    data = data.to(device)
    x = data.x
    
    edge_index_val_pos = data.edge_index[:, data.val_mask]
    edge_type_val = torch.squeeze(data.edge_type[data.val_mask])
    
    edge_meta_type = data.edge_meta_type[data.val_mask]
    edge_index_val_neg = negative_sample(edge_index_val_pos, edge_meta_type)
    
    edge_index_val_total = torch.cat([edge_index_val_pos, edge_index_val_neg], dim=-1)
    edge_type_val_total = torch.cat([edge_type_val, edge_type_val[:edge_index_val_neg.size(1)]], dim=-1)
    
    link_labels = get_link_labels(edge_index_val_pos.size(1), edge_index_val_neg.size(1))
    loss, probs, labels = get_metrics(model, embed, edge_index_val_total, edge_type_val_total, 
                                                                link_labels)
    auroc = roc_auc_score(labels, probs)
    auprc = average_precision_score(labels, probs)
    
    edge_type_val_total = edge_type_val_total.detach().cpu()
    
    loss_epoch_val.append(loss.item())
    auroc_epoch_val.append(auroc)
    
    if not evaluate_rel:
        return
    
    for i in range(num_relations):
        mask = (edge_type_val_total == i)
        if mask.sum() == 0:
            continue
        probs_per_rel = probs[mask]
        labels_per_rel = labels[mask]
        auroc_per_rel = roc_auc_score(labels_per_rel, probs_per_rel)
        auroc_edge_type[i].append(auroc_per_rel)

In [9]:
loss_train_total, loss_val_total = [], []
auroc_train_total, auroc_val_total = [], []

for epoch in range(0, params['epochs']):
    loss_epoch_train, loss_epoch_val = [], []
    auroc_epoch_train, auroc_epoch_val = [], []

    for batch in data_loader:
        
        optimizer.zero_grad()
        model.train()
        embed = get_embeddings(batch)
        train(batch, embed)
        model.eval()
        validation(batch, embed)
    
    loss_train_total.append(np.mean(loss_epoch_train))
    auroc_train_total.append(np.mean(auroc_epoch_train))
    loss_val_total.append(np.mean(loss_epoch_val))
    auroc_val_total.append(np.mean(auroc_epoch_val))

    print('Epoch: {} | train loss: {} | train auroc: {} |'.format(epoch + 1, 
                                                                  "%.2f" % np.mean(loss_epoch_train), 
                                                                  "%.2f" % np.mean(auroc_epoch_train)))
    print('Epoch: {} | val loss: {} | val auroc: {} |'.format(epoch + 1, 
                                                              "%.2f" % np.mean(loss_epoch_val), 
                                                              "%.2f" % np.mean(auroc_epoch_val)))
    
    print('----------------------------------------------------------------------------------------------')

Epoch: 1 | train loss: 1.50 | train auroc: 0.50 |
Epoch: 1 | val loss: 1.46 | val auroc: 0.50 |
----------------------------------------------------------------------------------------------
Epoch: 2 | train loss: 0.96 | train auroc: 0.51 |
Epoch: 2 | val loss: 0.96 | val auroc: 0.51 |
----------------------------------------------------------------------------------------------
Epoch: 3 | train loss: 0.90 | train auroc: 0.53 |
Epoch: 3 | val loss: 0.91 | val auroc: 0.53 |
----------------------------------------------------------------------------------------------
Epoch: 4 | train loss: 0.85 | train auroc: 0.60 |
Epoch: 4 | val loss: 0.86 | val auroc: 0.60 |
----------------------------------------------------------------------------------------------
Epoch: 5 | train loss: 0.75 | train auroc: 0.71 |
Epoch: 5 | val loss: 0.75 | val auroc: 0.71 |
----------------------------------------------------------------------------------------------
Epoch: 6 | train loss: 0.70 | train auroc: 0.

In [10]:
auroc_edge_type = {rel:[] for rel in range(num_relations)}

for batch in data_loader:
    embed = get_embeddings(batch)
    validation(batch, embed, evaluate_rel=True)

for rel, values in auroc_edge_type.items():
     print('auroc for relation type {}: {}'.format(edge_type_mapping[rel], "%.3f" % np.mean(values)))

auroc for relation type target: 0.706
auroc for relation type enzyme: 0.872
auroc for relation type carrier: 0.959
auroc for relation type transporter: 0.949
auroc for relation type ppi: 0.899
auroc for relation type target_rev: 0.724
auroc for relation type enzyme_rev: 0.694
auroc for relation type carrier_rev: 0.800
auroc for relation type transporter_rev: 0.836
