In [21]:
import torch
import dgl
import numpy as np
import torch.nn as nn

In [22]:
src = np.random.randint(0, 100, 500)
dst = np.random.randint(0, 100, 500)
# 无向图
edge_pred_graph = dgl.graph((np.concatenate([src, dst]), np.concatenate([dst, src])))
# synthetic node and edge features, as well as edge labels  
# 生成点、边特征
edge_pred_graph.ndata['feature'] = torch.randn(100, 10)
edge_pred_graph.edata['feature'] = torch.randn(1000, 10)
edge_pred_graph.edata['label'] = torch.randn(1000)
# synthetic train-validation-test splits
edge_pred_graph.edata['train_mask'] = torch.zeros(1000, dtype=torch.bool).bernoulli(0.6)

In [23]:
import dgl.function as fn
class DotProductPredictor(nn.Module):
    def forward(self, graph, h):
        # h contains the node representations computed from the GNN defined
        # in the node classification section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            return graph.edata['score']

In [24]:
class MLPPredictor(nn.Module):
    def __init__(self, in_features, out_classes):
        super().__init__()
        self.W = nn.Linear(in_features * 2, out_classes)

    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(torch.cat([h_u, h_v], 1))
        return {'score': score}

    def forward(self, graph, h):
        # h contains the node representations computed from the GNN defined
        # in the node classification section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']

In [25]:
# 假设这是SAGE模型的定义
class SAGE(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        # 定义模型的具体结构，例如多个线性层和激活函数
        self.layers = nn.ModuleList([
            nn.Linear(in_features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, out_features)
        ])

    def forward(self, g, x):
        # 实现图数据的前向传播
        for layer in self.layers:
            x = layer(x)
        return x

In [26]:
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.sage = SAGE(in_features, hidden_features, out_features)
        self.pred = DotProductPredictor()
    def forward(self, g, x):
        h = self.sage(g, x)
        return self.pred(g, h)

In [27]:
node_features = edge_pred_graph.ndata['feature']
edge_label = edge_pred_graph.edata['label']
train_mask = edge_pred_graph.edata['train_mask']
model = Model(10, 20, 5)
print(model)
opt = torch.optim.Adam(model.parameters())

for epoch in range(10):
    pred = model(edge_pred_graph, node_features)
    loss = ((pred[train_mask] - edge_label[train_mask]) ** 2).mean()
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

Model(
  (sage): SAGE(
    (layers): ModuleList(
      (0): Linear(in_features=10, out_features=20, bias=True)
      (1): ReLU()
      (2): Linear(in_features=20, out_features=5, bias=True)
    )
  )
  (pred): DotProductPredictor()
)
0.9796330332756042
0.9772471189498901
0.9751790761947632
0.9733945727348328
0.9718599915504456
0.9705431461334229
0.9694191217422485
0.9684544801712036
0.9676247835159302
0.9669097661972046
