In [69]:
!pip install torch-geometric



In [70]:
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import GCNConv, GATConv, GINConv
from torch_geometric.datasets import Planetoid, DeezerEurope
from torch_geometric.utils import train_test_split_edges, negative_sampling, subgraph
from sklearn.model_selection import train_test_split

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

corafull = Planetoid(root='CoraFull', name='Cora')[0]
deezerEurope = DeezerEurope(root='DeezerEurope')[0]

datasets = {'corafull': corafull.to(device), 'deezerEurope': deezerEurope.to(device)}

In [71]:
print (corafull)
print (deezerEurope)

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
Data(x=[28281, 128], edge_index=[2, 185504], y=[28281])


In [72]:
# Split Datasets into train, test and validation
train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2

num_nodes = {}
train_nodes = {}
test_nodes = {}
val_nodes = {}
edge_index = {}
for dataset_name in datasets:
    num_nodes[dataset_name] = datasets[dataset_name].num_nodes
    train_n, test_n = train_test_split(range(num_nodes[dataset_name]), test_size=test_ratio, random_state=42)
    train_n, val_n = train_test_split(train_n, test_size=val_ratio / (train_ratio + val_ratio), random_state=42)
    train_nodes[dataset_name] = train_n
    test_nodes[dataset_name] = test_n
    val_nodes[dataset_name] = val_n
    edge_index[dataset_name] = datasets[dataset_name].edge_index

In [73]:
# Produce the induced subgraph of train set nodes.
train_mask = {}
test_mask = {}
val_mask = {}

for dataset_name in datasets:
    train_mask[dataset_name] = torch.zeros(num_nodes[dataset_name], dtype=torch.bool).to(device)
    train_mask[dataset_name][train_nodes[dataset_name]] = True
    
    test_mask[dataset_name] = torch.zeros(num_nodes[dataset_name], dtype=torch.bool).to(device)
    test_mask[dataset_name][test_nodes[dataset_name]] = True
    
    val_mask[dataset_name] = torch.zeros(num_nodes[dataset_name], dtype=torch.bool).to(device)
    val_mask[dataset_name][val_nodes[dataset_name]] = True

In [74]:
# Produces a separated edge index.
# Sample non-neighbours from the train portion of the graph.
# The non-neighbour set will be the same size as the training edge index.
# The mask ensures that only the nodes in the training set are used.
train_edge_index = {}
val_edge_index = {}
val_neg_edge_index = {}
test_edge_index = {}
test_neg_edge_index = {}

for dataset_name in datasets:
    train_ei, _ = subgraph(train_nodes[dataset_name], edge_index[dataset_name], relabel_nodes=True)
    train_edge_index[dataset_name] = train_ei.to(device)
    
    val_ei, _ = subgraph(val_nodes[dataset_name], edge_index[dataset_name], relabel_nodes=True)
    val_edge_index[dataset_name] = val_ei.to(device)
    
    val_neg_edge_index[dataset_name] = negative_sampling(edge_index=val_edge_index[dataset_name], 
                                                         num_nodes=num_nodes[dataset_name], 
                                                         num_neg_samples=val_edge_index[dataset_name].size(1)).to(device)

    test_ei, _ = subgraph(test_nodes[dataset_name], edge_index[dataset_name], relabel_nodes=True)
    test_edge_index[dataset_name] = test_ei.to(device)
    
    test_neg_edge_index[dataset_name] = negative_sampling(edge_index=test_edge_index[dataset_name],
                                                          num_nodes=num_nodes[dataset_name],
                                                          num_neg_samples=test_edge_index[dataset_name].size(1)).to(device)

In [75]:
# GNN Model Definition
class GNNModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels=64, out_channels=64, gnn_type="GCN"):
        super(GNNModel, self).__init__()
        if gnn_type == "GCN":
            self.conv1 = GCNConv(in_channels, hidden_channels)
            self.conv2 = GCNConv(hidden_channels, out_channels)
        elif gnn_type == "GAT":
            self.conv1 = GATConv(in_channels, hidden_channels)
            self.conv2 = GATConv(hidden_channels, out_channels)
        elif gnn_type == "GIN":
            self.conv1 = GINConv(torch.nn.Sequential(
                torch.nn.Linear(in_channels, hidden_channels),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_channels, hidden_channels)
            ))
            self.conv2 = GINConv(torch.nn.Sequential(
                torch.nn.Linear(hidden_channels, out_channels),
                torch.nn.ReLU(),
                torch.nn.Linear(out_channels, out_channels)
            ))
        else:
            raise ValueError("Unsupported GNN type")

    def forward(self, input, edge_index):
        hidden_output = F.relu(self.conv1(input, edge_index))
        output = self.conv2(hidden_output, edge_index)
        return output

In [76]:
def auc_loss(model, data, train_edge_index, neg_edge_index, margin=1.0):
    node_embeddings = model(data.x, train_edge_index)

    # Positive edge scores
    pos_u, pos_v = train_edge_index
    pos_scores = F.cosine_similarity(node_embeddings[pos_u], node_embeddings[pos_v])

    # Negative edge scores
    neg_u, neg_v = neg_edge_index
    neg_scores = F.cosine_similarity(node_embeddings[neg_u], node_embeddings[neg_v])

    # Margin-based ranking loss
    loss = F.relu(margin + neg_scores - pos_scores).mean()
    return loss

