In [1]:
import numpy as np
import pandas as pd
import random
import os.path as osp
import networkx as nx
import pickle
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import k_hop_subgraph, from_networkx, train_test_split_edges
import os
import time
from tqdm import tqdm, trange
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_cluster import radius_graph, knn_graph
from torch_geometric.nn import BatchNorm, PNAConv, global_add_pool, SAGPooling, GraphNorm, GPSConv, GINEConv
from torch_geometric.nn import GINConv, JumpingKnowledge, GCNConv, Sequential, SAGEConv, GATConv, PNAConv, SimpleConv, GraphConv
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set, TopKPooling
from torch_geometric.loader import DataLoader
from gtrick.pyg import VirtualNode

In [2]:
def pkl_save(dataset_path, data):
    start = time.perf_counter()
    with open(dataset_path, 'wb') as file:
        pickle.dump(data, file)
    end = time.perf_counter()
    print(f"Data save {(end-start):.4f}s")
def pkl_load(dataset_path):
    start = time.perf_counter()
    with open(dataset_path, 'rb') as f:
        dat = pickle.load(f)
    end = time.perf_counter()
    print(f"Data loading {(end-start):.4f}s")
    return dat
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')

In [4]:
dataset_path = './pos_neg_link_datalist.pkl'
datalist = pkl_load(dataset_path)

Data loading 53.5346s


In [5]:
def dataset_split(datalist, ratio=0.95, seed=2023):
    import random
    random.seed(seed)
    random.shuffle(datalist)
    train = datalist[:int(len(datalist)*ratio)]
    test = datalist[int(len(datalist)*ratio):]
    return train, test

In [6]:
train_all, test = dataset_split(datalist)
train, val = dataset_split(train_all, ratio=1 - len(test) / len(train_all))
print(f"Train: {len(train)}; Val: {len(val)}; Test: {len(test)}")
train_path = f'./train.pkl'
val_path = f'./val.pkl'
test_path = f'./test.pkl'
pkl_save(train_path, train)
pkl_save(val_path, val)
pkl_save(test_path, test)

Train: 19534; Val: 1086; Test: 1086
Data save 39.0926s
Data save 2.1683s
Data save 1.9010s


In [7]:
conv = 'gcn'
hidden = 512
layer = 5
lr = 1e-3
num_epochs = 100
batch_size = 32

gpu = 0
cpus = 16
prefetch_factor = 2
train_path = f'./train.pkl'
val_path = f'./val.pkl'
test_path = f'./test.pkl'
train_data_list = pkl_load(train_path)
train_loader = DataLoader(train_data_list, batch_size=batch_size, pin_memory=True, num_workers=cpus, prefetch_factor=prefetch_factor, persistent_workers=True, shuffle=True)
print(f"train dataset length: {len(train_data_list)} link subgraphs.")
val_data_list = pkl_load(val_path)
val_loader = DataLoader(val_data_list, batch_size=batch_size, pin_memory=True, num_workers=cpus, prefetch_factor=prefetch_factor, persistent_workers=True, shuffle=True)
print(f"val loader length: {len(val_data_list)} link subgraphs.")
test_data_list = pkl_load(test_path)
test_loader = DataLoader(test_data_list, batch_size=batch_size, pin_memory=True, num_workers=cpus,
                         prefetch_factor=prefetch_factor, persistent_workers=True, shuffle=True)
print(f"test loader length: {len(test_data_list)} link subgraphs.")

Data loading 33.2874s
train dataset length: 19534 link subgraphs.
Data loading 1.3468s
val loader length: 1086 link subgraphs.
Data loading 1.0788s
test loader length: 1086 link subgraphs.


In [12]:
test_data_list[5].x.shape

torch.Size([2520, 113])

In [None]:
class GNN(torch.nn.Module):

    def __init__(self, in_fea, hidden_channels, num_layers, dropout, conv_type, out_channels=20):
        super(GNN, self).__init__()
        self.num_layers = num_layers
        self.dropout = dropout

        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.vns = nn.ModuleList()
        for i in range(num_layers):
            if i == 0:
                if conv_type=='gcn':
                    conv = GraphConv(in_fea, hidden)
                elif conv_type=='gin':
                    conv = GINConv(nn.Linear(in_fea, hidden_channels))
                bn = torch.nn.BatchNorm1d(hidden_channels)
                vn = VirtualNode(in_fea, hidden_channels, dropout=dropout)
            else:
                if conv_type=='gcn':
                    conv = GraphConv(in_channels=hidden_channels, out_channels=hidden_channels)
                elif conv_type=='gin':
                    conv = GINConv(nn.Linear(hidden_channels, hidden_channels))
                bn = torch.nn.BatchNorm1d(hidden_channels)
                vn = VirtualNode(hidden_channels, hidden_channels, dropout=dropout)
            self.vns.append(vn)
            self.convs.append(conv)
            self.batch_norms.append(bn)
        self.pool = TopKPooling(hidden_channels, 1e-4)
        self.mlp = nn.Linear(hidden_channels, out_channels)

    def reset_parameters(self):
        # if self.mol:
        #     for emb in self.node_encoder.atom_embedding_list:
        #         nn.init.xavier_uniform_(emb.weight.data)
        # else:
        #     nn.init.xavier_uniform_(self.node_encoder.weight.data)

        for i in range(self.num_layers):
            self.convs[i].reset_parameters()
            self.bns[i].reset_parameters()
            self.vns[i].reset_parameters()
        self.pool.reset_parameters()
        self.mlp.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        for i in range(self.num_layers):
            x, vx = self.vns[i].update_node_emb(x, edge_index, batch)
            x = self.convs[i](x, edge_index)
            x = self.batch_norms[i](x)
            x = F.dropout(F.relu(x), p=self.dropout)
        x, edge_index, edge_attr, batch, perm, select_output_weight = self.pool(x, edge_index, batch=batch)
        x = self.mlp(x)
        return x

In [None]:
device = torch.device(f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu')
model = GNN(in_fea=in_fea, hidden_channels=hidden, num_layers=layer, dropout=0.9, conv_type=conv).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, min_lr=0.00001)
train_loss_list = []
train_acc_list = []
val_acc_list = []
for epoch in range(1, num_epochs + 1):
    start = time.perf_counter()

    train_loss, train_acc = train_one_epoch(model, optimizer, criterion, train_loader, device)
    end = time.perf_counter()
    if epoch % 1 == 0:
        print(
            f"Train | {(end - start):.4f}s | Epoch {epoch} | Loss:{train_loss:.4f} | train accuracy: {train_acc:.4f}| {len(train_loader.dataset) / (end - start):.0f} samples/s ")
    del start, end
    start = time.perf_counter()
    val_loss, val_acc = eval_one_epoch(model, test_loader, criterion, device)
    end = time.perf_counter()
    train_loss_list.append(train_loss)
    train_acc_list.append(train_acc)
    val_acc_list.append(val_acc)
    # scheduler.step(val_loss)
    current_lr = optimizer.param_groups[-1]['lr']
    if epoch % 10 == 0:
        print(f"Valid | {(end - start):.4f}s| Epoch {epoch}| Loss:{val_loss}|valid accuracy: {val_acc:.4f}| lr: {current_lr}")