In [27]:
import argparse
from codecs import ignore_errors
from random import shuffle
import time
import os, sys, shutil
from tqdm import tqdm
import numpy as np
import scipy.sparse as ssp
import networkx as nx
from sklearn.metrics import roc_auc_score
import torch
from torch_geometric.data import Data, Dataset, InMemoryDataset
from torch_geometric.loader import DataLoader
from ogb.linkproppred import PygLinkPropPredDataset, Evaluator

parent_path = os.path.dirname(sys.path[0])
if parent_path not in sys.path:
    sys.path.append(parent_path)
from easylink.common.eval_utils import evaluate_auc
from easylink.model.heuristic_similarity import common_neighbors, adamic_adar, resource_allocation, local_path_index
from easylink.model.seal import SEAL, SEALDataset
from easylink.common.data_utils import load_basic_network, train_test_split
from easylink.common.seal_utils import *

In [2]:
@torch.no_grad()
def test(dataset, seal, batch_size, evaluator=None):
    seal.model.eval()
    data_loader = DataLoader(dataset, batch_size)
    pbar = tqdm(data_loader, ncols=80)
    y_pred, y_true = [], []
    for data in pbar:
        data = data.to(device)
        x = data.x if args.use_feature else None
        logits = seal.model(data.z, data.edge_index, data.batch, x)
        y_pred.append(logits.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))
    val_pred, val_true = torch.cat(y_pred), torch.cat(y_true)
    pos_val_pred = val_pred[val_true==1]
    neg_val_pred = val_pred[val_true==0]

    result = {}
    auc = roc_auc_score(val_true, val_pred)
    print("AUC:{}".format(auc))
    result['AUC'] = auc

    if evaluator is not None:
        for K in [20, 50, 100]:
            evaluator.K = K
            valid_hits = evaluator.eval({
                    'y_pred_pos': pos_val_pred,
                    'y_pred_neg': neg_val_pred,
                })[f'hits@{K}']
            result[f'Hits@{K}'] = valid_hits
            print(f"Hits@{K}:{valid_hits}")

    return result

In [6]:
parser = argparse.ArgumentParser(description='SEAL LinkPredictor')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--log_steps', type=int, default=1)
parser.add_argument('--dataset', type=str, default='ogbl-collab')
# DataStructure settings
parser.add_argument('--use_feature', action='store_true', 
                help="whether to use raw node features as GNN input")
parser.add_argument('--num_hops', type=int, default=2)
parser.add_argument('--max_nodes_per_hop', type=int, default=10)
# GNN settings
parser.add_argument('--model', type=str, default='SAGE')
parser.add_argument('--sortpool_k', type=float, default=0.6)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--hidden_channels', type=int, default=64)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--dropout', type=float, default=0.5)

# Training settings
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--epochs', type=int, default=5)
args = parser.parse_args(args=['--device', '0', '--dataset', 'facebook'])
print(args)

Namespace(device=0, log_steps=1, dataset='facebook', use_feature=False, num_hops=2, max_nodes_per_hop=10, model='SAGE', sortpool_k=0.6, num_layers=3, hidden_channels=64, batch_size=256, dropout=0.5, lr=0.0001, epochs=5)


In [59]:
dataset_root = '../data/facebook'
# facebook
facebook_dir = dataset_root+'/facebook.txt'
g = load_basic_network(facebook_dir)
adj = nx.adjacency_matrix(g)
train_edges, test_edges, val_edges, train_neg_edges, test_neg_edges, val_neg_edges = train_test_split(adj)

loading file: ../data/facebook/facebook.txt
#nodes: 4039 ,#edges: 88234
Negative Sampling.


In [9]:
train_edges[:10]

[(0, 148),
 (741, 1002),
 (540, 614),
 (564, 822),
 (351, 2376),
 (658, 906),
 (902, 1371),
 (2993, 3109),
 (2487, 2630),
 (391, 1022)]

In [60]:
train_edges = np.array(train_edges)
train_edges_reverse = np.array([train_edges[:,1], train_edges[:,0]]).transpose()
train_edges = np.concatenate([train_edges, train_edges_reverse], axis=0)
edge_weight = torch.ones(train_edges.shape[0], dtype=int)
A = ssp.csr_matrix(
    (edge_weight, (train_edges[:,0], train_edges[:,1])), shape = (g.number_of_nodes(), g.number_of_nodes())
)
h_predictor = common_neighbors 
val_edges = torch.tensor(val_edges)
val_neg_edges = torch.tensor(val_neg_edges)
pos_valid_pred = h_predictor(A, val_edges, batch_size=args.batch_size)
neg_valid_pred = h_predictor(A, val_neg_edges, batch_size=args.batch_size)
val_pred = torch.cat([torch.Tensor(pos_valid_pred), torch.Tensor(neg_valid_pred)])
val_true = torch.cat([torch.ones(pos_valid_pred.shape[0], dtype=int),
            torch.zeros(neg_valid_pred.shape[0], dtype=int)]) 
