In [None]:
# Paper: Link Prediction Based on Graph Neural Networks (NeurIPS 2018)
# Example: https://github.com/rusty1s/pytorch_geometric/blob/99a496e077a4d41417c7d927df7730fd984004b9/examples/seal_link_pred.py#L90

In [1]:
import math
import random
import os.path as osp
from itertools import chain

import os
import pandas as pd
from sklearn.model_selection import train_test_split
from datetime import datetime

import numpy as np
from sklearn.metrics import roc_auc_score
from scipy.sparse.csgraph import shortest_path
import torch
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss
from torch.nn import ModuleList, Linear, Conv1d, MaxPool1d

import torch_geometric
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, global_sort_pool
from torch_geometric.data import Data, InMemoryDataset, DataLoader, Dataset
from torch_geometric.utils import (negative_sampling, add_self_loops,
                                   train_test_split_edges, k_hop_subgraph,
                                   to_scipy_sparse_matrix, to_undirected)

# Define

In [2]:
def to_list(x):
    if not isinstance(x, (tuple, list)):
        x = [x]
    return x


def files_exist(files):
    return len(files) != 0 and all(osp.exists(f) for f in files)

In [3]:
class SEALDataset(InMemoryDataset):
    def __init__(self, dataset, num_hops, split='train'):
        self.data = dataset[0]
        self.num_hops = num_hops
        super(SEALDataset, self).__init__(dataset.root)
        index = ['train', 'val', 'test'].index(split)
        self.data, self.slices = torch.load(self.processed_paths[index])

    @property
    def processed_file_names(self):
        return ['SEAL_train_data.pt', 'SEAL_val_data.pt', 'SEAL_test_data.pt']

    def process(self):
        random.seed(12345)
        torch.manual_seed(12345)

        data = train_test_split_edges(self.data)

        edge_index, _ = add_self_loops(data.train_pos_edge_index)
        
        data.train_neg_edge_index = negative_sampling(
            edge_index, num_nodes=data.num_nodes,
            num_neg_samples=data.train_pos_edge_index.size(1))

        self.__max_z__ = 0

        # Collect a list of subgraphs for training, validation and test.
        train_pos_list = self.extract_enclosing_subgraphs(
            data.train_pos_edge_index, data.train_pos_edge_index, 1)
        train_neg_list = self.extract_enclosing_subgraphs(
            data.train_neg_edge_index, data.train_pos_edge_index, 0)

        val_pos_list = self.extract_enclosing_subgraphs(
            data.val_pos_edge_index, data.train_pos_edge_index, 1)
        val_neg_list = self.extract_enclosing_subgraphs(
            data.val_neg_edge_index, data.train_pos_edge_index, 0)

        test_pos_list = self.extract_enclosing_subgraphs(
            data.test_pos_edge_index, data.train_pos_edge_index, 1)
        test_neg_list = self.extract_enclosing_subgraphs(
            data.test_neg_edge_index, data.train_pos_edge_index, 0)

        # Convert labels to one-hot features.
        for data in chain(train_pos_list, train_neg_list, val_pos_list,
                          val_neg_list, test_pos_list, test_neg_list):
            data.x = F.one_hot(data.z, self.__max_z__ + 1).to(torch.float)

            
        torch.save(self.collate(train_pos_list + train_neg_list),
                   self.processed_paths[0])
        torch.save(self.collate(val_pos_list + val_neg_list),
                   self.processed_paths[1])
        torch.save(self.collate(test_pos_list + test_neg_list),
                   self.processed_paths[2])

    def extract_enclosing_subgraphs(self, link_index, edge_index, y):
        data_list = []
        for src, dst in link_index.t().tolist():
            sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(
                [src, dst], self.num_hops, edge_index, relabel_nodes=True)
            src, dst = mapping.tolist()

            # Remove target link from the subgraph.
            mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
            mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
            sub_edge_index = sub_edge_index[:, mask1 & mask2]

            # Calculate node labeling.
            z = self.drnl_node_labeling(sub_edge_index, src, dst,
                                        num_nodes=sub_nodes.size(0))

            data = Data(x=self.data.x[sub_nodes], z=z,
                        edge_index=sub_edge_index, y=y)
        
            data_list.append(data)

        return data_list

    def drnl_node_labeling(self, edge_index, src, dst, num_nodes=None):
        # Double-radius node labeling (DRNL).
        src, dst = (dst, src) if src > dst else (src, dst)
        adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr()

        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
        adj_wo_src = adj[idx, :][:, idx]

        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
        adj_wo_dst = adj[idx, :][:, idx]

        dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True,
                                 indices=src)
        dist2src = np.insert(dist2src, dst, 0, axis=0)
        dist2src = torch.from_numpy(dist2src)

        dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True,
                                 indices=dst - 1)
        dist2dst = np.insert(dist2dst, src, 0, axis=0)
        dist2dst = torch.from_numpy(dist2dst)

        dist = dist2src + dist2dst
        dist_over_2, dist_mod_2 = dist // 2, dist % 2

        z = 1 + torch.min(dist2src, dist2dst)
        z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
        z[src] = 1.
        z[dst] = 1.
        z[torch.isnan(z)] = 0.

        self.__max_z__ = max(int(z.max()), self.__max_z__)

        return z.to(torch.long)


