In [155]:
import dgl
import torch

n = 1000
Q = 3
P = 2
heads = 8
embedding_dimension = 8
lr = 5e-4
nodes = 8

device = "cpu"

X = torch.randn(n, nodes, Q).to(device)
Y = torch.randn(n, P).to(device)
u1 = torch.tensor([0,1,2,3,
                    4,5,6,7,
                    0,4,
                    1,5,
                    2,6,
                    3,7])
v1 = torch.tensor([0,0,0,0,
              5,5,5,5,
              4,4,
              1,1,
              2,2,
              3,3])
g1 = dgl.graph((u1,v1)).to(device)

u2 = torch.tensor([0,1,2,3,4,5])
v2 = torch.tensor([6,6,6,6,6,6])
g2 = dgl.graph((u2,v2)).to(device)

In [156]:
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
class GAT(nn.Module):
    def __init__(self, in_feats, hidden_feats,out_feats):
        super().__init__()
        # 实例化SAGEConve，in_feats是输入特征的维度，out_feats是输出特征的维度，aggregator_type是聚合函数的类型
        self.conv1 = dglnn.GATConv(
            in_feats=in_feats, out_feats=hidden_feats,num_heads = heads, allow_zero_in_degree=True)
        self.conv2 = dglnn.GATConv(
            in_feats=hidden_feats, out_feats=hidden_feats, num_heads = heads, allow_zero_in_degree=True)
        self.mlp = torch.nn.Linear(hidden_feats, out_feats)

    def forward(self, graph1, graph2, inputs):
        # 输入是节点的特征
        h = self.conv1(graph1, inputs)
        h = torch.sum(h,1)
        h = h[:6]
        h = torch.vstack((h,torch.zeros(1,h.shape[-1]).to(device)))
        h = self.conv2(graph2,h)
        h = torch.sum(h,1)
        h = h[-1]
        h = self.mlp(h)
        return h

In [157]:
def evaluate(model, graph, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(graph, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

In [158]:
def data_iter(X,Y):
    i = 0
    while i < len(X):
        yield X[i],Y[i]
        i += 1
    

In [159]:
model = GAT(in_feats=Q, hidden_feats=embedding_dimension,out_feats=P).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=lr)
loss = nn.MSELoss()

model.train()
for epoch in range(100):
    for x, y in data_iter(X,Y):
        # print(x)
        # 使用所有节点(全图)进行前向传播计算
        logits = model(g1, g2, x)
        # 计算损失值
        l = loss(input=logits, target=y)
        # # 进行反向传播计算
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        print(l.item())
        break


0.14518514275550842
0.12449261546134949
0.1055426150560379
0.08828239887952805
0.07270374894142151
0.05878599360585213
0.04650646448135376
0.03584323823451996
0.026740631088614464
0.019135532900691032
0.012960558757185936
0.008133267052471638
0.0045553757809102535
0.002111895941197872
0.0006593635771423578
4.668055407819338e-05
0.0001091367521439679
0.0006737431976944208
0.0015665313694626093
0.002621683292090893
0.00369146722368896
0.004654976073652506
0.005423668771982193
0.005942920222878456
0.006190318148583174
0.006171228364109993
0.0059128799475729465
0.005457884632050991
0.004857689142227173
0.0041667805053293705
0.003437851322814822
0.00271802581846714
0.002046378096565604
0.00145246391184628
0.0009558270103298128
0.0005663233459927142
0.00028511430718936026
0.00010596645734040067
1.6996766134980135e-05
2.444196297801682e-06
4.4421871280064806e-05
0.00012453441740944982
0.00022524964879266918
0.0003309870953671634
0.0004288750351406634
0.0005091900820843875
0.000565531314350664