auc = roc_auc_score(val_true, val_pred)
print("Heuristic AUC:{}".format(auc)) 

100%|███████████████████████████████████████████████████████████████████████████████| 35/35 [00:00<00:00, 1643.81it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 2368.44it/s]

Heuristic AUC:0.982454120657644





In [61]:
train_edges = np.array(train_edges)
train_edges_reverse = np.array([train_edges[:,1], train_edges[:,0]]).transpose()
train_edges = np.concatenate([train_edges, train_edges_reverse], axis=0)
train_edges = torch.tensor(train_edges)
val_edges = torch.tensor(val_edges)
val_neg_edges = torch.tensor(val_neg_edges)

  val_edges = torch.tensor(val_edges)
  val_neg_edges = torch.tensor(val_neg_edges)


In [62]:
train_edges.t().size(1)

247060

In [41]:
edge_index = train_edges.t()
edge_weight = torch.ones(edge_index.size(1), dtype=int)
A = ssp.csr_matrix(
            (edge_weight, (edge_index[0], edge_index[1])), 
            shape=(g.number_of_nodes(),g.number_of_nodes()))
print('Graph Ajc Shape:', A.shape)

Graph Ajc Shape: (4039, 4039)


In [23]:
pos_edge_index = train_edges.t()

In [28]:
pos_edge, neg_edge = get_pos_neg_edges(pos_edge_index, g.number_of_nodes(), None)

In [44]:
pos_edge[:,0]

tensor([2126, 3370], dtype=torch.int32)

In [42]:
pos_list = extract_enclosing_subgraphs(pos_edge, A, None, 1, 2)

  self._set_intXint(row, col, x.flat[0])
100%|██████████████████████████████████████████████████████████████████████████| 10000/10000 [00:12<00:00, 820.46it/s]


In [50]:
pos_list[0].edge_weight

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [63]:
# Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
max_z = 1000
lr = 0.001
epochs = 10
seal = SEAL('SAGE', False, lr, args.hidden_channels, args.num_layers, max_z, args.dropout)
train_dataset_dir = dataset_root+'_seal'
shutil.rmtree(train_dataset_dir, ignore_errors=True)
train_dataset = SEALDataset(train_dataset_dir, train_edges.t(), train_edges, 
                g.number_of_nodes(), args.num_hops, args.max_nodes_per_hop, node_feat=None)
seal.train(train_dataset, epochs, args.batch_size, device)


Total number of parameters is 92993
Processing dataset.


Processing...
100%|████████████████████████████████████████████████████████████████████████| 247060/247060 [05:07<00:00, 804.10it/s]
  self._set_intXint(row, col, x.flat[0])
100%|████████████████████████████████████████████████████████████████████████| 247060/247060 [05:56<00:00, 692.39it/s]
Done!
Epoch 0, Loss: 0.000738: 100%|██████████████| 1931/1931 [03:03<00:00, 10.51it/s]
Epoch 1, Loss: 0.000702: 100%|██████████████| 1931/1931 [02:32<00:00, 12.70it/s]
Epoch 2, Loss: 0.000686: 100%|██████████████| 1931/1931 [02:32<00:00, 12.62it/s]
Epoch 3, Loss: 0.000677: 100%|██████████████| 1931/1931 [02:36<00:00, 12.33it/s]
Epoch 4, Loss: 0.00067: 100%|███████████████| 1931/1931 [02:35<00:00, 12.45it/s]
Epoch 5, Loss: 0.000665: 100%|██████████████| 1931/1931 [02:33<00:00, 12.62it/s]
Epoch 6, Loss: 0.000662: 100%|██████████████| 1931/1931 [02:32<00:00, 12.70it/s]
Epoch 7, Loss: 0.000659: 100%|██████████████| 1931/1931 [02:32<00:00, 12.66it/s]
Epoch 8, Loss: 0.000657: 100%|██████████████| 1931/19

0.0006546536139641155

In [64]:
# Test
val_dataset_dir = dataset_root+"_seal_val"
shutil.rmtree(val_dataset_dir, ignore_errors=True)
val_dataset = SEALDataset(val_dataset_dir, train_edges.t(), val_edges,
                        g.number_of_nodes(), args.num_hops, args.max_nodes_per_hop,
                        node_feat=None, neg_edges=val_neg_edges)
test(val_dataset, seal, args.batch_size, None)

Processing...


Processing dataset.


100%|████████████████████████████████████████████████████████████████████████████| 8823/8823 [00:13<00:00, 671.11it/s]
100%|██████████████████████████████████████████████████████████████████████████| 17646/17646 [00:25<00:00, 701.68it/s]
Done!
100%|█████████████████████████████████████████| 104/104 [00:06<00:00, 16.64it/s]

AUC:0.9803714555564406





{'AUC': 0.9803714555564406}