In [None]:
import torch
import torch_geometric.transforms as T
import torch.nn as nn

from tqdm import tqdm
from torch.optim import Adam
from torch.nn import functional as F

from torch_geometric.datasets import Planetoid
from torch_geometric.utils import train_test_split_edges
from torch_geometric.nn import VGAE, GraphConv
from torch.utils.tensorboard import SummaryWriter

In [2]:
# Normalization of adjacent matrix  T.NormalizeFeatures
data = Planetoid(root='', name='Cora', transform=T.NormalizeFeatures())[0]
data.train_mask = data.val_mask = data.test_mask = None
data = train_test_split_edges(data)



In [3]:
#pos and neg represent whether they exist in original graph.
data

Data(x=[2708, 1433], y=[2708], val_pos_edge_index=[2, 263], test_pos_edge_index=[2, 527], train_pos_edge_index=[2, 8976], train_neg_adj_mask=[2708, 2708], val_neg_edge_index=[2, 263], test_neg_edge_index=[2, 527])

In [4]:
epochs = 200
lr = 0.01
hidden1 = 32
hidden2 = 16
# In fact ,there isn't dropout in paper, but when I reproduced without PyG, I added it, so it's reserved.

input_channels = data.num_node_features
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
input_channels, device

(1433, device(type='cuda', index=0))

In [5]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden1, hidden2, dropout):
        super(Encoder, self).__init__()
        # first layer of GCN, share parameters for the second layer
        self.conv1 = GraphConv(input_dim, hidden1)
        # parameter of Gaussian distribution：mean mu
        self.conv_mu = GraphConv(hidden1, hidden2)
        # parameter of Gaussian distribution：log of standard deviation
        self.conv_logstd = GraphConv(hidden1, hidden2)
        self.dropout = dropout
        self.edge_index = None

    def conv_layer(self, x, conv):
        # this function is redundant, you can merge it to 'forward' directly.
        x = F.dropout(x, p=self.dropout, training=self.training)
        hidden = conv(x, self.edge_index)
        hidden = F.relu(hidden, True)
        return hidden

    def forward(self, x, edge_index):
        self.edge_index = edge_index
        hidden1 = self.conv_layer(x, self.conv1)
        return self.conv_mu(hidden1, edge_index), self.conv_logstd(hidden1, edge_index)


In [6]:
encoder = Encoder(input_channels, hidden1, hidden2, dropout)
# decoder is InnerProduct by default, so I only add encoder
model = VGAE(encoder).to(device)
optimizer = Adam(model.parameters(), lr)

In [7]:
x = data.x.to(device)
train_pos_edge_index = data.train_pos_edge_index.to(device)

In [8]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(x, train_pos_edge_index)
    loss = model.recon_loss(z, train_pos_edge_index) + model.kl_loss() / data.num_nodes
    loss.backward()
    optimizer.step()
    return loss.item()

def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        # because this is a transductive task，it'll generate a complete graph using train_pos_edge.
        z = model.encode(x, train_pos_edge_index)
        # pos_edge represent existent edges，neg_edge represent inexistent edges
        return model.test(z, pos_edge_index, neg_edge_index)

In [9]:
# use tensorbaord to record the process of training
writer = SummaryWriter('runs/VGAE_experiment_'+'200 epochs')

In [10]:
for epoch in range(1, epochs + 1):
    loss = train()
    auc, ap = test(data.test_pos_edge_index, data.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))
    writer.add_scalar('auc train',auc,epoch)
    writer.add_scalar('ap train',ap,epoch)   

Epoch: 001, AUC: 0.5933, AP: 0.6404
Epoch: 002, AUC: 0.5867, AP: 0.6374
Epoch: 003, AUC: 0.5906, AP: 0.6390
Epoch: 004, AUC: 0.6018, AP: 0.6453
Epoch: 005, AUC: 0.6138, AP: 0.6564
Epoch: 006, AUC: 0.6546, AP: 0.6947
Epoch: 007, AUC: 0.7001, AP: 0.7233
Epoch: 008, AUC: 0.7018, AP: 0.7170
Epoch: 009, AUC: 0.7008, AP: 0.7191
Epoch: 010, AUC: 0.7038, AP: 0.7311
Epoch: 011, AUC: 0.6971, AP: 0.7354
Epoch: 012, AUC: 0.6965, AP: 0.7360
Epoch: 013, AUC: 0.7089, AP: 0.7416
Epoch: 014, AUC: 0.7221, AP: 0.7480
Epoch: 015, AUC: 0.7179, AP: 0.7466
Epoch: 016, AUC: 0.7179, AP: 0.7467
Epoch: 017, AUC: 0.7182, AP: 0.7478
Epoch: 018, AUC: 0.7225, AP: 0.7528
Epoch: 019, AUC: 0.7266, AP: 0.7571
Epoch: 020, AUC: 0.7304, AP: 0.7610
Epoch: 021, AUC: 0.7330, AP: 0.7645
Epoch: 022, AUC: 0.7378, AP: 0.7686
Epoch: 023, AUC: 0.7411, AP: 0.7704
Epoch: 024, AUC: 0.7447, AP: 0.7721
Epoch: 025, AUC: 0.7495, AP: 0.7758
Epoch: 026, AUC: 0.7532, AP: 0.7797
Epoch: 027, AUC: 0.7549, AP: 0.7818
Epoch: 028, AUC: 0.7578, AP: