In [2]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
class STHeteroConv(MessagePassing):
    def __init__(self,node_num, in_s_channels, i_t_channels, hidden_channels, out_channels, eps=1.0):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.s_lin = Linear(in_s_channels, hidden_channels, bias=False)
        self.t_lin = Linear(i_t_channels, hidden_channels, bias=False)
        # self.bias = Parameter(torch.empty(out_channels))
        self.eps = eps
        self.s_trans = torch.nn.ELU()
        self.t_trans = torch.nn.ReLU()

        # self.spatial_trans = torch.nn.Parameter(hidden_channels, out_)
        
        self.node_num = node_num
        self.reset_parameters()

    def reset_parameters(self):
        self.s_lin.reset_parameters()
        self.t_lin.reset_parameters()
        # self.bias.data.zero_()

    def forward(self, x, edge_index, edge_weight=None):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        # edge_index has shape [E, weight_dim]
        print(x.shape)
        xs = self.s_lin(x[:self.node_num, :]) # diff 
        xt = self.t_lin(x[self.node_num:, :]) # not diff
        
        x = torch.concat([xs, xt], axis=0)  # (N + T , out_channels)
        
        out = self.propagate(edge_index, x=x)
        
        # plus 
        out = out + x*self.eps
        
        return out

    def message(self,x_i, x_j):
        # x_j has shape [|E|, out_channels] , The first n edges are edges of  spatial nodes
        # x_j denotes a lifted tensor, which contains the source node features of each edge, source_node(如果flow 是 source_to_target)
        # 要从从有向图的角度来解释 edge_index 有几个，就有几个x_j
        x_j = torch.concat([
            self.s_trans(x_j[self.node_num:]), # information aggregate to temporal node
            self.t_trans( x_j[:self.node_num]) # information aggregate to spatial node
        ], axis=0)
        
        return x_j


In [3]:
x = torch.randn(8, 128)

# edge_index = torch.rand(2,10)
# edge_index = torch.bernoulli(edge_index)
edge_index = torch.concat([torch.randint(low=1, high=8,size=(1,7)), torch.randint(low=1, high=8,size=(1,7))],axis=0)
edge_index =  edge_index.to(dtype=torch.int64)
edge_index



tensor([[4, 7, 7, 4, 4, 5, 7],
        [2, 3, 7, 1, 2, 1, 3]])

In [5]:
conv = STHeteroConv(3, 128, 128, 128, 128)
conv(x, edge_index)

torch.Size([8, 128])


tensor([[ 0.0171,  0.6879,  0.8599,  ...,  0.2926,  0.2404, -0.5903],
        [-1.5831,  1.6342, -0.8668,  ..., -1.1076, -1.5462,  1.4954],
        [ 1.7546, -0.5609,  0.7079,  ..., -0.1980, -0.2070, -0.8605],
        ...,
        [-0.0771,  0.3812,  0.6522,  ...,  0.0877,  0.2224,  0.7678],
        [ 0.1084,  0.3915, -0.5542,  ...,  0.0057, -0.3616,  0.2899],
        [-1.0070,  0.7585,  0.5640,  ..., -0.9495, -0.9691,  1.1126]],
       grad_fn=<AddBackward0>)

In [32]:
x

tensor([[ 0.0650,  0.9773, -0.1229,  ...,  0.9194,  1.0958,  1.1487],
        [ 1.6506, -0.4570,  0.1573,  ..., -0.9247, -0.6433,  1.4225],
        [-1.9737,  1.1208,  0.0160,  ...,  1.2821, -2.5599,  1.0460],
        ...,
        [-0.7306, -1.5670, -0.4000,  ..., -0.4804, -1.5296,  0.5681],
        [ 0.3970,  0.7127,  0.3103,  ..., -0.6843,  0.8229,  0.8038],
        [ 0.6945,  0.1936,  0.6469,  ..., -1.1334, -0.2593, -0.5858]])