# Multi-relational Link Prediction on Knowledge Graphs
By Haoxin Li, on 13 June 2020

In the biological world, different types of relation could exist between two entities. For example, a drug/chemical compound can act as a *target, enzyme, carrier* or *transporter* on proteins, forming 4 types of edges. Thus, it would not be ideal to represent these relations using the same edge embeddings. In this example, we explore Relational Graph Convolutional Neural Network (RGCN) and apply this achitecture on real world biological dataset, including protein-protein interactions, and drug-protein interactions.

In [1]:
from torch_geometric.data import Data, GraphSAINTRandomWalkSampler, NeighborSampler, GraphSAINTEdgeSampler
from torch_geometric.nn import RGCNConv, Node2Vec, FastRGCNConv
from torch_geometric.utils import negative_sampling, contains_isolated_nodes

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score, average_precision_score
import matplotlib.pyplot as plt

### Preparing Data

In [2]:
edge_index = torch.load('data/edge_index.pt')
edge_attr = torch.load('data/edge_attr.pt')
edge_type = torch.load('data/edge_type.pt')
x = torch.load('data/x.pt')
y = torch.load('data/y.pt')

train_mask = torch.load('data/train_mask.pt')
val_mask = torch.load('data/val_mask.pt')
test_mask = torch.load('data/test_mask.pt')

num_relations = edge_type.unique().size(0)

data = Data(edge_attr=edge_attr, edge_index=edge_index, edge_type=edge_type, x=x, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
data

Data(edge_attr=[1515256], edge_index=[2, 1515256], edge_type=[1515256, 1], test_mask=[1515256], train_mask=[1515256], val_mask=[1515256], x=[25455, 128], y=[25455])

In [3]:
edge_type_mapping = {
    0: 'target', 
    1: 'enzyme', 
    2: 'carrier', 
    3: 'transporter', 
    4: 'ppi', 
    5: 'target_rev',
    6: 'enzyme_rev',
    7: 'carrier_rev',
    8: 'transporter_rev'}

Here we have 9 different edge types. The last 4 edge types are the opposites of the first 4 edge types as we want our graph to be un-directional.
e.g. Drug A **targets** Protein A is equivalent to Protein A is **targeted** by Drug A

In [4]:
data_loader = GraphSAINTRandomWalkSampler(data, batch_size=256, walk_length=16, num_steps=16)

We utilize *GraphSAINTRandomWalkSampler* as it allows us to sample fully-connected sub-graphs for training.

### Constructing a GNN Model

In [5]:
class RGCN(torch.nn.Module):
    def __init__(self, in_dim, h_dim, out_dim, num_rels):
        super(RGCN, self).__init__()
        self.num_rels = num_rels
        self.conv1 = FastRGCNConv(
            in_dim, h_dim, num_rels)
        self.conv2 = FastRGCNConv(
            h_dim, out_dim, num_rels)
        self.relu = nn.ReLU()
        self.w_rels = nn.Parameter(torch.Tensor(num_rels, out_dim))
        nn.init.xavier_uniform_(self.w_rels,
                                gain=nn.init.calculate_gain('relu'))
        
    def forward(self, x, edge_index, edge_type):
        x1 = self.conv1(x, edge_index, edge_type)
        x1 = self.relu(x1)
        x2 = self.conv2(x1, edge_index, edge_type)
        out = F.log_softmax(x2, dim=1)
        return out
    
def get_metrics(model, embed, edge_index, edge_type, labels):
    probs = DistMult(embed, edge_index, edge_type, model)
    
    loss = F.binary_cross_entropy(probs, labels)

    probs = probs.cpu().detach().numpy()
    labels = labels.cpu().detach().numpy()
#     auroc = roc_auc_score(labels, probs)
#     auprc = average_precision_score(labels, probs)
    return loss, probs, labels

def DistMult(embed, edge_index, edge_type, model):
    s = embed[edge_index[0, :]]
    o = embed[edge_index[1, :]]
    r = model.w_rels[edge_type]
    scores = torch.sum(s * r * o, dim=1)
    
    return torch.sigmoid(scores)



def get_link_labels(edge_index_pos_len, edge_index_neg_len):
    '''
    first half 1, and second half 0
    '''
    
    link_labels = torch.zeros(edge_index_pos_len + edge_index_neg_len).float().to(device)
    link_labels[:int(edge_index_pos_len)] = 1.
    return link_labels
        

In [6]:
params = {'in_dim': 128, 
          'h_dim':64,
          'out_dim':64,
          'num_rels': num_relations,
          'epochs':20}

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RGCN(params['in_dim'], params['h_dim'], params['out_dim'], params['num_rels']).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)

Here we construct a 2-layer RGCN with hidden dimension of 64 for both node and edge embeddings. We model it as a binary classification task that tries to minimize the loss between real edge labels and fake edge labels geneated from negative sampling. We use RGCN as the encoder for node embeddings and DistMult as the decoder.

### Model Training
Note: the data for training is sampled from GraphSaint, wheares the data for validation is the whole graph. 

