In [None]:
# Graph Convolutional Network
# https://www.youtube.com/watch?v=8qTnNXdkF1Q&list=PLSgGvve8UweGx4_6hhrF3n4wpHf_RV76_&index=6

# Karate club: 34 nodes which are people
# edges represent social interactions that occured outside of the context of the Karate club
#

In [34]:
import numpy as np
import networkx as nx
from networkx.algorithms.community.modularity_max import greedy_modularity_communities
import matplotlib.pyplot as plt
from matplotlib import animation
from scipy.linalg import sqrtm

In [3]:
g = nx.karate_club_graph()

In [4]:
g.number_of_nodes(), g.number_of_edges()

(34, 78)

In [5]:
A = nx.to_numpy_matrix(g)

In [30]:
A

matrix([[0., 4., 5., ..., 2., 0., 0.],
        [4., 0., 6., ..., 0., 0., 0.],
        [5., 6., 0., ..., 0., 2., 0.],
        ...,
        [2., 0., 0., ..., 0., 4., 4.],
        [0., 0., 2., ..., 4., 0., 5.],
        [0., 0., 0., ..., 4., 5., 0.]])

In [35]:
A_mod = A + np.eye(g.number_of_nodes()) # self-connections

# Normalization
D_mod = np.zeros_like(A_mod) # Degree Matrix
np.fill_diagonal(D_mod, np.asarray(A_mod.sum(axis=1)).flatten()) # number of connections for each node

D_mod_invroot = np.linalg.inv(sqrtm(D_mod))

# Note: @ -- is used for matrix multiplication, matrix * matrix brings a element-wise multiplication known as a Hadamard product
A_hat = D_mod_invroot @ A_mod @ D_mod_invroot # A_hat_i_j = 1/sqrt(d_i * d_j) * A_hat_i_j


In [44]:
# Features Matrix -- since we have no node features, we just use the identity matrix
X = np.eye(g.number_of_nodes()) # identity matrix effectively map each graph to a column of learnable parameters resulting in a full learnable embeddings

# Now we have a labels, normalized Adjacency matrix and input features => GCN implementation

In [None]:
# GCN layer is implemented in GCN layer class which can be able to stack into a larger GCN model
from turtle import forward


class GCN():
    def __init__(self, n_inputs, n_outputs, activation=None, name=''):
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.w = glorot_init(self.n_outputs, self.n_inputs)
        self.activation = activation
        self.name = name
    
    def __repr__(self):
        return f"GCN: W({'_'+self.name if self.name else ''} ({self.n_inputs}, {self.n_outputs}"

    def forward(self, A, X, W=None):
        '''
        Assumes A is (bs, bs) adjacency matrix and X is (bs, D)
        where bs = 'batch size' and D = input features length
        '''
        self._X = (A @ X).T # for calculating gradients (D, bs)

        if W is None:
            W = self.w
        
        H = W @ self._X # (h, D)*(D, bs) -> (h, bs)
        if self.activation is not None:
            H = self.activation()
        self._H = H # (h, bs)
        return self._H.T # (bs, h)

    def backward(self, optim, update=True):
        dtanh = 1 - np.asarray(self._H.T)**2 # (bs, out_dim)
        d2 = np.multiply(optim.out, dtanh) # (bs, out_dim) * element-wise *(bs, out_dim)
        


In [1]:
# Deep Graph Library
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import dgl.data

dataset = dgl.data.CoraGraphDataset()
print(f'Number of categories: {dataset.num_classes}')

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Number of categories: 7


In [6]:
g = dataset[0]

In [16]:
g.edata

{}

In [21]:
from dgl.nn import GraphConv, ChebConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes, k=2):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, h_feats)
        self.specConv = ChebConv(h_feats, num_classes, k)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        h = self.specConv(g, h)
        h = F.relu(h)
        return h

model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)

In [22]:
model

GCN(
  (conv1): GraphConv(in=1433, out=16, normalization=both, activation=None)
  (conv2): GraphConv(in=16, out=16, normalization=both, activation=None)
  (specConv): ChebConv(
    (linear): Linear(in_features=32, out_features=7, bias=True)
  )
)

