In [37]:
import numpy as np
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, VGAE

device = 'cpu'
transform = T.Compose([
    T.NormalizeFeatures(),
    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True, add_negative_train_samples=False),
])

dataset = Planetoid('.', name='Cora', transform=transform)
train_data, val_data, test_data = dataset[0]
train_data, val_data, test_data = train_data.to(device), val_data.to(device), test_data.to(device)

In [38]:
class Encoder(torch.nn.Module):
    def __init__(self, dim_in, hidden_dim = 32, dim_out = 16):
        super().__init__()
        self.fc1 = GCNConv(dim_in, hidden_dim)
        self.fc_average = GCNConv(hidden_dim, dim_out)
        self.fc_log_variance = GCNConv(hidden_dim, dim_out)
 
    def forward(self, x, edge_index):
        x = self.fc1(x, edge_index).relu()
        return self.fc_average(x, edge_index), self.fc_log_variance(x, edge_index)
 
model = VGAE(Encoder(dataset.num_features, 16)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [39]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)
    loss = model.recon_loss(z, train_data.pos_edge_label_index) + (1 / train_data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def test(data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)


for epoch in range(201):
    loss = train()
    test_auc, test_ap = test(test_data)
    if epoch % 50 == 0:
        print('Train loss : {}'.format(loss))
        print('Test AUC : {} | Test AP : {}'.format(test_auc, test_ap))

Train loss : 3.416130304336548
Test AUC : 0.668675219368521 | Test AP : 0.6958206661143584
Train loss : 1.3364050388336182
Test AUC : 0.635670023656154 | Test AP : 0.670486426553272
Train loss : 1.299009919166565
Test AUC : 0.637009458860976 | Test AP : 0.6720657173192042
Train loss : 1.1670433282852173
Test AUC : 0.740761677750613 | Test AP : 0.7413474020737971
Train loss : 1.0603376626968384
Test AUC : 0.8202852420885107 | Test AP : 0.8246690433361138


In [34]:
test_auc, test_ap = test(val_data)


Test AUC: 0.8688 | Test AP: 0.8601
