In [1]:
import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid, Actor
from torch_geometric.transforms import NormalizeFeatures
from copy import deepcopy
import itertools

In [2]:
from torch import nn

In [3]:
import os

os.chdir('/'.join(os.getcwd().split('/')[:-2]))
from src import *

# CORA

## GraphSAGE

### Mean Agg

In [4]:
# Cora model:
cora = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())
train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g = create_train_test_split_edge(cora[0])

model = GraphSAGE(train_g.ndata["x"].shape[1], 32)
pred = MLPPredictor(32)
optimizer = torch.optim.Adam(
    itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)

In [5]:
train_link_pred(1000, model, pred, optimizer, train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g)

In epoch 5, loss: 0.6881262063980103
In epoch 10, loss: 0.64781653881073
In epoch 15, loss: 0.5861732959747314
In epoch 20, loss: 0.5601806640625
In epoch 25, loss: 0.5454561710357666
In epoch 30, loss: 0.5283210277557373
In epoch 35, loss: 0.5114594101905823
In epoch 40, loss: 0.4908011257648468
In epoch 45, loss: 0.47150540351867676
In epoch 50, loss: 0.44645392894744873
In epoch 55, loss: 0.419136643409729
In epoch 60, loss: 0.39415863156318665
In epoch 65, loss: 0.3690391480922699
In epoch 70, loss: 0.35145696997642517
In epoch 75, loss: 0.3312009871006012
In epoch 80, loss: 0.31218698620796204
In epoch 85, loss: 0.29176223278045654
In epoch 90, loss: 0.2872620224952698
In epoch 95, loss: 0.27275702357292175
In epoch 100, loss: 0.25841566920280457
AUC 0.802616293434559
In epoch 105, loss: 0.24204814434051514
In epoch 110, loss: 0.23821844160556793
In epoch 115, loss: 0.22772371768951416
In epoch 120, loss: 0.21526092290878296
In epoch 125, loss: 0.21264280378818512
In epoch 130, lo

### Pool Agg

In [6]:
model = GraphSAGE(train_g.ndata["x"].shape[1], 32, agg='pool')
pred = MLPPredictor(32)
optimizer = torch.optim.Adam(
    itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)

In [7]:
train_link_pred(1000, model, pred, optimizer, train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g)

In epoch 5, loss: 0.6897978186607361
In epoch 10, loss: 0.6589993238449097
In epoch 15, loss: 0.5966556072235107
In epoch 20, loss: 0.5643706917762756
In epoch 25, loss: 0.543643057346344
In epoch 30, loss: 0.5310526490211487
In epoch 35, loss: 0.5145068168640137
In epoch 40, loss: 0.5015579462051392
In epoch 45, loss: 0.48903724551200867
In epoch 50, loss: 0.4756827652454376
In epoch 55, loss: 0.4582759737968445
In epoch 60, loss: 0.4319886565208435
In epoch 65, loss: 0.42698249220848083
In epoch 70, loss: 0.38241657614707947
In epoch 75, loss: 0.3673381507396698
In epoch 80, loss: 0.33755603432655334
In epoch 85, loss: 0.3145856261253357
In epoch 90, loss: 0.29491522908210754
In epoch 95, loss: 0.2786233723163605
In epoch 100, loss: 0.26284417510032654
AUC 0.7924314368500258
In epoch 105, loss: 0.2461032122373581
In epoch 110, loss: 0.23793070018291473
In epoch 115, loss: 0.2817286252975464
In epoch 120, loss: 0.22139926254749298
In epoch 125, loss: 0.21593394875526428
In epoch 130, 

### LSTM Agg

In [8]:
model = GraphSAGE(train_g.ndata["x"].shape[1], 32, agg='lstm')
pred = MLPPredictor(32)
optimizer = torch.optim.Adam(
    itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)

In [9]:
train_link_pred(1000, model, pred, optimizer, train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g)

In epoch 5, loss: 0.6883365511894226
In epoch 10, loss: 0.6506872773170471
In epoch 15, loss: 0.5944363474845886
In epoch 20, loss: 0.557481586933136
In epoch 25, loss: 0.533795177936554
In epoch 30, loss: 0.5168571472167969
In epoch 35, loss: 0.49965983629226685
In epoch 40, loss: 0.48354214429855347
In epoch 45, loss: 0.4647616744041443
In epoch 50, loss: 0.44313862919807434
In epoch 55, loss: 0.42548108100891113
In epoch 60, loss: 0.3932667076587677
In epoch 65, loss: 0.36693188548088074
In epoch 70, loss: 0.34153419733047485
In epoch 75, loss: 0.31989893317222595
In epoch 80, loss: 0.3068668842315674
In epoch 85, loss: 0.29772743582725525
In epoch 90, loss: 0.2787810266017914
In epoch 95, loss: 0.2638188600540161
In epoch 100, loss: 0.25307178497314453
AUC 0.7940109161968509
In epoch 105, loss: 0.24520859122276306
In epoch 110, loss: 0.2375347763299942
In epoch 115, loss: 0.22642888128757477
In epoch 120, loss: 0.22982603311538696
In epoch 125, loss: 0.21471457183361053
In epoch 13