In [30]:
def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    for e in range(500):
        # Forward
        logits = model(g, features)
        #print(f'Output from forward pass: {logits.shape}')
        # Compute predictions -- i.e. gather the maximal
        pred = logits.argmax(1)

        # Compute loss -- you should compute the losses of the nodes in the training set
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = ( pred[train_mask] == labels[train_mask] ).float().mean()
        val_acc = ( pred[val_mask] == labels[val_mask] ).float().mean()
        test_acc = ( pred[test_mask] == labels[test_mask] ).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if e % 5 == 0:
            print(f'In epoch {e}, loss {loss}, vall_acc: {val_acc}, best acc: {best_val_acc}, test acc: {test_acc}, best test acc: {best_test_acc}')

model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)


In epoch 0, loss 1.9472980499267578, vall_acc: 0.057999998331069946, best acc: 0.057999998331069946, test acc: 0.06400000303983688, best test acc: 0.06400000303983688
In epoch 5, loss 1.9188668727874756, vall_acc: 0.20600000023841858, best acc: 0.20600000023841858, test acc: 0.19599999487400055, best test acc: 0.19599999487400055
In epoch 10, loss 1.8453208208084106, vall_acc: 0.28200000524520874, best acc: 0.28200000524520874, test acc: 0.2639999985694885, best test acc: 0.2639999985694885
In epoch 15, loss 1.7023981809616089, vall_acc: 0.29600000381469727, best acc: 0.29600000381469727, test acc: 0.27900001406669617, best test acc: 0.27900001406669617
In epoch 20, loss 1.5109180212020874, vall_acc: 0.29600000381469727, best acc: 0.29600000381469727, test acc: 0.2800000011920929, best test acc: 0.27900001406669617
In epoch 25, loss 1.3367013931274414, vall_acc: 0.32199999690055847, best acc: 0.32199999690055847, test acc: 0.31299999356269836, best test acc: 0.31299999356269836
In epoc

In [31]:
# How Does DGL Represent a Graph
import dgl
import numpy as np
import torch

g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]), num_nodes=6)
# Equivalently LongTensor also works
g = dgl.graph((torch.LongTensor([0, 0, 0, 0, 0]), torch.LongTensor([1, 2, 3, 4, 5])), num_nodes=6)

# You can omit the number of nodes argument if you can tell the number of nodes from the edge list alone
g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]))


In [34]:
g.edges()[0].shape

torch.Size([5])

In [54]:
# Assign a 3-dimensional
g.ndata['x'] = torch.rand(6, 3)
# Assign a 4-dimensional edge feature vector for each edge
g.edata['a'] = torch.rand(5, 4)
# Assign a 5x4 node feature matrix for each node. Node and edge features in DGL can be multi-dimensional
g.ndata['y'] = torch.rand(6, 5, 4)

print(g.edata['a'])

tensor([[0.3221, 0.6773, 0.2418, 0.0775],
        [0.1706, 0.8468, 0.5743, 0.0085],
        [0.2340, 0.3074, 0.9905, 0.7342],
        [0.6691, 0.9219, 0.1042, 0.6407],
        [0.3158, 0.3419, 0.0140, 0.7359]])


In [43]:
# Querying Graph Structures
print(g.num_nodes())
print(g.num_edges())
# Out degrees of the center node
print(g.out_degrees(0))
# In degress of the center node - note that the graph is directed so the in degree should be zero
print(g.in_degrees(0))

6
5
5
0


In [44]:
# Graph Transofrmations 
# Introduce a subgraph from node 0. node 1 and node 3 from the original graph
sg1 = g.subgraph([0, 1, 3])
# Introduce a sobgraph from edge, edge 1 and edge 3 from the original graph
sg2 = g.edge_subgraph([0, 1, 3])

In [50]:
# The original IDs of each node in sg1
print(sg1.ndata[dgl.NID])
# The original IDs of each edge in sg1
print(sg1.edata[dgl.EID])
# The original IDs of each node in sg2
print(sg2.ndata[dgl.NID])
# The original IDs of each edge in sg2
print(sg2.edata[dgl.EID])

