In [25]:
import torch
import torch.nn as nn
import dgl
import dgl.function as fn

class DGL_conv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(DGL_conv, self).__init__()  # 调用父类的初始化方法
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.W = nn.Parameter(torch.ones(in_channel, out_channel))
        self.b = nn.Parameter(torch.ones(out_channel))

    def forward(self, g, h, edge_weight=None):
        g.ndata['h'] = h
        if edge_weight is not None:
            g.edata['w'] = edge_weight
            g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_neigh'))
        else:
            g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='h_neigh'))
        h_neigh = g.ndata['h_neigh']
        output = torch.matmul(h_neigh, self.W) + self.b
        return output

In [26]:
src = torch.tensor([0, 1, 1, 2, 2, 4])
dst = torch.tensor([2, 0, 2, 3, 4, 3])
h = torch.ones((5, 8))
g = dgl.graph((src, dst))
edge_weight = 2 * torch.ones(6)
conv = DGL_conv(8, 4)
output = conv(g, h, edge_weight)
import numpy as np
assert np.allclose(output.detach().numpy(), [[17., 17., 17., 17.],
                      [ 1.,  1.,  1.,  1.],
                      [33., 33., 33., 33.],
                      [33., 33., 33., 33.],
                      [17., 17., 17., 17.]])

In [27]:
import torch
import torch.nn as nn
from torch_geometric.nn.conv import MessagePassing

class PyG_conv(MessagePassing):
    def __init__(self, in_channel, out_channel):
        super(PyG_conv, self).__init__(aggr='add')
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.W = nn.Parameter(torch.ones((in_channel, out_channel)))
        self.b = nn.Parameter(torch.ones(out_channel))

    def forward(self, x, edge_index, edge_weight):
        return self.propagate(edge_index, x=x, edge_weight=edge_weight)

    def message(self, x_j, edge_weight):
        return edge_weight.view(-1, 1) * x_j

    def update(self, aggr_out):
        return torch.matmul(aggr_out, self.W) + self.b


In [28]:
import numpy as np
edge_index = torch.tensor([[0,1,1,2,2,4],[2,0,2,3,4,3]])
x = torch.ones((5, 8))
edge_weight = 2 * torch.ones(6)
conv = PyG_conv(8, 4)
output = conv(x, edge_index, edge_weight)
assert np.allclose(output.detach().numpy(), [[17., 17., 17., 17.],
                      [ 1.,  1.,  1.,  1.],
                      [33., 33., 33., 33.],
                      [33., 33., 33., 33.],
                      [17., 17., 17., 17.]])