In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

class GCNLinkPredictor(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers):
        super(GCNLinkPredictor, self).__init__()
        self.convs = nn.ModuleList()
        
        # 输入层
        self.convs.append(GCNConv(in_channels, hidden_channels))
        # 隐藏层
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        # 输出层
        self.convs.append(GCNConv(hidden_channels, hidden_channels))
        
    def forward(self, x, edge_index):
        # 通过多层GCN来提取节点嵌入
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index))
        # 最后一层不使用激活函数
        x = self.convs[-1](x, edge_index)
        return x

class LinkPredictionModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers):
        super(LinkPredictionModel, self).__init__()
        self.gcn = GCNLinkPredictor(in_channels, hidden_channels, num_layers)
        
    def forward(self, x, edge_index, node_pairs):
        # 获取节点嵌入
        embeddings = self.gcn(x, edge_index)
        
        # 计算每对节点之间的相似度（点积）
        node_i = embeddings[node_pairs[:, 0]]
        node_j = embeddings[node_pairs[:, 1]]
        dot_product = (node_i * node_j).sum(dim=-1)
        
        # 应用 Sigmoid 函数来计算连接概率
        link_prob = torch.sigmoid(dot_product)
        return link_prob

# 创建数据示例
# 假设我们有一个5节点的图，输入特征为3维
x = torch.randn((5, 3))  # 5个节点，每个节点3个特征
edge_index = torch.tensor([[0, 1, 2, 3, 4, 0], [1, 2, 3, 4, 0, 2]], dtype=torch.long)  # 边列表
node_pairs = torch.tensor([[0, 1], [2, 3], [0, 4]], dtype=torch.long)  # 需要预测的节点对

# 模型参数
in_channels = 3
hidden_channels = 16
num_layers = 3

# 初始化并测试模型
model = LinkPredictionModel(in_channels, hidden_channels, num_layers)
output = model(x, edge_index, node_pairs)

print("Link Probabilities:", output)

AttributeError: 'builtin_function_or_method' object has no attribute 'default'