In [25]:
import itertools
import os
os.environ['DGL_BACKEND'] = 'pytorch'
import dgl
import dgl.data
from dgl.nn import SAGEConv
import dgl.function as fn
import numpy as np
import scipy.sparse as sp
from sklearn.metrics import roc_auc_score
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [48]:
# split edges into train and test sets
u, v = g.edges()

eids = np.arange(g.num_edges())
eids = np.random.permutation(eids)
test_size = int(len(eids) * 0.1)
train_size = g.num_edges() - test_size

# positive examples (edge exists)
test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]]

# negative examples (edge does not exist)
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = 1 - adj.todense() - np.eye(g.num_nodes())
neg_u, neg_v = np.where(adj_neg != 0)

# split negative edges
neg_eids = np.random.choice(len(neg_u), g.num_edges())
test_neg_u, test_neg_v = (
    neg_u[neg_eids[:test_size]],
    neg_v[neg_eids[:test_size]],
)
train_neg_u, train_neg_v = (
    neg_u[neg_eids[test_size:]],
    neg_v[neg_eids[test_size:]],
)

In [49]:
train_g = dgl.remove_edges(g, eids[:test_size])

In [50]:
class GraphSAGE(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, 'mean')
        self.conv2 = SAGEConv(h_feats, h_feats, 'mean')
    
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

In [51]:
# construct positive and negative graphs
# positive graph has positive examples as edges
# negative graph has negative examples
train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.num_nodes())
train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.num_nodes())

test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.num_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.num_nodes())

In [52]:
class DotPredictor(nn.Module):
    def forward(self, g, h):
        g.ndata['h'] = h
        # compute edge feature 'score' from source h and destination h
        g.apply_edges(fn.u_dot_v('h', 'h', 'score'))
        return g.edata['score'][:, 0]

In [53]:
class MLPPredictor(nn.Module):
    def __init__(self, h_feats):
        super().__init__()
        self.W1 = nn.Linear(h_feats * 2, h_feats)
        self.W2 = nn.Linear(h_feats, 1)
    
    def apply_edge(self, edges):
        # concat features from edge sources and destinations
        h = torch.cat([edges.src['h'], edges.dst['h']], 1)
        return {'score': self.W2(F.relu(self.W1(h))).squeeze(1)}
    
    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            g.apply_edges(self.apply_edges)
            return g.edata['score']

In [54]:
def compute_loss(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
    )
    return F.binary_cross_entropy_with_logits(scores, labels)

def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
    ).numpy()
    return roc_auc_score(labels, scores)

In [55]:
model = GraphSAGE(train_g.ndata['feat'].shape[1], 16)
# pred = MLPPredictor(16)
pred = DotPredictor()

In [57]:
optimizer = torch.optim.Adam(
    itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)

all_logits = []
for epoch in range(100):
    h = model(train_g, train_g.ndata['feat'])
    pos_score = pred(train_pos_g, h)
    neg_score = pred(train_neg_g, h)
    loss = compute_loss(pos_score, neg_score)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 5 == 0:
        print(f'Epoch {epoch}, loss: {loss}')

with torch.no_grad():
    pos_score = pred(test_pos_g, h)
    neg_score = pred(test_neg_g, h)
    print('AUC', compute_auc(pos_score, neg_score))

Epoch 0, loss: 0.7125867009162903
Epoch 5, loss: 0.6900863647460938
Epoch 10, loss: 0.6723355054855347
Epoch 15, loss: 0.6246037483215332
Epoch 20, loss: 0.5648944973945618
Epoch 25, loss: 0.5238027572631836
Epoch 30, loss: 0.4803387224674225
Epoch 35, loss: 0.45928460359573364
Epoch 40, loss: 0.4357365071773529
Epoch 45, loss: 0.41412580013275146
Epoch 50, loss: 0.3938811123371124
Epoch 55, loss: 0.3742051422595978
Epoch 60, loss: 0.35424721240997314
Epoch 65, loss: 0.3339700996875763
Epoch 70, loss: 0.31376728415489197
Epoch 75, loss: 0.2935040593147278
Epoch 80, loss: 0.2733391225337982
Epoch 85, loss: 0.2531461715698242
Epoch 90, loss: 0.2334868609905243
Epoch 95, loss: 0.21408666670322418
AUC 0.8475047730284584
