In [2]:
import torch
import numpy as np
import torch.nn as nn
import dgl
import dgl.function as fn
from torch_geometric.nn.conv import MessagePassing

class PyG_GATConv(MessagePassing):
    def __init__(self, in_channel, out_channel):
        super(PyG_GATConv, self).__init__(aggr='add')
        self.lin = nn.Linear(in_channel, out_channel)

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j, edge_index):
        return x_j

class PyG_SAGEConv(MessagePassing):
    def __init__(self, in_channel, out_channel):
        super(PyG_SAGEConv, self).__init__(aggr='mean')
        self.lin = nn.Linear(in_channel, out_channel)

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return x_j


In [4]:
edge_index = torch.tensor([[0,1,1,2,2,4],[2,0,2,3,4,3]])
x = torch.ones((5, 8))
conv = PyG_GATConv(8, 4)
output = conv(x, edge_index)
print(output)
conv = PyG_SAGEConv(8, 4)
output = conv(x, edge_index)
print(output)


tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])


In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn

class DGL_GATConv(nn.Module):
    def __init__(self, in_channel, out_channel, num_heads):
        super(DGL_GATConv, self).__init__()
        self.fc = nn.Linear(in_channel, out_channel * num_heads)
        self.attn_fc = nn.Linear(out_channel * 2, 1)  # 修改了这里的输入维度

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = self.fc(h)
            g.apply_edges(fn.u_add_v('h', 'h', 'h_sum'))
            g.edata['attn'] = F.leaky_relu(self.attn_fc(g.edata['h_sum']))
            g.update_all(fn.u_mul_e('h', 'attn', 'm'), fn.sum('m', 'h'))
            return g.ndata.pop('h')

class DGL_SAGEConv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(DGL_SAGEConv, self).__init__()
        self.fc = nn.Linear(in_channel * 2, out_channel)

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            g.update_all(message_func=fn.copy_u('h', 'm'),
                         reduce_func=fn.mean('m', 'h_neigh'))
            h_src = g.ndata['h']  # original node features
            h_dst = h_src[:g.num_dst_nodes()]  # features of the destination nodes
            h = torch.cat((h_src, h_dst), dim=1)  # concatenate features
            return F.relu(self.fc(h))


In [22]:
src = torch.tensor([0, 1, 1, 2, 2, 4])
dst = torch.tensor([2, 0, 2, 3, 4, 3])
h = torch.ones((5, 8))
g = dgl.graph((src, dst))
conv = DGL_GATConv(8, 4, num_heads=2)  # 增加num_heads=2
output = conv(g, h)
print(output)
conv = DGL_SAGEConv(8, 4)
output = conv(g, h)
print(output)

tensor([[ 0.1215, -0.0206, -0.1269,  0.0274,  0.0818, -0.0645,  0.0467,  0.0457],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.2430, -0.0412, -0.2538,  0.0548,  0.1635, -0.1290,  0.0935,  0.0915],
        [ 0.2430, -0.0412, -0.2538,  0.0548,  0.1635, -0.1290,  0.0935,  0.0915],
        [ 0.1215, -0.0206, -0.1269,  0.0274,  0.0818, -0.0645,  0.0467,  0.0457]],
       grad_fn=<GSpMMBackward>)
tensor([[0.1816, 0.3301, 0.0000, 0.9639],
        [0.1816, 0.3301, 0.0000, 0.9639],
        [0.1816, 0.3301, 0.0000, 0.9639],
        [0.1816, 0.3301, 0.0000, 0.9639],
        [0.1816, 0.3301, 0.0000, 0.9639]], grad_fn=<ReluBackward0>)