In [87]:
from rich.console import Console
from rich.table import Table

console = Console()

def train(model, data, optimizer, train_edge_index, val_edge_index, val_neg_edge_index, epochs=200, val_interval=20):
    model.to(device)
    model.train()

    for epoch in range(1, epochs + 1):
        train_neg_edge_index = negative_sampling(edge_index=train_edge_index, 
                                                 num_nodes=data.num_nodes, 
                                                 num_neg_samples=train_edge_index.size(1)).to(device)

        loss = auc_loss(model, data, train_edge_index, train_neg_edge_index)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if epoch % val_interval == 0:
            model.eval()
            with torch.no_grad():
                val_loss = auc_loss(model, data, val_edge_index, val_neg_edge_index)
            model.train()

            console.print(f"Epoch {epoch:3d} | [cyan]Train Loss:[/] {loss:.5f} | [green]Val Loss:[/] {val_loss:.5f}")

In [88]:
models = {}

for dataset_name in datasets:
    models[dataset_name] = {'GCN': None, 'GAT': None, 'GIN': None}
    
    for gnn_model in models[dataset_name]:
        console.rule(f"[bold blue]Dataset: {dataset_name}, GNN Model: {gnn_model}")
        
        model = GNNModel(datasets[dataset_name].num_features, gnn_type=gnn_model)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    
        train(model, datasets[dataset_name], optimizer, 
              train_edge_index[dataset_name], 
              val_edge_index[dataset_name], val_neg_edge_index[dataset_name])
    
        model.eval()
        with torch.no_grad():
            test_loss = auc_loss(model, datasets[dataset_name], test_edge_index[dataset_name], test_neg_edge_index[dataset_name])

        console.print(f"[bold red]Final Test Loss: {test_loss:.5f}[/]\n")
        
        models[dataset_name][gnn_model] = model

In [89]:
import torch
from torch_geometric.utils import to_dense_adj
from sklearn.metrics import pairwise_distances
import numpy as np

def compute_set_Q(edge_index, num_nodes):
    """
    Identify nodes in Q: nodes part of at least one triangle in the graph.
    """
    adj = to_dense_adj(edge_index, max_num_nodes=num_nodes).squeeze(0)
    triangles = torch.mm(adj, adj) * adj  # Triangle adjacency
    Q = torch.nonzero(triangles.sum(dim=1) > 0).squeeze()
    return Q[:1000]

def cosine_similarity(u, v):
    """
    Compute cosine similarity between two sets of embeddings.
    """
    u = F.normalize(u, dim=-1)
    v = F.normalize(v, dim=-1)
    return torch.mm(u, v.t())

def precision_at_k(predictions, labels, k):
    """
    Calculate Precision@K.
    """
    top_k = predictions[:, :k]
    correct = (labels.gather(1, top_k) > 0).sum(dim=1).float()
    return (correct / k).mean().item()

def mean_reciprocal_rank(predictions, labels):
    """
    Calculate Mean Reciprocal Rank (MRR).
    """
    ranks = torch.arange(1, predictions.size(1) + 1, device=predictions.device).float()
    mask = labels.gather(1, predictions) > 0
    reciprocal_rank = (mask.float() * (1.0 / ranks)).sum(dim=1)
    return reciprocal_rank.mean().item()

In [92]:
def evaluate_inference(model, data, test_edge_index, Q, K_values):
    """
    Perform inference and evaluate Precision@K and MRR for the trained GNN model.
    """
    model.eval()
    with torch.no_grad():
        embeddings = model(data.x, test_edge_index)
        
        # Rank candidate node pairs for each node in Q
        precision_scores = {k: [] for k in K_values}
        mrr_scores = []
        adj = to_dense_adj(test_edge_index, max_num_nodes=data.num_nodes).squeeze(0)
        
        for q in Q:
            scores = cosine_similarity(embeddings[q].unsqueeze(0), embeddings)
            
            scores = scores.squeeze().argsort(descending=True)  # Sorted indices
            labels = adj[q].long()  # Ground truth edges
            
            # Precision@K
            for k in K_values:
                precision_scores[k].append(precision_at_k(scores.unsqueeze(0), labels.unsqueeze(0), k))
            
            # MRR
            mrr_scores.append(mean_reciprocal_rank(scores.unsqueeze(0), labels.unsqueeze(0)))
        
        # Average metrics
        precision_avg = {k: np.mean(precision_scores[k]) for k in K_values}
        mrr_avg = np.mean(mrr_scores)
        
    return precision_avg, mrr_avg

In [95]:
Q = {}
K_values = [1, 5, 10]