In [4]:
class MyDataset(InMemoryDataset):
    
    def __init__(self, num_hops, root=None, split="train"):
        
        self.num_hops = num_hops
        super().__init__(root=root)
        
        index = ['train', 'valid'].index(split)
        self.data, self.slices = torch.load(self.processed_paths[index])

        
    @property
    def raw_file_names(self):  
        return ["train.csv"]
    
    
    @property
    def processed_file_names(self):
        return ['SEAL_data_0.pt', 'SEAL_data_1.pt']
    
    def process(self):
        # Read node features
        content = pd.read_csv(os.path.join(self.raw_dir, "content.csv"), delimiter="\t", header=None)
        content = content.sort_values(by=[0]).loc[:, 1:].to_numpy()
        content = torch.from_numpy(content)
        num_nodes = content.size(0)

        # Read edge list
        train = pd.read_csv(os.path.join(self.raw_dir, "train.csv"))

        train_pos = train[ train["label"] == 1]
        train_neg = train[ train["label"] == 0]

        follower_pos = train_pos["from"].to_numpy().tolist()
        followee_pos = train_pos["to"].to_numpy().tolist()
        train_pos_edge = torch.tensor([follower_pos, followee_pos], dtype=torch.long)
        train_pos_edge = to_undirected(train_pos_edge)
#         train_pos_edge = add_self_loops(train_pos_edge)[0]

        self.data = Data(x=content, edge_index=train_pos_edge, num_nodes=num_nodes)

        train_pos_edge = train_pos_edge.t()

        follower_neg = train_neg["from"].to_numpy().tolist()
        followee_neg = train_neg["to"].to_numpy().tolist()
        train_neg_edge = torch.tensor([follower_neg, followee_neg], dtype=torch.long)
        train_neg_edge = to_undirected(train_neg_edge)
        
#         neg_data = Data(edge_index=train_neg_edge, num_nodes=num_nodes)
#         train_neg_edge = add_self_loops(neg_data.edge_index)[0]
        
        train_neg_edge = train_neg_edge.t()

        train_pos_edge, valid_pos_edge = train_test_split(train_pos_edge, shuffle=True)
        train_neg_edge, valid_neg_edge = train_test_split(train_neg_edge, shuffle=True)

        train_pos_edge = train_pos_edge.t()
        valid_pos_edge = valid_pos_edge.t()
        train_neg_edge = train_neg_edge.t()
        valid_neg_edge = valid_neg_edge.t()


        self.__max_z__ = 0


        train_pos_list = self.extract_enclosing_subgraphs(
            train_pos_edge, train_pos_edge, 1)
        train_neg_list = self.extract_enclosing_subgraphs(
            train_neg_edge, train_pos_edge, 0)
    

        val_pos_list = self.extract_enclosing_subgraphs(
            valid_pos_edge, train_pos_edge, 1)
        val_neg_list = self.extract_enclosing_subgraphs(
            valid_neg_edge, train_pos_edge, 0)


        # Convert labels to one-hot features.
        for data in chain(train_pos_list, train_neg_list, 
                          val_pos_list, val_neg_list):
            z = F.one_hot(data.z, self.__max_z__ + 1).to(torch.float)
            data.x = torch.cat([z, data.x], 1)


        
            
        torch.save(self.collate(train_pos_list + train_neg_list),
                   self.processed_paths[0])
        torch.save(self.collate(val_pos_list + val_neg_list),
                   self.processed_paths[1])