In [8]:
def train(data):
    data = data.to(device)
    x = data.x
    edge_index_train_pos = data.edge_index[:, data.train_mask]
    edge_type_train = torch.squeeze(data.edge_type[data.train_mask])
    
    edge_index_train_neg = negative_sampling(edge_index_train_pos, num_neg_samples=edge_index_train_pos.size(1))

    edge_index_train_total = torch.cat([edge_index_train_pos, edge_index_train_neg], dim=-1)
    edge_type_train_total = torch.cat([edge_type_train, edge_type_train[:edge_index_train_neg.size(1)]], dim=-1)


    link_labels = get_link_labels(edge_index_train_pos.size(1), edge_index_train_neg.size(1))
    embed = model(x, edge_index_train_pos, edge_type_train)
    loss, probs, labels = get_metrics(model, embed, edge_index_train_total, edge_type_train_total, 
                                            link_labels)
    
    auroc = roc_auc_score(labels, probs)
    auprc = average_precision_score(labels, probs)
    
    loss_epoch_train.append(loss.item())
    auroc_epoch_train.append(auroc)
    
    loss.backward()
    optimizer.step()

@torch.no_grad()
def validation(data, evaluate_rel=False):
    
    data = data.to(device)
    x = data.x
    edge_index_val_pos = data.edge_index[:, data.val_mask]
    edge_type_val = torch.squeeze(data.edge_type[data.val_mask])
    
    edge_index_val_neg = negative_sampling(edge_index_val_pos, num_neg_samples=edge_index_val_pos.size(1))
    edge_index_val_total = torch.cat([edge_index_val_pos, edge_index_val_neg], dim=-1)
    edge_type_val_total = torch.cat([edge_type_val, edge_type_val[:edge_index_val_neg.size(1)]], dim=-1)
    
    link_labels = get_link_labels(edge_index_val_pos.size(1), edge_index_val_neg.size(1))
    embed = model(x, edge_index_val_pos, edge_type_val)
    loss, probs, labels = get_metrics(model, embed, edge_index_val_total, edge_type_val_total, 
                                                                link_labels)
    auroc = roc_auc_score(labels, probs)
    auprc = average_precision_score(labels, probs)
    
    edge_type_val_total = edge_type_val_total.detach().cpu()
    
    loss_epoch_val.append(loss.item())
    auroc_epoch_val.append(auroc)
    
    if not evaluate_rel:
        return
    
    for i in range(num_relations):
        mask = (edge_type_val_total == i)
        probs_per_rel = probs[mask]
        labels_per_rel = labels[mask]
        auroc_per_rel = roc_auc_score(labels_per_rel, probs_per_rel)
        auroc_edge_type[i].append(auroc_per_rel)

In [9]:
loss_train_total, loss_val_total = [], []
auroc_train_total, auroc_val_total = [], []

for epoch in range(0, params['epochs']):
    loss_epoch_train, loss_epoch_val = [], []
    auroc_epoch_train, auroc_epoch_val = [], []

    for batch in data_loader:
        optimizer.zero_grad()
        model.train()

        train(batch)
        validation(data)
    
    loss_train_total.append(np.mean(loss_epoch_train))
    auroc_train_total.append(np.mean(auroc_epoch_train))
    loss_val_total.append(np.mean(loss_epoch_val))
    auroc_val_total.append(np.mean(auroc_epoch_val))

    print('Epoch: {} | train loss: {} | train auroc: {} |'.format(epoch + 1, 
                                                                  "%.2f" % np.mean(loss_epoch_train), 
                                                                  "%.2f" % np.mean(auroc_epoch_train)))
    print('Epoch: {} | val loss: {} | val auroc: {} |'.format(epoch + 1, 
                                                              "%.2f" % np.mean(loss_epoch_val), 
                                                              "%.2f" % np.mean(auroc_epoch_val)))
    
    print('----------------------------------------------------------------------------------------------')



Epoch: 1 | train loss: 2.71 | train auroc: 0.49 |
Epoch: 1 | val loss: 3.14 | val auroc: 0.31 |
----------------------------------------------------------------------------------------------
Epoch: 2 | train loss: 1.05 | train auroc: 0.65 |
Epoch: 2 | val loss: 1.75 | val auroc: 0.47 |
----------------------------------------------------------------------------------------------
Epoch: 3 | train loss: 0.78 | train auroc: 0.73 |
Epoch: 3 | val loss: 1.27 | val auroc: 0.57 |
----------------------------------------------------------------------------------------------
Epoch: 4 | train loss: 0.69 | train auroc: 0.76 |
Epoch: 4 | val loss: 1.06 | val auroc: 0.64 |
----------------------------------------------------------------------------------------------
Epoch: 5 | train loss: 0.66 | train auroc: 0.79 |
Epoch: 5 | val loss: 0.94 | val auroc: 0.69 |
----------------------------------------------------------------------------------------------
Epoch: 6 | train loss: 0.63 | train auroc: 0.

In [10]:
auroc_edge_type = {rel:[] for rel in range(num_relations)}
validation(data, evaluate_rel=True)

for rel, values in auroc_edge_type.items():
     print('auroc for relation type {}: {}'.format(edge_type_mapping[rel], "%.2f" % np.mean(values)))

auroc for relation type target: 0.95
auroc for relation type enzyme: 0.89
auroc for relation type carrier: 0.88
auroc for relation type transporter: 0.92
auroc for relation type ppi: 0.89
auroc for relation type target_rev: 0.87
auroc for relation type enzyme_rev: 0.89
auroc for relation type carrier_rev: 0.84
auroc for relation type transporter_rev: 0.79