tensor([0, 1, 3])
tensor([0, 2])
tensor([0, 1, 2, 4])
tensor([0, 1, 3])


In [52]:
# subgraph and edge_subgraph also copies the original features to the subgraph
# The original node feature of each node in sg1
print(sg1.ndata['x'])
# The original edge feature of each node in sq1
print(sg1.edata['a'])
# The original node feature of each node in sg2
print(sg2.ndata['x'])
# The original edge feature of each node in sg2
print(sg2.edata['a'])

tensor([[0.9325, 0.1277, 0.3511],
        [0.7269, 0.4787, 0.5622],
        [0.0375, 0.8109, 0.0934],
        [0.6426, 0.2974, 0.7814]])


In [56]:
# Another common transformation is to add a reverse edge for each edge in the original graph with dgl.add_reverse_edges
# But if you have an undirected graph (have edges that do not have a direction and indicates two-way relationship) it is better to convert it into 
# a bidirectional graph first via adding reverse edges (bidirectional might have different edge weights between two adjacent vertices depending on the direction)

newg = dgl.add_reverse_edges(g)
newg.edges()

(tensor([0, 0, 0, 0, 0, 1, 2, 3, 4, 5]),
 tensor([1, 2, 3, 4, 5, 0, 0, 0, 0, 0]))

In [62]:
# Loading and Saving Graphs
# You can save a graph or a list of graphs via dgl.save_graphs and load them back with dgl.load_graphs
# Save Graphs
dgl.save_graphs('graph.dgl', g)
dgl.save_graphs('graphs.dgl', [g, sg1, sg2])

# Load Graphs
(g,), _ = dgl.load_graphs('graph.dgl')
print(g)
(g, sg1, sg2), _ = dgl.load_graphs('graphs.dgl')
print(g)
print(sg1)
print(sg2)

Graph(num_nodes=6, num_edges=5,
      ndata_schemes={'x': Scheme(shape=(3,), dtype=torch.float32), 'y': Scheme(shape=(5, 4), dtype=torch.float32)}
      edata_schemes={'a': Scheme(shape=(4,), dtype=torch.float32)})


In [64]:
# Write Your Own GNN module
# Sometimes your model goes beyond simply stacking existing GNN modules. E.g. you would like to invent a new way  of aggregating neighbor information by considering node importance or edge weights.
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

In [65]:
# Message Passing and GNNs
import dgl.function as fn

class SAGEConv(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output
        self.linear = nn.Linear(in_feat*2, out_feat)

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            # update_all is a mesage passing API
            g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)
            

In [68]:
# Afterwards, you can stack your own GraphSAGE convolution layers to form a multi-layer GraphSAGE network
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

# Training Loop
import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    all_logits = []
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    for e in range(100):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        #  Compute loss -- one should only compute the losses of the nodes in the trainng set, i.e. with train_mask = 1
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training / validation / test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean() 
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())

        if e %  5 == 0:
            print(f'In epoch {e}, loss: {loss}, vall acc: {val_acc}, best: {best_val_acc}, test acc: {test_acc}, best: {best_test_acc}')

model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.




In epoch 0, loss: 1.9515306949615479, vall acc: 0.15600000321865082, best: 0.15600000321865082, test acc: 0.14399999380111694, best: 0.14399999380111694
In epoch 5, loss: 1.8879996538162231, vall acc: 0.17800000309944153, best: 0.17800000309944153, test acc: 0.15800000727176666, best: 0.15800000727176666
In epoch 10, loss: 1.7564455270767212, vall acc: 0.5220000147819519, best: 0.5220000147819519, test acc: 0.49300000071525574, best: 0.49300000071525574
In epoch 15, loss: 1.5496182441711426, vall acc: 0.5460000038146973, best: 0.550000011920929, test acc: 0.5479999780654907, best: 0.5249999761581421
In epoch 20, loss: 1.2737606763839722, vall acc: 0.5879999995231628, best: 0.5879999995231628, test acc: 0.5789999961853027, best: 0.5789999961853027
In epoch 25, loss: 0.9600477814674377, vall acc: 0.6340000033378601, best: 0.6340000033378601, test acc: 0.6309999823570251, best: 0.6309999823570251
In epoch 30, loss: 0.6582165956497192, vall acc: 0.6800000071525574, best: 0.6800000071525574