for dataset_name in datasets:
    Q[dataset_name] = compute_set_Q(test_edge_index[dataset_name], num_nodes[dataset_name])
    
    for gnn_model in models[dataset_name]:
        console.rule(f"[bold blue]Dataset: {dataset_name}, GNN Model: {gnn_model}")
        
        model = models[dataset_name][gnn_model]
        precision_avg, mrr_avg = evaluate_inference(model, datasets[dataset_name], test_edge_index[dataset_name], Q[dataset_name], K_values)

        print(f"{gnn_model} - Precision@K: {precision_avg}, MRR: {mrr_avg}")
        # console.print(f"[bold red]Final Test Loss: {test_loss:.5f}[/]\n")

AAAA 24


torch.Size([2708, 64])
2708
torch.Size([2, 416])
GCN - Precision@K: {1: 0.20833333333333334, 5: 0.32500000670552254, 10: 0.17083333681027094}, MRR: 0.7822519254987128


torch.Size([2708, 64])
2708
torch.Size([2, 416])
GAT - Precision@K: {1: 0.2916666666666667, 5: 0.32500000732640427, 10: 0.1791666711991032}, MRR: 0.7931965226307511


torch.Size([2708, 64])
2708
torch.Size([2, 416])
GIN - Precision@K: {1: 0.2916666666666667, 5: 0.42500000819563866, 10: 0.22500000521540642}, MRR: 0.9761710288003087
AAAA 528


torch.Size([28281, 64])
28281
torch.Size([2, 7536])
GCN - Precision@K: {1: 0.0625, 5: 0.517045467471083, 10: 0.33068182514133776}, MRR: 1.1074701686359758


torch.Size([28281, 64])
28281
torch.Size([2, 7536])
GAT - Precision@K: {1: 0.06060606060606061, 5: 0.4534091017575878, 10: 0.30189394569871103}, MRR: 0.997229561876421


torch.Size([28281, 64])
28281
torch.Size([2, 7536])
GIN - Precision@K: {1: 0.058712121212121215, 5: 0.29166667264970864, 10: 0.20643939815856743}, MRR: 0.7103986088186502


In [None]:
def adamic_adar(edge_index, Q, num_nodes):
    """
    Compute Adamic-Adar scores for candidate pairs.
    """
    adj = to_dense_adj(edge_index, max_num_nodes=num_nodes).squeeze(0).numpy()
    degree = adj.sum(axis=0)
    scores = {}
    for nodes in Q:
        q = nodes.item()
        scores[q] = {}
        for other in range(num_nodes):
            if q != other and not adj[q, other]:
                common_neighbors = np.where(np.logical_and(adj[q], adj[other]))[0]
                scores[q][other] = sum(1.0 / np.log(degree[cn]) for cn in common_neighbors)
    return scores

def common_neighbors(edge_index, Q, num_nodes):
    """
    Compute Common Neighbors scores for candidate pairs.
    """
    adj = to_dense_adj(edge_index, max_num_nodes=num_nodes).squeeze(0).numpy()
    scores = {}
    
    for nodes in Q:
        q = nodes.item()
        scores[q] = {}
        for other in range(num_nodes):
            if q != other and not adj[q, other]:
                scores[q][other] = np.logical_and(adj[q], adj[other]).sum()
    return scores

In [None]:
aa_scores = adamic_adar(test_edge_index, Q, data.num_nodes)
cn_scores = common_neighbors(test_edge_index, Q, data.num_nodes)

In [None]:
def evaluate_non_trainable(scores, test_edge_index, Q, K_values, num_nodes):
    """
    Evaluate Precision@K and MRR for Adamic-Adar or Common Neighbors scores.
    """
    adj = to_dense_adj(test_edge_index, max_num_nodes=num_nodes).squeeze(0)
    precision_scores = {k: [] for k in K_values}
    mrr_scores = []
    
    for q in Q:
        # Extract scores for node `q`
        node_scores = scores[q.item()]
        sorted_candidates = sorted(node_scores.items(), key=lambda x: x[1], reverse=True)
        sorted_nodes = torch.tensor([candidate[0] for candidate in sorted_candidates])
        # Get ground truth labels for node `q`
        labels = adj[q].long()
        # Precision@K
        for k in K_values:
            precision_scores[k].append(precision_at_k(sorted_nodes.unsqueeze(0), labels.unsqueeze(0), k))
        # MRR
        mrr_scores.append(mean_reciprocal_rank(sorted_nodes.unsqueeze(0), labels.unsqueeze(0)))

    # Average metrics
    precision_avg = {k: np.mean(precision_scores[k]) for k in K_values}
    mrr_avg = np.mean(mrr_scores)

    return precision_avg, mrr_avg

# Example usage for Adamic-Adar
precision_avg_aa, mrr_avg_aa = evaluate_non_trainable(aa_scores, test_edge_index, Q, K_values, data.num_nodes)
print(f"Adamic-Adar - Precision@K: {precision_avg_aa}, MRR: {mrr_avg_aa}")

# Example usage for Common Neighbors
precision_avg_cn, mrr_avg_cn = evaluate_non_trainable(cn_scores, test_edge_index, Q, K_values, data.num_nodes)
print(f"Common Neighbors - Precision@K: {precision_avg_cn}, MRR: {mrr_avg_cn}")