## $\text{Link Prediction using Graph Neural Network} $

kipf's 가 제안한 GCN, 이후 나온 GraphSAGE 등의 모델은 Node Classification을 위한 모델입니다. 

본 tutorial에서는 Link Prediction을 하기 위한 workflow를 다룹니다. 

* Prepare training and testing sets for link prediction task.
* Build a GNN-based link prediction model.
* Train the model and verify the result.

In [73]:
import dgl 
import time 
import torch 
import torch.nn as nn 
import torch.optim as optim 
import torch.nn.functional as F 

import numpy as np 
import scipy.sparse as sp 

In [2]:
from utils import load_zachery

g = load_zachery()

## $\text{Prepare training and testing set} $

`edge`에는 두 가지 종류가 있습니다. `negative`, `positive`.. 

예제에서는 모델의 학습을 위해 train edge와 test edge를 분할합니다. 50개는 test, 나머지는 train

In [36]:
u, v = g.edges()
eids = np.arange(g.number_of_edges())
eids = np.random.permutation(eids)

test_pos_u, test_pos_v = u[eids[:50]], v[eids[:50]]
train_pos_u, train_pos_v = u[eids[50:]], v[eids[50:]]

In [37]:
# coo_matrix를 사용하는 이유는 matrix에 0이 많은 경우 압축률이 좋기 때문입니다. 
# adj = g.adjacency_matrix(transpose = False, scipy_fmt='csr').todense() 둘 중 아무거나 사용하셔도 무방합니다. 
"""
edge가 없는 부분을 negative로 표현하기 위해 이러한 과정을 수행합니다. 
"""
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy()))) # edge에 weight를 따로 지정하지 않았기 때문에 np.ones(len(u))를 사용합니다. 
adj_neg = 1 - adj.todense() - np.eye(34) # np.eye == digonal matrix 
neg_u, neg_v = np.where(adj_neg != 0)
neg_eids = np.random.choice(len(neg_u), 200)
test_neg_u, test_neg_v = neg_u[neg_eids[:50]], neg_v[neg_eids[:50]]
train_neg_u, train_neg_v = neg_u[neg_eids[50:]], neg_v[neg_eids[50:]]

In [44]:
train_u = torch.cat([torch.as_tensor(train_pos_u), torch.as_tensor(train_neg_u)])
train_v = torch.cat([torch.as_tensor(train_pos_v), torch.as_tensor(train_neg_v)])
train_label = torch.cat([torch.zeros(len(train_pos_u)), torch.ones(len(train_neg_u))])

test_u = torch.cat([torch.as_tensor(test_pos_u), torch.as_tensor(test_neg_u)])
test_v = torch.cat([torch.as_tensor(test_pos_v), torch.as_tensor(test_neg_v)])
test_label = torch.cat([torch.zeros(len(test_pos_u)), torch.ones(len(test_neg_u))])

## $ \text{Define a GraphSAGE model} $

### $$ h^k_{\mathcal{N}(v)} \leftarrow \text{AGGEGATE}_k(h^{k-1}_u, \forall u \in \mathcal{N}(v))$$
### $$ h^k_v \leftarrow \sigma(W^k \cdot \text{CONCAT} (h^{k-1}_v, h^k_{N(v)})) $$

DGL은 많은 neighbor aggregation modules을 제공하고 있습니다. 사용하고자 하는 module을 호출(invoke)해서 사용하시면 됩니다. 

In [102]:
from dgl.nn import SAGEConv 
# build a two-layer GraphSAGE model 
class GraphSAGE(nn.Module):
    def __init__(self, num_nodes, embed_dim, h_feats):
        super(GraphSAGE, self).__init__()
        self.num_nodes = num_nodes 
        self.embeded = nn.Embedding(num_nodes, embed_dim)
        self.conv1 = SAGEConv(embed_dim, h_feats, 'mean')
        self.conv2 = SAGEConv(h_feats, h_feats, 'mean')
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
        self._init_weight()
        
    def forward(self, g):
        embed = self.embeded.weight
        output = self.conv1(g, embed)
        output = self.relu(output)
        output = self.conv2(g, output)
        
        return output
    
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Embedding):
                nn.init.xavier_uniform_(m.weight)