In [69]:
# More Customization: DGL has got many built-in message and reduce functions under the dgl.function package
# https://docs.dgl.ai/api/python/dgl.function.html#apifunction
# E.g. Add Edge Weights
class WeightedSAGEConv(nn.Module):
    def __init__(self,in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        # a linear submodule for projecting the input and neighbor feature to rhe output
        self.linear = nn.Linear(in_feat * 2, out_feat)

    def forward(self, g, h, w):
        with g.local_scope():
            g.ndata['h'] = h
            g.edata['w'] = w
            g.update_all(message_func=fn.u_mul_e('h', 'w', 'm'), reduce_func=fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
        h = F.relu(h)
        h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
        return h

model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)



In epoch 0, loss: 1.951414942741394, vall acc: 0.3160000145435333, best: 0.3160000145435333, test acc: 0.3190000057220459, best: 0.3190000057220459
In epoch 5, loss: 1.8783222436904907, vall acc: 0.5680000185966492, best: 0.5680000185966492, test acc: 0.5559999942779541, best: 0.5559999942779541
In epoch 10, loss: 1.7482794523239136, vall acc: 0.2639999985694885, best: 0.5680000185966492, test acc: 0.26100000739097595, best: 0.5559999942779541
In epoch 15, loss: 1.5574053525924683, vall acc: 0.2980000078678131, best: 0.5680000185966492, test acc: 0.2939999997615814, best: 0.5559999942779541
In epoch 20, loss: 1.3111858367919922, vall acc: 0.4000000059604645, best: 0.5680000185966492, test acc: 0.39100000262260437, best: 0.5559999942779541
In epoch 25, loss: 1.0292094945907593, vall acc: 0.5199999809265137, best: 0.5680000185966492, test acc: 0.492000013589859, best: 0.5559999942779541
In epoch 30, loss: 0.7458023428916931, vall acc: 0.5820000171661377, best: 0.5820000171661377, test ac

In [None]:
# Even More Customization by user-defined Function
# DGL allows user-defined message and reduce functio for the maximal expresivness. Here is a user-defined message funcion that is equivalent to fn.u_mul_e('h', 'w', 'm')
def u_mul_e_udf(edges):
    return {'m': edges.src['h'] * edges.data['w']}

In [71]:
# Link Prediction using GNN, i.e. predicting the existence of an edge between two arbitrary nodes in a graph
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import numpy as np
import scipy.sparse as sp

In [72]:
# Many applications such as soical recommendation, item recommendation, knowledge graph competion can be formulated as link prediction, which predicts
# whether an edge exists between two particular nodes.
# Example: prediciton whether a citation raltionship, either citing or being cited, between two papers exists in a citation network
#    - binary classification
#    - treat the edgess in the graph as positive examples
#    - sample a number of non-existing edges (i.e. node pairs with no edges between them) as negative examples
#    - divide the positive examples and negative examples into a training set and a test set
#    - evaluate the model with any binary classification metric such as AUC

import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [78]:
# Split edge set for training and testing
u,v = g.edges()

eids = np.arange(g.number_of_edges())
eids = np.random.permutation(eids)
test_size = int(len(eids) * 0.1) # 10 % for testing purposes
train_size = g.number_of_edges() - test_size
test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]]

# Find all negative edges and split them fo training and testing
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = 1 - adj.todense() - np.eye(g.number_of_nodes())
neg_u, neg_v = np.where(adj_neg != 0)

neg_eids = np.random.choice(len(neg_u), g.number_of_edges())
test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]]
train_neg_u, train_neg_v = neg_u[neg_eids[test_size:]], neg_v[neg_eids[test_size:]]


In [79]:
# remove edges in the test set from the original graph
train_g = dgl.remove_edges(g, eids[:test_size])

In [80]:
# Define a GrapSAGE Model
from dgl.nn import SAGEConv

