In [41]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
global X
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)
        self.h = x
        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out += self.bias

        return out

    def message(self,x_i, x_j, norm):
        # x_j has shape [|E|, out_channels]
        # x_j denotes a lifted tensor, which contains the source node features of each edge, source_node(如果flow 是 source_to_target)
        # 要从从有向图的角度来解释 edge_index 有几个，就有几个x_j
        print(x_i[0], x_j[0])
        print(x_i.shape, x_j.shape)
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j


In [42]:
edge_index = torch.concat([torch.randint(low=1, high=9,size=(10,1)), torch.randint(low=1, high=9,size=(10,1))],axis=1)

In [43]:

x = torch.randn(8, 128)

# edge_index = torch.rand(2,10)
# edge_index = torch.bernoulli(edge_index)
edge_index = torch.concat([torch.randint(low=1, high=8,size=(1,7)), torch.randint(low=1, high=8,size=(1,7))],axis=0)
edge_index =  edge_index.to(dtype=torch.int64)
edge_index



tensor([[5, 5, 6, 5, 2, 3, 5],
        [1, 4, 6, 6, 4, 3, 4]])

In [44]:
conv = GCNConv(128, 32)
h = conv(x, edge_index)


tensor([-0.3273,  0.1516,  0.0161, -0.8483,  0.0766,  0.7280,  0.6113, -0.0470,
         0.9393,  0.6909,  0.0947,  0.8419,  0.1639, -0.1819,  1.0376,  0.0166,
        -0.7076,  0.9921,  0.2978, -0.4104, -1.0616,  0.6393,  0.1030,  0.5463,
        -0.3269,  0.1066, -0.1008,  0.2158, -0.1032, -0.3987, -0.4024, -0.3123],
       grad_fn=<SelectBackward0>) tensor([-0.4796, -0.6486, -1.6920, -1.0470, -0.4733, -0.2575, -0.1732, -0.4943,
         0.1441, -0.6750,  0.2937, -0.0437, -0.0987,  0.2391,  0.2676, -0.3876,
         0.9132,  0.7786, -0.3593, -0.1810, -1.0035, -1.1220, -0.5471,  0.9383,
        -0.9157,  0.2677, -0.0436,  0.3239,  0.2225,  0.7437,  0.0334,  0.9020],
       grad_fn=<SelectBackward0>)
torch.Size([7, 32]) torch.Size([7, 32])


In [48]:
bi = 0

edge_nt = torch.stack((
    edge_index[bi][0][edge_index[bi][0] < self.node_num], # source
    edge_index[bi][1][edge_index[bi][1] >= self.node_num] # target
    ))
edge_tn = torch.stack((
    edge_index[bi][0][edge_index[bi][0] >= self.node_num],
    edge_index[bi][1][edge_index[bi][1] < self.node_num]
    ))               


tensor([-0.4796, -0.6486, -1.6920, -1.0470, -0.4733, -0.2575, -0.1732, -0.4943,
         0.1441, -0.6750,  0.2937, -0.0437, -0.0987,  0.2391,  0.2676, -0.3876,
         0.9132,  0.7786, -0.3593, -0.1810, -1.0035, -1.1220, -0.5471,  0.9383,
        -0.9157,  0.2677, -0.0436,  0.3239,  0.2225,  0.7437,  0.0334,  0.9020],
       grad_fn=<SelectBackward0>)