num_nodes = g.number_of_nodes()
embed_dim = 5
h_feats = 16
models = GraphSAGE(num_nodes, embed_dim, h_feats)
        

We then optimize the model using the following loss function.

$$ \hat{y}_{u~v} = \sigma(h^T_u h_v) $$
$$ \mathcal{L} = - \sum_{u~v \in \mathcal{D}} ( y_{u~v} \log (\hat{y}_{u~v}) + (1-y_{u~v}) \log (1 - \hat{y}_{u~v})) $$

기본적으로 위에서 구축한 모델은 두 노드의 표현(representation)을 내적하여 edge score를 예측합니다. 

그 후, target $y$가 0 혹은 1인 binary cross entropy loss를 계산하여 edge가 양수인지 확인합니다. 

In [103]:
optimizer = optim.Adam(models.parameters(), lr = 1e-2)
criterion = nn.BCELoss()

def calc_accuracy(pred, true):
    return ((pred >= 0.5) == true).sum().item() / len(pred)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time 
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = elapsed_time - elapsed_mins * 60 
    return elapsed_mins, elapsed_secs

In [104]:
all_logits = []
num_epochs = 100

for epoch in range(1, num_epochs + 1):
    
    models.train()
    start_time = time.time()
    pred = models(g)
    optimizer.zero_grad()
    pred_proba = torch.sigmoid((pred[train_u] * pred[train_v]).sum(dim=1))
    train_loss = criterion(pred_proba, train_label)
    
    train_loss.backward()
    optimizer.step()
    
    all_logits.append(pred_proba.detach())
    train_acc = calc_accuracy(pred_proba, train_label)
    
    if epoch % 5 == 0 :
        with torch.no_grad():
            models.eval()
            pred = models(g)
            test_proba = torch.sigmoid((pred[test_u] * pred[test_v]).sum(dim=1))
            test_loss = criterion(test_proba, test_label)
            
            test_acc = calc_accuracy(test_proba, test_label)
            
            end_time = time.time()
            elapsed_mins, elapsed_secs = epoch_time(start_time, end_time)
            print(f'epoch [{epoch}/{num_epochs}] | elapsed time {elapsed_mins}m, {elapsed_secs:.2f}s')
            print(f'train loss: {train_loss:.4f}\t train acc: {train_acc*100:.2f}%')
            print(f'test loss: {test_loss:.4f}\t train acc: {test_acc*100:.2f}% \n')

epoch [5/100] | elapsed time 0m, 0.03s
train loss: 0.5798	 train acc: 66.80%
test loss: 0.6922	 train acc: 53.00% 

epoch [10/100] | elapsed time 0m, 0.01s
train loss: 0.4882	 train acc: 78.52%
test loss: 0.6721	 train acc: 61.00% 

epoch [15/100] | elapsed time 0m, 0.01s
train loss: 0.4100	 train acc: 79.69%
test loss: 0.6300	 train acc: 67.00% 

epoch [20/100] | elapsed time 0m, 0.01s
train loss: 0.3209	 train acc: 85.55%
test loss: 0.5630	 train acc: 68.00% 

epoch [25/100] | elapsed time 0m, 0.01s
train loss: 0.2500	 train acc: 89.06%
test loss: 0.5799	 train acc: 70.00% 

epoch [30/100] | elapsed time 0m, 0.01s
train loss: 0.1843	 train acc: 91.41%
test loss: 0.6130	 train acc: 75.00% 

epoch [35/100] | elapsed time 0m, 0.01s
train loss: 0.1288	 train acc: 95.70%
test loss: 0.6224	 train acc: 80.00% 

epoch [40/100] | elapsed time 0m, 0.01s
train loss: 0.0792	 train acc: 96.88%
test loss: 0.7371	 train acc: 88.00% 

epoch [45/100] | elapsed time 0m, 0.01s
train loss: 0.0448	 train