#         if self.pre_filter is not None and not self.pre_filter(data):
#             continue

#         if self.pre_transform is not None:
#             data = self.pre_transform(data)


#         torch.save(self.collate(train_pos_list + train_neg_list),
#                    self.processed_paths[0])
#         torch.save(self.collate(val_pos_list + val_neg_list),
#                    self.processed_paths[1])

    
#     def len(self):
#         return len(self.processed_file_names)
    
    
    def extract_enclosing_subgraphs(self, link_index, edge_index, y):
        data_list = []
        for src, dst in link_index.t().tolist():
            sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(
                [src, dst], self.num_hops, edge_index, relabel_nodes=True)
            src, dst = mapping.tolist()

            # Remove target link from the subgraph.
            mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
            mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
            sub_edge_index = sub_edge_index[:, mask1 & mask2]

            # Calculate node labeling.
            z = self.drnl_node_labeling(sub_edge_index, src, dst,
                                        num_nodes=sub_nodes.size(0))

            data = Data(x=self.data.x[sub_nodes], z=z,
                        edge_index=sub_edge_index, y=y)
            data_list.append(data)

        return data_list

    def drnl_node_labeling(self, edge_index, src, dst, num_nodes=None):
        # Double-radius node labeling (DRNL).
        src, dst = (dst, src) if src > dst else (src, dst)
        adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr()

        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
        adj_wo_src = adj[idx, :][:, idx]

        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
        adj_wo_dst = adj[idx, :][:, idx]
        
#         print(src, adj_wo_dst.shape)
#         print(adj_wo_dst[src])
        

        dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True,
                                 indices=src)
        dist2src = np.insert(dist2src, dst, 0, axis=0)
        dist2src = torch.from_numpy(dist2src)

        dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True,
                                 indices=dst - 1)
        dist2dst = np.insert(dist2dst, src, 0, axis=0)
        dist2dst = torch.from_numpy(dist2dst)

        dist = dist2src + dist2dst
        dist_over_2, dist_mod_2 = dist // 2, dist % 2

        z = 1 + torch.min(dist2src, dist2dst)
        z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
        z[src] = 1.
        z[dst] = 1.
        z[torch.isnan(z)] = 0.

        self.__max_z__ = max(int(z.max()), self.__max_z__)

        return z.to(torch.long)


# Model

In [5]:
class DGCNN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers, GNN=GCNConv, k=0.6):
        super(DGCNN, self).__init__()

        if k < 1:  # Transform percentile to number.
            num_nodes = sorted([data.num_nodes for data in train_dataset])
            k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1]
            k = max(10, k)
        self.k = int(k)

        self.convs = ModuleList()
        self.convs.append(GNN(train_dataset.num_features, hidden_channels))
        for i in range(0, num_layers - 1):
            self.convs.append(GNN(hidden_channels, hidden_channels))
        self.convs.append(GNN(hidden_channels, 1))

        conv1d_channels = [16, 32]
        total_latent_dim = hidden_channels * num_layers + 1
        conv1d_kws = [total_latent_dim, 5]
        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0],
                            conv1d_kws[0])
        self.maxpool1d = MaxPool1d(2, 2)
        self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1],
                            conv1d_kws[1], 1)
        dense_dim = int((self.k - 2) / 2 + 1)
        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
        self.lin1 = Linear(dense_dim, 128)
        self.lin2 = Linear(128, 1)

    def forward(self, x, edge_index, batch):
        xs = [x]
        for conv in self.convs:
            xs += [torch.tanh(conv(xs[-1], edge_index))]
        x = torch.cat(xs[1:], dim=-1)

        # Global pooling.

        x = global_sort_pool(x, batch, self.k)

        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = F.relu(self.conv1(x))
        x = self.maxpool1d(x)

        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

        
        # MLP.
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)

        return x


