# Graph Auto Encoder with PyG

In [1]:
import argparse
import os
import time

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid

from torch_geometric.nn import GAE, GCNConv

In [2]:
device = torch.device('cpu')

In [3]:
DATASET_NAME="Cora"

In [35]:
transform = T.Compose([
    T.NormalizeFeatures(),
    T.RandomLinkSplit(num_val=0., num_test=0.1, is_undirected=True,
                      split_labels=True, add_negative_train_samples=False),
])
path = os.path.join("/home/deusebio/Personal/graph_machine_learning/data", 'data')
dataset = Planetoid(path, DATASET_NAME, transform=transform)
train_data, val_data, test_data = dataset[0]

In [40]:
print(f"Train edges (positive): {train_data.pos_edge_label_index.shape[1]}")
print(f"Test edges (positive): {test_data.pos_edge_label_index.shape[1]}")
print(f"Test edges (negative): {test_data.neg_edge_label_index.shape[1]}")

Train edges (positive): 4751
Test edges (positive): 527
Test edges (negative): 527


In [5]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, num_node_features, num_embedding):
        super().__init__()
        self.conv1 = GCNConv(num_node_features, 2 * num_embedding)
        self.conv2 = GCNConv(2 * num_embedding, num_embedding)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

In [6]:
n_features = dataset.num_features
n_embeddings = 20

In [7]:
model = GAE(GCNEncoder(n_features, n_embeddings))

In [8]:
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [9]:
for epoch in range(20):  # loop over the dataset multiple times

    model.train()

    # zero the parameter gradients
    optimizer.zero_grad()

    z = model.encode(train_data.x, train_data.edge_index)
    loss = model.recon_loss(z, train_data.pos_edge_label_index)

    loss.backward()
    optimizer.step()
    
    # Test/Evaluate
    model.eval()
    z = model.encode(test_data.x, test_data.edge_index)
    auc, ap = model.test(z, test_data.pos_edge_label_index, test_data.neg_edge_label_index)
    
    print(f"Performance on validation set => AUC: {auc} AP: {ap}")

Performance on validation set => AUC: 0.7277003841874633 AP: 0.751380623617229
Performance on validation set => AUC: 0.7216693251334934 AP: 0.7450836500826495
Performance on validation set => AUC: 0.7201516586312556 AP: 0.7438195305144295
Performance on validation set => AUC: 0.7177356343773966 AP: 0.7446658389934638
Performance on validation set => AUC: 0.7144770621721173 AP: 0.7463040868711021
Performance on validation set => AUC: 0.7097044240968714 AP: 0.7458711201460098
Performance on validation set => AUC: 0.7041108418638313 AP: 0.7440737868933852
Performance on validation set => AUC: 0.7006038260318512 AP: 0.7420508132883922
Performance on validation set => AUC: 0.699102362374833 AP: 0.7411833809196392
Performance on validation set => AUC: 0.6959626110344977 AP: 0.739441047817806
Performance on validation set => AUC: 0.6908227084676068 AP: 0.7366122214404001
Performance on validation set => AUC: 0.6845666098966978 AP: 0.7315068175571388
Performance on validation set => AUC: 0.683