In [1]:
import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid, Actor
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import to_networkx
from copy import deepcopy
from scipy.sparse import coo_matrix
from sklearn.metrics import roc_auc_score
import itertools
import dgl
from dgl.nn import SAGEConv
import dgl.function as fn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cora = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())

def remove_edges(G, edges):
    G_new = deepcopy(G)
    G_new.remove_edges_from(edges)
    return G_new

def create_train_test_split_edge(data):
    data = cora[0]

    # Create a list of positive and negative edges
    u, v = data.edge_index.numpy()

    adj = coo_matrix((np.ones(data.num_edges), data.edge_index.numpy()))
    adj_neg = 1 - adj.todense() - np.eye(data.num_nodes)
    neg_u, neg_v = np.where(adj_neg != 0)

    # Create train/test edge split
    test_size = int(np.floor(data.num_edges * 0.1))
    eids = np.random.permutation(np.arange(data.num_edges)) # Create an array of 'edge IDs'

    train_pos_u, train_pos_v = data.edge_index[:, eids[test_size:]]
    test_pos_u, test_pos_v   = data.edge_index[:, eids[:test_size]]

    # Sample an equal amount of negative edges from  the graph, split into train/test
    neg_eids = np.random.choice(len(neg_u), data.num_edges)
    test_neg_u, test_neg_v = (
        neg_u[neg_eids[:test_size]],
        neg_v[neg_eids[:test_size]],
    )
    train_neg_u, train_neg_v = (
        neg_u[neg_eids[test_size:]],
        neg_v[neg_eids[test_size:]],
    )

    # Remove test edges from original graph
    G = to_networkx(data, node_attrs=data.node_attrs(), to_undirected=data.is_undirected())
    G_train = remove_edges(G, np.column_stack([test_pos_u, test_pos_v])) 

    train_g = dgl.from_networkx(G_train, node_attrs=list(G.nodes[0].keys()))

    train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=data.num_nodes)
    train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=data.num_nodes)

    test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=data.num_nodes)
    test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=data.num_nodes)

    return train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g

def compute_loss(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
    )
    return F.binary_cross_entropy_with_logits(scores, labels)


def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
    ).numpy()
    return roc_auc_score(labels, scores)

In [3]:
# ----------- 2. create model -------------- #
# build a two-layer GraphSAGE model
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_feats, h_feats):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, "mean")
        self.conv2 = SAGEConv(h_feats, h_feats, "mean")

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h
    

class DotPredictor(torch.nn.Module):
    def forward(self, g, h):
        with g.local_scope():
            g.ndata["h"] = h
            # Compute a new edge feature named 'score' by a dot-product between the
            # source node feature 'h' and destination node feature 'h'.
            g.apply_edges(fn.u_dot_v("h", "h", "score"))
            # u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
            return g.edata["score"][:, 0]
        

class MLPPredictor(torch.nn.Module):
    def __init__(self, in_feats):
        super().__init__()
        self.linear1 = torch.nn.Linear(in_feats, 32)
        self.linear2 = torch.nn.Linear(32, 1)

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            g.apply_edges(fn.u_add_v("h", "h", "score"))

            score = g.edata['score']

            score = self.linear1(score)
            score = F.relu(score)
            score = self.linear2(score)
            # u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
            return score.squeeze()

In [4]:
# Cora model:
train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g = create_train_test_split_edge(cora[0])

model = GraphSAGE(train_g.ndata["x"].shape[1], 32)
pred = MLPPredictor(32)
optimizer = torch.optim.Adam(
    itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)

In [5]:
# ----------- 4. training -------------------------------- #
all_logits = []
for e in range(1001):
    # forward
    h = model(train_g, train_g.ndata["x"])
    pos_score = pred(train_pos_g, h)
    neg_score = pred(train_neg_g, h)
    loss = compute_loss(pos_score, neg_score)

    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if e % 5 == 0:
        print("In epoch {}, loss: {}".format(e, loss))

    # ----------- 5. check results ------------------------ #
    if e % 100 == 0:
        with torch.no_grad():
            pos_score = pred(test_pos_g, h)
            neg_score = pred(test_neg_g, h)
            print("AUC", compute_auc(pos_score, neg_score))

In epoch 0, loss: 0.6941514015197754
AUC 0.4900258305069518
In epoch 5, loss: 0.6813181638717651
In epoch 10, loss: 0.6224610805511475
In epoch 15, loss: 0.6055197715759277
In epoch 20, loss: 0.5536144971847534
In epoch 25, loss: 0.5455098748207092
In epoch 30, loss: 0.5327537059783936
In epoch 35, loss: 0.5170523524284363
In epoch 40, loss: 0.5034334063529968
In epoch 45, loss: 0.487308144569397
In epoch 50, loss: 0.46476295590400696
In epoch 55, loss: 0.4387530982494354
In epoch 60, loss: 0.4108203947544098
In epoch 65, loss: 0.3884579837322235
In epoch 70, loss: 0.36885544657707214
In epoch 75, loss: 0.34917932748794556
In epoch 80, loss: 0.3297191262245178
In epoch 85, loss: 0.3122040927410126
In epoch 90, loss: 0.2994556725025177
In epoch 95, loss: 0.28140905499458313
In epoch 100, loss: 0.2747761607170105
AUC 0.7697859437119562
In epoch 105, loss: 0.25600486993789673
In epoch 110, loss: 0.24686764180660248
In epoch 115, loss: 0.2364881932735443
In epoch 120, loss: 0.2271429300308

In [6]:
#Actor model
actor = Actor(root='data/Actor', transform=NormalizeFeatures())

# Cora model:
train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g = create_train_test_split_edge(actor[0])

model = GraphSAGE(train_g.ndata["x"].shape[1], 16)
pred = DotPredictor()
optimizer = torch.optim.Adam(
    itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)

In [7]:
# ----------- 4. training -------------------------------- #
all_logits = []
for e in range(100):
    # forward
    h = model(train_g, train_g.ndata["x"])
    pos_score = pred(train_pos_g, h)
    neg_score = pred(train_neg_g, h)
    loss = compute_loss(pos_score, neg_score)

    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if e % 5 == 0:
        print("In epoch {}, loss: {}".format(e, loss))

# ----------- 5. check results ------------------------ #


with torch.no_grad():
    pos_score = pred(test_pos_g, h)
    neg_score = pred(test_neg_g, h)
    print("AUC", compute_auc(pos_score, neg_score))

In epoch 0, loss: 0.705256998538971
In epoch 5, loss: 0.6895120143890381
In epoch 10, loss: 0.6700869798660278
In epoch 15, loss: 0.6250309348106384
In epoch 20, loss: 0.5564284920692444
In epoch 25, loss: 0.5070099830627441
In epoch 30, loss: 0.4828763008117676
In epoch 35, loss: 0.4588315188884735
In epoch 40, loss: 0.43904662132263184
In epoch 45, loss: 0.4211689829826355
In epoch 50, loss: 0.4033316969871521
In epoch 55, loss: 0.38598501682281494
In epoch 60, loss: 0.36767756938934326
In epoch 65, loss: 0.3482709527015686
In epoch 70, loss: 0.327592670917511
In epoch 75, loss: 0.30565235018730164
In epoch 80, loss: 0.2826615273952484
In epoch 85, loss: 0.2592843472957611
In epoch 90, loss: 0.23580880463123322
In epoch 95, loss: 0.21299971640110016
AUC 0.8322247927944115