### GCN Agg

In [10]:
model = GraphSAGE(train_g.ndata["x"].shape[1], 32, agg='gcn')
pred = MLPPredictor(32)
optimizer = torch.optim.Adam(
    itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)

In [11]:
train_link_pred(1000, model, pred, optimizer, train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g)

In epoch 5, loss: 0.6881817579269409
In epoch 10, loss: 0.653206467628479
In epoch 15, loss: 0.5981858372688293
In epoch 20, loss: 0.5748456716537476
In epoch 25, loss: 0.5513438582420349
In epoch 30, loss: 0.5288108587265015
In epoch 35, loss: 0.5084074139595032
In epoch 40, loss: 0.49039632081985474
In epoch 45, loss: 0.47301188111305237
In epoch 50, loss: 0.45799973607063293
In epoch 55, loss: 0.4435465931892395
In epoch 60, loss: 0.4325772523880005
In epoch 65, loss: 0.42378830909729004
In epoch 70, loss: 0.41575106978416443
In epoch 75, loss: 0.4000803232192993
In epoch 80, loss: 0.39086830615997314
In epoch 85, loss: 0.3740013837814331
In epoch 90, loss: 0.3617696762084961
In epoch 95, loss: 0.34915027022361755
In epoch 100, loss: 0.3380183279514313
AUC 0.7963477909301228
In epoch 105, loss: 0.3275896906852722
In epoch 110, loss: 0.3128474950790405
In epoch 115, loss: 0.3060632050037384
In epoch 120, loss: 0.28964951634407043
In epoch 125, loss: 0.2735818922519684
In epoch 130, l

## GraphEVE

In [13]:
model = GraphEVE(train_g.ndata["x"].shape[1], 32)
pred = MLPPredictor(32)
optimizer = torch.optim.Adam(
    itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)

In [14]:
train_link_pred(1000, model, pred, optimizer, train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g)

  graph.dstdata['eve']=self.relu(self.dw_conv(tt.T).T)[:,:,0]


In epoch 5, loss: 0.7084175944328308
In epoch 10, loss: 0.6981925964355469
In epoch 15, loss: 0.6932471990585327
In epoch 20, loss: 0.6938855648040771
In epoch 25, loss: 0.6939219832420349
In epoch 30, loss: 0.6933499574661255
In epoch 35, loss: 0.6932055950164795
In epoch 40, loss: 0.6932910680770874
In epoch 45, loss: 0.6932910084724426
In epoch 50, loss: 0.6932286024093628
In epoch 55, loss: 0.6931902170181274
In epoch 60, loss: 0.6931816339492798
In epoch 65, loss: 0.6931800246238708
In epoch 70, loss: 0.6931745409965515
In epoch 75, loss: 0.6931652426719666
In epoch 80, loss: 0.6931540369987488
In epoch 85, loss: 0.6931411623954773
In epoch 90, loss: 0.6931257247924805
In epoch 95, loss: 0.693105936050415
In epoch 100, loss: 0.6930797696113586
AUC 0.6892801150019091
In epoch 105, loss: 0.6930434107780457
In epoch 110, loss: 0.6929918527603149
In epoch 115, loss: 0.692916214466095
In epoch 120, loss: 0.6928010582923889
In epoch 125, loss: 0.6926175355911255
In epoch 130, loss: 0.69

# Actors

In [15]:
#Actor model
actor = Actor(root='data/Actor', transform=NormalizeFeatures())
train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g = create_train_test_split_edge(actor[0])

model = GraphSAGE(train_g.ndata["x"].shape[1], 16, agg='mean')
pred = DotPredictor()
optimizer = torch.optim.Adam(
    itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)

In [16]:
train_link_pred(100, model, pred, optimizer, train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g)

In epoch 5, loss: 0.6714571714401245
In epoch 10, loss: 0.6475620269775391
In epoch 15, loss: 0.6315845847129822
In epoch 20, loss: 0.6195549964904785
In epoch 25, loss: 0.6060764789581299
In epoch 30, loss: 0.5928593277931213
In epoch 35, loss: 0.580198347568512
In epoch 40, loss: 0.5656442642211914
In epoch 45, loss: 0.5499475598335266
In epoch 50, loss: 0.5341180562973022
In epoch 55, loss: 0.5176094770431519
In epoch 60, loss: 0.5032098889350891
In epoch 65, loss: 0.4871290326118469
In epoch 70, loss: 0.47250697016716003
In epoch 75, loss: 0.4589647948741913
In epoch 80, loss: 0.4460175335407257
In epoch 85, loss: 0.4332176148891449
In epoch 90, loss: 0.4206375777721405
In epoch 95, loss: 0.40838930010795593
In epoch 100, loss: 0.3964161276817322
AUC 0.7034637237992756