#----------------2. create model -----------------------------#
# build a two-layer GraphSAGE model
class GraphSAGE(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, 'mean')
        self.conv2 = SAGEConv(h_feats, h_feats, 'mean')

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

In [81]:
# Now, the model predicts probability of existence of an edge by computing a score between the representations of both incident nodes with a function (e.g. an MLP or a dot product)

# Positive graph, negative graph
train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes())
train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes()) 

test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes())

In [82]:
import dgl.function as fn

class DotPredictor(nn.Module):
    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            # Compute a new edge feature names 'score' by a dot-product between the source node feature 'h' and destination node feature 'h'
            g.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            # u_dot_v returns a 1-element vector for each edge so you need to squeeze it
            return g.edata['score'][:, 0]

In [83]:
# One can also write own function if it is complex, e.g. the following module produces a scalar score on each edge by concatenating the incident nodes features and passing it to MLP
class MLPPredictor(nn.Module):
    def __init__(self, h_feats):
        super().__init__()
        self.W1 = nn.Linear(h_feats * 2, h_feats)
        self.W2 = nn.Linera(h_feats, 1) # output is a scalar

    def apply_edges(self, edges):
        '''
        Computes a scalar score for each edge of the given graph.
        Parameters
        ----------
        edges:
            Has three members ''src'', ''dst'' and ''data'', each of which is a dictionary representing the features of the source nodes, the destination nodes and the edges themselves.
        Returns
        -------
        dict
            A dictionary of new edge features
        '''
        h = torch.cat([edges.src['h'], edges.dst['h']], 1)
        return {'score': self.W2(F.relu(self.W1(h))).squeeze(1)}

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            g.apply_edges(self.apply_edges)
            return g.edata['score']

In [86]:
# Training Loop
# loss function: binary cross-entropy
from sklearn.metrics import roc_auc_score

model = GraphSAGE(train_g.ndata['feat'].shape[1], 16)
# One can replace DotPredictor with MLPPredictor
# pred = MLPPredictor(16)
pred = DotPredictor()

def compute_loss(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])])
    return F.binary_cross_entropy_with_logits(scores, labels)

def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).numpy()
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
    return roc_auc_score(labels, scores)

In [87]:
# -------------- 3. set up loss and optimizer -------------------#
# in this case, loss will in training loop
optimizer = torch.optim.Adam(itertools.chain(model.parameters(), pred.parameters()), lr=0.01)

# -------------- 4. training ---------------------------------- #
all_logits = []
for e in range(100):
    # forward pass
    h = model(train_g, train_g.ndata['feat'])
    pos_score = pred(train_pos_g, h)
    neg_score = pred(train_neg_g, h)
    loss = compute_loss(pos_score, neg_score)

    # backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if e % 5 == 0:
        print(f'In epoch {e}, loss: {loss}')

# -------------5. check the results -------------------------- #
with torch.no_grad():
    pos_score = pred(test_pos_g, h)
    neg_score = pred(test_neg_g, h)
    print('AUC', compute_auc(pos_score, neg_score))



In epoch 0, loss: 0.6929958462715149
In epoch 5, loss: 0.664452314376831
In epoch 10, loss: 0.5835233330726624
In epoch 15, loss: 0.5399207472801208
In epoch 20, loss: 0.5070629119873047
In epoch 25, loss: 0.48428264260292053
In epoch 30, loss: 0.4538520574569702
In epoch 35, loss: 0.4277080297470093
In epoch 40, loss: 0.4002706706523895
In epoch 45, loss: 0.37552323937416077
In epoch 50, loss: 0.35206684470176697
In epoch 55, loss: 0.32759401202201843
In epoch 60, loss: 0.30359822511672974
In epoch 65, loss: 0.27989843487739563
In epoch 70, loss: 0.25594303011894226
In epoch 75, loss: 0.23273035883903503
In epoch 80, loss: 0.20979833602905273
In epoch 85, loss: 0.18734630942344666
In epoch 90, loss: 0.16602729260921478
In epoch 95, loss: 0.145292729139328
AUC 0.8608620650928774
