In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.nn as dglnn
from torch_geometric.utils import negative_sampling
from torch_geometric.seed import seed_everything
import pandas as pd
import numpy as np
import tqdm
import networkx as nx

from combsage.combsage import IDRSAGEJK
from combsage.utils import evaluate, convert_to_heterograph_group_isolates
from combsage.utils import HetEdgePredictionSampler, HomoNeighborSampler

  _torch_pytree._register_pytree_node(
  from pandas.core import (
  warn(f"Failed to load image Python extension: {e}")


In [2]:
import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [3]:
config = {'r1':15, 'r2':10, 'lr':0.001, 'batch_size':256, 'dropout':0.1}
device = 'cpu'

In [4]:
# split data
neg_edge_index = negative_sampling(edge_index=torch.vstack(g.edges()),
                                num_nodes=g.number_of_nodes(), 
                                num_neg_samples=g.number_of_edges())

# Split edge set for training and testing
u, v = g.edges()

eids = np.arange(g.number_of_edges())
eids = np.random.permutation(eids)
val_size = int(len(eids) * 0.1)
val_pos_u, val_pos_v = u[eids[:val_size]], v[eids[:val_size]]

# Find all negative edges and split them for training and testing
neg_u, neg_v = neg_edge_index[0], neg_edge_index[1]

neg_eids = np.random.choice(len(neg_u), g.number_of_edges())
val_neg_u, val_neg_v = neg_u[neg_eids[:val_size]], neg_v[neg_eids[:val_size]]

val_pos_g = dgl.graph((val_pos_u, val_pos_v), num_nodes=g.number_of_nodes())
val_neg_g = dgl.graph((val_neg_u, val_neg_v), num_nodes=g.number_of_nodes())

train_g = dgl.remove_edges(g, eids[:val_size])

In [5]:
edge_list = pd.DataFrame(torch.vstack(train_g.edges()).T)
edge_list.columns = ['source', 'target']
G = nx.from_pandas_edgelist(edge_list)

# convert graph to heterogenous graph to annotate edges according to community membership 
g_hetero = convert_to_heterograph_group_isolates(G, n_nodes = train_g.number_of_nodes()).to(device)

n_types = max([int(t) for t in g_hetero.etypes])
e_tensors = [g_hetero.edges(etype = etype) for etype in sorted(g_hetero.etypes, key = int)]
src = torch.hstack([e[0] for e in e_tensors])
dst = torch.hstack([e[1] for e in e_tensors])

# preserve homogenous version of the graph
g_homo = dgl.heterograph({('paper','1','paper'):(src,dst)})
g_hetero.ndata['feat'] = train_g.ndata['feat']
g_hetero.ndata['feat'] = g_hetero.ndata['feat']

g_homo.to(device)
g_hetero.to(device);

Graph(num_nodes={'paper': 2708},
      num_edges={('paper', '1', 'paper'): 4856, ('paper', '10', 'paper'): 2, ('paper', '11', 'paper'): 2, ('paper', '2', 'paper'): 4925, ('paper', '3', 'paper'): 507, ('paper', '4', 'paper'): 79, ('paper', '5', 'paper'): 35, ('paper', '6', 'paper'): 20, ('paper', '7', 'paper'): 16, ('paper', '8', 'paper'): 10, ('paper', '9', 'paper'): 6},
      metagraph=[('paper', 'paper', '1'), ('paper', 'paper', '10'), ('paper', 'paper', '11'), ('paper', 'paper', '2'), ('paper', 'paper', '3'), ('paper', 'paper', '4'), ('paper', 'paper', '5'), ('paper', 'paper', '6'), ('paper', 'paper', '7'), ('paper', 'paper', '8'), ('paper', 'paper', '9')])

In [6]:
model = IDRSAGEJK(g_hetero.ndata['feat'].shape[1], 256,
            n_types, dropout = config['dropout'])
model.to(device)

opt = torch.optim.Adam(model.parameters(), lr=config['lr'])
edge_dict = {etype: g_hetero.edges(etype = etype, form = 'all')[-1] for etype in g_hetero.etypes}

sampler = HomoNeighborSampler([config['r1'],config['r2']], prefetch_node_feats=['feat'])
sampler = HetEdgePredictionSampler(
        sampler, g_homo = g_homo,
        negative_sampler=dgl.dataloading.negative_sampler.Uniform(1))

dataloader = dgl.dataloading.DataLoader(
        g_hetero, edge_dict, sampler,
        device=device, batch_size= config['batch_size'], shuffle=True,
        drop_last=False, num_workers=10)

In [7]:
best_loss = 100
for epoch in range(10):
    model.train()
    # with dataloader.enable_cpu_affinity():
    with tqdm.tqdm(dataloader) as tq:
        for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(tq):
            tq.set_description('Epoch: {}'.format(epoch))
            x = {'paper':blocks[0].srcdata['feat']}
            pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x)
            pos_label = torch.ones_like(pos_score)
            neg_label = torch.zeros_like(neg_score)
            score = torch.cat([pos_score, neg_score])
            labels = torch.cat([pos_label, neg_label])
            loss = F.binary_cross_entropy_with_logits(score, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
            tq.set_postfix({'loss':'{:.3f}'.format(loss.item())})
    model.eval()
    val_loss, val_auc, _, _ = evaluate(model, g_hetero, val_pos_g, val_neg_g)
    if val_loss < best_loss:
        best_loss = val_loss
        best_params = model.state_dict()
    # torch.save(best_params,params_path)
print("Finished Training")

  assert input.numel() == input.storage().size(), "Cannot convert view " \
Epoch: 0: 100%|████████████████████████████████████████████████████████| 41/41 [00:56<00:00,  1.38s/it, loss=0.654]
Inference.: 100%|████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.14it/s]
Inference.: 100%|████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.26it/s]
Epoch: 1: 100%|████████████████████████████████████████████████████████| 41/41 [00:59<00:00,  1.46s/it, loss=0.579]
Inference.: 100%|████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.18it/s]
Inference.: 100%|████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.08it/s]
Epoch: 2: 100%|████████████████████████████████████████████████████████| 41/41 [00:57<00:00,  1.40s/it, loss=0.579]
Inference.: 100%|████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.17it/

Finished Training





In [8]:
val_auc

0.9252173131780508