In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import negative_sampling, add_self_loops, remove_self_loops, train_test_split_edges
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score

In [3]:
torch.manual_seed(42)

class LinkPredictionGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, 
                 layer_type='GCN', dropout=0.5):
        super().__init__()
        conv_dict = {'GCN': GCNConv, 'GAT': GATConv, 'SAGE': SAGEConv}
        conv_layer = conv_dict[layer_type]
        
        self.convs = nn.ModuleList()
        self.convs.append(conv_layer(in_channels, hidden_channels))
        
        for _ in range(num_layers - 2):
            self.convs.append(conv_layer(hidden_channels, hidden_channels))
            
        self.convs.append(conv_layer(hidden_channels, out_channels))
        
        self.dropout = dropout
    
    def encode(self, x, edge_index):
        """Encode node features into embeddings"""
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = self.convs[-1](x, edge_index)
        return x
    
    def decode(self, z, edge_index):
        """Compute edge scores from node embeddings"""
        src, dst = edge_index
        return (z[src] * z[dst]).sum(dim=1)
    
    def decode_all(self, z):
        """Decode all possible edges"""
        return torch.matmul(z, z.t())
    
    def forward(self, x, edge_index):
        z = self.encode(x, edge_index)
        return z

In [None]:
def train_link_pred(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    
    z = model.encode(data.x, data.train_pos_edge_index)
    
    pos_edge_index = data.train_pos_edge_index
    
    neg_edge_index = negative_sampling(
        edge_index=pos_edge_index,
        num_nodes=data.num_nodes,
        num_neg_samples=pos_edge_index.size(1))
    
    pos_score = model.decode(z, pos_edge_index)
    neg_score = model.decode(z, neg_edge_index)
    
    scores = torch.cat([pos_score, neg_score], dim=0)
    labels = torch.cat([torch.ones(pos_score.size(0)), 
                       torch.zeros(neg_score.size(0))], dim=0)
    
    loss = F.binary_cross_entropy_with_logits(scores, labels)
    loss.backward()
    optimizer.step()
    
    return loss.item()

@torch.no_grad()
def test_link_pred(model, data):
    model.eval()
    z = model.encode(data.x, data.train_pos_edge_index)
    
    results = {}
    
    pos_edge_index = data.test_pos_edge_index
    neg_edge_index = data.test_neg_edge_index
    
    pos_score = model.decode(z, pos_edge_index)
    neg_score = model.decode(z, neg_edge_index)
    
    scores = torch.cat([pos_score, neg_score], dim=0)
    labels = torch.cat([torch.ones(pos_score.size(0)), 
                       torch.zeros(neg_score.size(0))], dim=0)
    
    scores_np = scores.cpu().numpy()
    labels_np = labels.cpu().numpy()
    
    results['test_auc'] = roc_auc_score(labels_np, scores_np)
    results['test_ap'] = average_precision_score(labels_np, scores_np)
    
    pos_edge_index = data.val_pos_edge_index
    neg_edge_index = data.val_neg_edge_index
    
    pos_score = model.decode(z, pos_edge_index)
    neg_score = model.decode(z, neg_edge_index)
    
    scores = torch.cat([pos_score, neg_score], dim=0)
    labels = torch.cat([torch.ones(pos_score.size(0)), 
                       torch.zeros(neg_score.size(0))], dim=0)
    
    scores_np = scores.cpu().numpy()
    labels_np = labels.cpu().numpy()
    
    results['val_auc'] = roc_auc_score(labels_np, scores_np)
    results['val_ap'] = average_precision_score(labels_np, scores_np)
    
    return results

def run_experiment(dataset_name, gnn_type, use_self_loops=False, embedding_dim=64):
    dataset = Planetoid(root='.', name=dataset_name)
    data = dataset[0]
    
    data = train_test_split_edges(data)
    
    if use_self_loops:
        data.train_pos_edge_index, _ = add_self_loops(data.train_pos_edge_index, 
                                                      num_nodes=data.num_nodes)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = LinkPredictionGNN(
        in_channels=dataset.num_features,
        hidden_channels=128,
        out_channels=embedding_dim,
        num_layers=2,
        layer_type=gnn_type
    ).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    data = data.to(device)
    
    train_losses = []
    val_aucs = []
    
    best_val_auc = 0
    best_model = None
    
    for epoch in range(1, 201):
        loss = train_link_pred(model, optimizer, data)
        train_losses.append(loss)
        
        results = test_link_pred(model, data)
        val_aucs.append(results['val_auc'])
        
        if results['val_auc'] > best_val_auc:
            best_val_auc = results['val_auc']
            best_model = {k: v.cpu() for k, v in model.state_dict().items()}
        
        if epoch % 10 == 0:
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val AUC: {results["val_auc"]:.4f}, '
                  f'Val AP: {results["val_ap"]:.4f}')
    
    model.load_state_dict(best_model)
    model = model.to(device)
    final_results = test_link_pred(model, data)
    
    loops_str = "with" if use_self_loops else "without"
    print(f"\nFinal Results for {dataset_name} with {gnn_type} {loops_str} self-loops:")
    print(f"Test AUC: {final_results['test_auc']:.4f}")
    print(f"Test AP: {final_results['test_ap']:.4f}")
    
    return {
        'train_losses': train_losses,
        'val_aucs': val_aucs,
        'final_results': final_results
    }

