# 链路预测

给定两个节点，预测两个节点之间有没有边

# 负采样  
确定已知边的所有节点，组成正样本  
每个源节点，随机选取节点作为负样本

实现如下:  

In [21]:
import torch
import dgl
import torch.nn as nn
import numpy as np

In [22]:
def construct_negative_graph(graph, k):
    # 每个边生成k个负采样样本
    src, dst = graph.edges()
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(), (len(src) * k,))
    return dgl.graph((neg_src, neg_dst), num_nodes=graph.num_nodes())

In [23]:
import dgl.function as fn
class DotProductPredictor(nn.Module):
    def forward(self, graph, h):
        # h contains the node representations computed from the GNN defined
        # in the node classification section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            return graph.edata['score']

class SAGE(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        # 定义模型的具体结构，例如多个线性层和激活函数
        self.layers = nn.ModuleList([
            nn.Linear(in_features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, out_features)
        ])

    def forward(self, g, x):
        # 实现图数据的前向传播
        for layer in self.layers:
            x = layer(x)
        return x

class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.sage = SAGE(in_features, hidden_features, out_features)
        self.pred = DotProductPredictor()
    def forward(self, g, neg_g, x):
        h = self.sage(g, x)
        return self.pred(g, h), self.pred(neg_g, h)

In [24]:
def compute_loss(pos_score, neg_score):
    # 间隔损失
    n_edges = pos_score.shape[0]
    return (1 - pos_score.unsqueeze(1) + neg_score.view(n_edges, -1)).clamp(min=0).mean()

In [25]:
dataset=dgl.data.CiteseerGraphDataset()
graph=dataset[0]
node_features = graph.ndata['feat']
n_features = node_features.shape[1]
k = 5
model = Model(n_features, 100, 100)
opt = torch.optim.Adam(model.parameters())
for epoch in range(2):
    negative_graph = construct_negative_graph(graph, k)
    print(negative_graph)
    pos_score, neg_score = model(graph, negative_graph, node_features)
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

  NumNodes: 3327
  NumEdges: 9228
  NumFeats: 3703
  NumClasses: 6
  NumTrainingSamples: 120
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Graph(num_nodes=3327, num_edges=46140,
      ndata_schemes={}
      edata_schemes={})
0.9999901652336121
Graph(num_nodes=3327, num_edges=46140,
      ndata_schemes={}
      edata_schemes={})
0.9996460676193237


In [27]:
node_embeddings = model.sage(graph, node_features)
print(node_embeddings[0])
print(node_embeddings)

tensor([ 0.0260, -0.0329, -0.0731,  0.0776,  0.0530, -0.0408,  0.0330,  0.0058,
        -0.0871,  0.0825, -0.0064,  0.0271,  0.0863,  0.0952, -0.0990,  0.0908,
        -0.1024, -0.0319,  0.0936, -0.0856,  0.0683,  0.0523,  0.0714, -0.0238,
         0.0259,  0.0970,  0.0190,  0.0052, -0.0721, -0.0275, -0.0562, -0.0294,
        -0.0717, -0.0064, -0.0307,  0.0978,  0.0197,  0.0734, -0.0373, -0.0694,
        -0.0831, -0.0556,  0.0809, -0.0593,  0.0630,  0.0694,  0.0074,  0.0430,
         0.0245, -0.0051,  0.0976, -0.0670, -0.0213, -0.0802,  0.0912,  0.0341,
        -0.0620,  0.0745, -0.0718, -0.0968,  0.0870, -0.0324, -0.0786, -0.0785,
        -0.0366,  0.0567, -0.0058, -0.0109,  0.0714,  0.0661, -0.0844, -0.0333,
         0.0798, -0.0249,  0.0546, -0.0656, -0.0011, -0.0432, -0.0754,  0.1029,
        -0.0164, -0.0683, -0.0565, -0.0832, -0.0385,  0.0259, -0.0928,  0.0224,
         0.0868,  0.0173,  0.0581, -0.0587, -0.0322,  0.0279,  0.0708, -0.0845,
         0.0441, -0.0203,  0.0631, -0.00