In [6]:
def train(loader):
    model.train()

    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        logits = model(data.x, data.edge_index, data.batch)
        loss = BCEWithLogitsLoss()(logits.view(-1), data.y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs

    return total_loss / len(train_dataset)

In [7]:
@torch.no_grad()
def test(loader):
    model.eval()

    y_pred, y_true = [], []
    for data in loader:
#         print(data)
        data = data.to(device)
        logits = model(data.x, data.edge_index, data.batch)
        y_pred.append(logits.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))

#     print(y_pred, y_true)
    return roc_auc_score(torch.cat(y_true), torch.cat(y_pred))

# Usage

In [8]:
from torch_geometric.datasets import Planetoid

In [11]:
cora = Planetoid(root=os.getcwd(), name='Cora')

cora_train_dataset = SEALDataset(cora, num_hops=2, split='train')
cora_val_dataset = SEALDataset(cora, num_hops=2, split='val')
cora_test_dataset = SEALDataset(cora, num_hops=2, split='test')

cora_train_loader = DataLoader(cora_train_dataset, batch_size=32, shuffle=True)
cora_val_loader = DataLoader(cora_val_dataset, batch_size=32)
cora_test_loader = DataLoader(cora_test_dataset, batch_size=32)

Processing...
Done!


In [12]:
cora_train_dataset

SEALDataset(17952)

In [94]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = DGCNN(hidden_channels=32, num_layers=3).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)

In [95]:
best_val_auc = test_auc = 0
for epoch in range(1, 51):
    loss = train(cora_train_loader)
    val_auc = test(cora_val_loader)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        test_auc = test(cora_test_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
          f'Test: {test_auc:.4f}')