def compare_self_loops():
    datasets = ['Cora', 'PubMed']
    gnn_types = ['GCN', 'GAT', 'SAGE']
    
    all_results = {}
    
    for dataset_name in datasets:
        all_results[dataset_name] = {}
        for gnn_type in gnn_types:
            print(f"\nRunning {gnn_type} on {dataset_name} without self-loops...")
            without_loops = run_experiment(dataset_name, gnn_type, use_self_loops=False)
            
            print(f"\nRunning {gnn_type} on {dataset_name} with self-loops...")
            with_loops = run_experiment(dataset_name, gnn_type, use_self_loops=True)
            
            all_results[dataset_name][gnn_type] = {
                'without_loops': without_loops,
                'with_loops': with_loops
            }
    
    for dataset_name in datasets:
        fig, axes = plt.subplots(len(gnn_types), 2, figsize=(15, 5*len(gnn_types)))
        fig.suptitle(f'Link Prediction Performance on {dataset_name}', fontsize=16)
        
        for i, gnn_type in enumerate(gnn_types):
            results = all_results[dataset_name][gnn_type]
            
            axes[i, 0].plot(results['without_loops']['train_losses'], label='Without Self-Loops')
            axes[i, 0].plot(results['with_loops']['train_losses'], label='With Self-Loops')
            axes[i, 0].set_title(f'{gnn_type} - Training Loss')
            axes[i, 0].set_xlabel('Epoch')
            axes[i, 0].set_ylabel('Loss')
            axes[i, 0].legend()
            
            axes[i, 1].plot(results['without_loops']['val_aucs'], label='Without Self-Loops')
            axes[i, 1].plot(results['with_loops']['val_aucs'], label='With Self-Loops')
            axes[i, 1].set_title(f'{gnn_type} - Validation AUC')
            axes[i, 1].set_xlabel('Epoch')
            axes[i, 1].set_ylabel('AUC')
            axes[i, 1].legend()
        
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.savefig(f'{dataset_name}_link_prediction_comparison.png')
        plt.show()
    
    print("\n--- Summary Results ---")
    for dataset_name in datasets:
        print(f"\n{dataset_name} Dataset")
        for gnn_type in gnn_types:
            results = all_results[dataset_name][gnn_type]
            without_auc = results['without_loops']['final_results']['test_auc']
            with_auc = results['with_loops']['final_results']['test_auc']
            print(f"{gnn_type}: Without self-loops AUC: {without_auc:.4f}, "
                  f"With self-loops AUC: {with_auc:.4f}, "
                  f"Improvement: {(with_auc - without_auc) * 100:.2f}%")
    
    return all_results

In [None]:
run_experiment('Cora', 'GCN', use_self_loops=False)