In [3]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import FAConv,HeteroConv
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_scatter import scatter_add
from torch_sparse import SparseTensor, fill_diag, matmul, mul
from torch_sparse import sum as sparsesum
from torch_sparse import SparseTensor, fill_diag
def hetero_directed_norm(edge_index, edge_weight=None, num_nodes=None,
              dtype=None):

    if isinstance(edge_index, SparseTensor):
        raise NotImplementedError("Operation of Sparse Tensor Not defined!")
    else:
        num_nodes = maybe_num_nodes(edge_index, num_nodes)

        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                     device=edge_index.device)

        row, col = edge_index[0], edge_index[1]
        
        # in degree of every node |N|
        in_deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)

        # out degree of every node |N|
        out_deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)

        # nomalization
        in_deg_inv_sqrt = in_deg.pow_(-0.5)
        out_deg_inv_sqrt = out_deg.pow_(-0.5)
        in_deg_inv_sqrt.masked_fill_(in_deg_inv_sqrt == float('inf'), 0)
        out_deg_inv_sqrt.masked_fill_(out_deg_inv_sqrt == float('inf'), 0)

        # source node out degree, target node in degree 
        return out_deg_inv_sqrt[row] * edge_weight * in_deg_inv_sqrt[col]





In [4]:


class TSConv(MessagePassing):
    # convolution for relation < t -> s >
    def __init__(self, in_channels, out_channels, eps=0.9):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.eps = eps
        self.att_g = Linear(2*in_channels,1, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        self.att_g.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None):
        # x has shape [N+T, in_channels]
        # edge_index has shape [2, E]
        # edge_index has shape [E, weight_dim]
        
        xt = x[0]
        xs = x[1]
        
        x = torch.concat([xs, xt], dim=0)
        
        edge_weight = hetero_directed_norm(  # yapf: disable
            edge_index, edge_weight, x.size(self.node_dim), dtype=x.dtype)

        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        
        # plus 
        out = out + x*self.eps
        
        return out

    def message(self,x_i, x_j, edge_weight):
        # x_j has shape [|E|, out_channels] , The first n edges are edges of  spatial nodes
        # x_j denotes a lifted tensor, which contains the source node features of each edge, source_node(如果flow 是 source_to_target)
        # 要从从有向图的角度来解释 edge_index 有几个，就有几个x_
        
        # 对 st , ts分别定义
        
        alpha_i_j = self.att_g(torch.concat([x_i, x_j], axis=1)).tanh().squeeze(-1) # ( |E|, )    
        
        return x_j *( alpha_i_j * edge_weight ).view(-1,1)




class STConv(MessagePassing):
    # convolution for relation < s -> t >
    def __init__(self, in_channels, out_channels, eps=0.9):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.eps = eps
        self.s_trans = torch.nn.ELU()
        self.t_trans = torch.nn.ReLU()
        
        self.att_g = Linear(2*in_channels,1, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        self.att_g.reset_parameters()
        # self.bias.data.zero_()

    def forward(self, x, edge_index, edge_weight=None):
        # x has shape [N+T, in_channels]
        # edge_index has shape [2, E]
        # edge_index has shape [E, weight_dim]
        xs = x[0]
        xt = x[1]
        
        x = torch.concat([xs, xt], dim=0)
        
        
        edge_weight = hetero_directed_norm(  # yapf: disable
            edge_index, None, x.size(self.node_dim), dtype=x.dtype)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        # plus 
        out = out + x*self.eps
        
        return out

    def message(self,x_i, x_j, edge_weight):
        # x_j has shape [|E|, out_channels] , The first n edges are edges of  spatial nodes
        # x_j denotes a lifted tensor, which contains the source node features of each edge, source_node(如果flow 是 source_to_target)
        # 要从从有向图的角度来解释 edge_index 有几个，就有几个x_
        
        # 对 st , ts分别定义
        alpha_i_j = self.att_g(torch.concat([x_i, x_j], axis=1)).tanh().squeeze(-1) # ( |E|, )    
        
        return x_j *( alpha_i_j * edge_weight ).view(-1,1)

In [5]:
n_nodes = 2
seq_len = 4

xn = torch.randn(n_nodes, 128)
xt = torch.randn(seq_len, 128)


tmp = torch.zeros((n_nodes + seq_len, n_nodes + seq_len)) # (NxT , NxT)
tmp[:n_nodes, n_nodes:] = 1
tmp[n_nodes:, :n_nodes] = 1
edge_index = torch.nonzero(tmp).T
edge_index


tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5],
        [2, 3, 4, 5, 2, 3, 4, 5, 0, 1, 0, 1, 0, 1, 0, 1]])

In [6]:
edge_nt = torch.stack((
    edge_index[0][edge_index[0] < n_nodes], # source
    edge_index[1][edge_index[1] >= n_nodes] # target
    ))
edge_tn = torch.stack((
    edge_index[0][edge_index[0] >= n_nodes],
    edge_index[1][edge_index[1] < n_nodes]
    ))               


In [7]:
hetero_conv = HeteroConv({
    ('s', 's2t', 't'): STConv(128, 128, eps=0.9),
    ('t', 't2s', 's'): TSConv(128, 128, eps=0.9),
}, aggr='sum')
x_dict = {
    's': xn,
    't': xt
}
edge_index_dict = {
    ('s', 's2t', 't'): edge_nt,
    ('t', 't2s', 's'): edge_tn,
}
out_dict = hetero_conv(x_dict,edge_index_dict )


In [14]:
out_dict['t'].shape

torch.Size([6, 128])

In [12]:
torch.cat([out_dict['s'], out_dict['t']], dim=0).shape

torch.Size([12, 128])