Batch(batch=[2155], edge_index=[2, 7254], x=[2155, 136], y=[32], z=[2155])
Batch(batch=[1894], edge_index=[2, 5896], x=[1894, 136], y=[32], z=[1894])
Batch(batch=[2382], edge_index=[2, 7896], x=[2382, 136], y=[32], z=[2382])
Batch(batch=[1698], edge_index=[2, 5468], x=[1698, 136], y=[32], z=[1698])
Batch(batch=[2191], edge_index=[2, 6692], x=[2191, 136], y=[32], z=[2191])
Batch(batch=[2028], edge_index=[2, 6402], x=[2028, 136], y=[32], z=[2028])
Batch(batch=[2287], edge_index=[2, 7266], x=[2287, 136], y=[32], z=[2287])
Batch(batch=[2826], edge_index=[2, 9558], x=[2826, 136], y=[32], z=[2826])
Batch(batch=[1590], edge_index=[2, 4832], x=[1590, 136], y=[32], z=[1590])
Batch(batch=[1997], edge_index=[2, 6262], x=[1997, 136], y=[32], z=[1997])
Batch(batch=[2060], edge_index=[2, 6396], x=[2060, 136], y=[32], z=[2060])
Batch(batch=[1421], edge_index=[2, 4062], x=[1421, 136], y=[32], z=[1421])
Batch(batch=[1619], edge_index=[2, 4910], x=[1619, 136], y=[32], z=[1619])
Batch(batch=[2441], edge_

Batch(batch=[3520], edge_index=[2, 12316], x=[3520, 136], y=[32], z=[3520])
Batch(batch=[3038], edge_index=[2, 10136], x=[3038, 136], y=[32], z=[3038])
Batch(batch=[1753], edge_index=[2, 5532], x=[1753, 136], y=[32], z=[1753])
Batch(batch=[1644], edge_index=[2, 4840], x=[1644, 136], y=[32], z=[1644])
Batch(batch=[1725], edge_index=[2, 5102], x=[1725, 136], y=[32], z=[1725])
Batch(batch=[1985], edge_index=[2, 6220], x=[1985, 136], y=[32], z=[1985])
Batch(batch=[1914], edge_index=[2, 5922], x=[1914, 136], y=[32], z=[1914])
Batch(batch=[1310], edge_index=[2, 3846], x=[1310, 136], y=[32], z=[1310])
Batch(batch=[1330], edge_index=[2, 3762], x=[1330, 136], y=[32], z=[1330])
Batch(batch=[1344], edge_index=[2, 3742], x=[1344, 136], y=[32], z=[1344])
Batch(batch=[1525], edge_index=[2, 4568], x=[1525, 136], y=[32], z=[1525])
Batch(batch=[1487], edge_index=[2, 4046], x=[1487, 136], y=[32], z=[1487])
Batch(batch=[1662], edge_index=[2, 5012], x=[1662, 136], y=[32], z=[1662])
Batch(batch=[1891], edg

KeyboardInterrupt: 

# My Usage

In [9]:
train_dataset =  MyDataset(num_hops=2, root=os.path.join(os.getcwd(), "hw2_data", "dataset1"), split="train")
val_dataset =  MyDataset(num_hops=2, root=os.path.join(os.getcwd(), "hw2_data", "dataset1"), split="valid")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [10]:
train_dataset

MyDataset(12897)

In [11]:
train_dataset[0]

Data(edge_index=[2, 0], x=[2, 1518], y=[1], z=[2])

In [12]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = DGCNN(hidden_channels=32, num_layers=3).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
date_time = datetime.strftime(datetime.now(), "%Y-%m-%d_%H-%M")

In [None]:
best_val_auc = test_auc = 0
for epoch in range(1, 101):
    loss = train(train_loader)
    val_auc = test(val_loader)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
#         test_auc = test(test_loader)

        checkpoint = {
            'model_stat': model.state_dict(),
            'optimizer_stat': optimizer.state_dict(),
        }
        
        torch.save(checkpoint, os.path.join(os.getcwd(), 
                                            "hw2_data", 
                                            "dataset1", 
                                            "{}.pth".format(date_time)))
        print("Save Model\n")
        
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}')

Save Model

Epoch: 01, Loss: 0.6093, Val: 0.8282
Save Model

Epoch: 02, Loss: 0.4805, Val: 0.8583
Save Model

Epoch: 03, Loss: 0.4386, Val: 0.8754
Save Model

Epoch: 04, Loss: 0.4081, Val: 0.8832
Save Model

Epoch: 05, Loss: 0.3877, Val: 0.8887
Save Model

Epoch: 06, Loss: 0.3698, Val: 0.8939
Save Model

Epoch: 07, Loss: 0.3534, Val: 0.8953
Save Model

Epoch: 08, Loss: 0.3411, Val: 0.8976
Save Model

Epoch: 09, Loss: 0.3294, Val: 0.9006
Epoch: 10, Loss: 0.3170, Val: 0.9006
Save Model

Epoch: 11, Loss: 0.3072, Val: 0.9033
Epoch: 12, Loss: 0.3000, Val: 0.9031
Save Model

Epoch: 13, Loss: 0.2875, Val: 0.9066
Save Model

Epoch: 14, Loss: 0.2788, Val: 0.9071
Save Model

Epoch: 15, Loss: 0.2705, Val: 0.9087
Save Model

Epoch: 16, Loss: 0.2684, Val: 0.9104
Save Model

Epoch: 17, Loss: 0.2578, Val: 0.9126
Save Model

Epoch: 18, Loss: 0.2515, Val: 0.9149
Save Model

Epoch: 19, Loss: 0.2481, Val: 0.9166
Save Model

Epoch: 20, Loss: 0.2435, Val: 0.9188
Epoch: 21, Loss: 0.2375, Val: 0.9174
Epoch: 