In [100]:


import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from class_resolver.contrib.torch import activation_resolver
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn import FAConv,HeteroConv
from torch_geometric.nn.conv.gcn_conv import gcn_norm

from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear
# from torch_timeseries.layers.graphsage import MyGraphSage

# from torch_timeseries.utils.norm import hetero_directed_norm

def hetero_directed_norm(edge_index, edge_weight=None, num_nodes=None,
              dtype=None):

    if isinstance(edge_index, SparseTensor):
        raise NotImplementedError("Operation of Sparse Tensor Not defined!")
    else:
        num_nodes = maybe_num_nodes(edge_index, num_nodes)

        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                     device=edge_index.device)

        row, col = edge_index[0], edge_index[1]
        
        # in degree of every node |N|
        in_deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)

        # out degree of every node |N|
        out_deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)

        # nomalization
        in_deg_inv_sqrt = in_deg.pow_(-0.5)
        out_deg_inv_sqrt = out_deg.pow_(-0.5)
        in_deg_inv_sqrt.masked_fill_(in_deg_inv_sqrt == float('inf'), 0)
        out_deg_inv_sqrt.masked_fill_(out_deg_inv_sqrt == float('inf'), 0)

        # source node out degree, target node in degree 
        return out_deg_inv_sqrt[row] * edge_weight * in_deg_inv_sqrt[col]





In [151]:

class TSConv(MessagePassing):
    # convolution for relation < t -> s >
    def __init__(self, in_channels, out_channels, eps=0.9):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.eps = eps
        self.att_g = Linear(2*in_channels,1, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.att_g.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None):
        # x has shape [N+T, in_channels]
        # edge_index has shape [2, E]
        # edge_index has shape [E, weight_dim]
        
        xt = x[0]
        xs = x[1]
        
        x = torch.concat([xs, xt], dim=0)
        
        edge_weight = hetero_directed_norm(  # yapf: disable
            edge_index, edge_weight, x.size(self.node_dim), dtype=x.dtype)

        t2s_info = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        
        return t2s_info

    def message(self,x_i, x_j, edge_weight):
        # x_j has shape [|E|, out_channels] , The first n edges are edges of  spatial nodes
        # x_j denotes a lifted tensor, which contains the source node features of each edge, source_node(如果flow 是 source_to_target)
        # 要从从有向图的角度来解释 edge_index 有几个，就有几个x_
        
        # 对 st , ts分别定义
        
        alpha_i_j = self.att_g(torch.concat([x_i, x_j], axis=1)).tanh().squeeze(-1) # ( |E|, )    
        
        return x_j *( alpha_i_j * edge_weight ).view(-1,1)


class STConv(MessagePassing):
    # convolution for relation < s -> t >
    def __init__(self, in_channels, out_channels, eps=0.9):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.eps = eps
        self.att_g = Linear(2*in_channels,1, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.att_g.reset_parameters()
        # self.bias.data.zero_()

    def forward(self, x, edge_index, edge_weight=None):
        # x has shape [N+T, in_channels]
        # edge_index has shape [2, E]
        # edge_index has shape [E, weight_dim]
        xs = x[0]
        xt = x[1]
        
        x = torch.concat([xs, xt], dim=0)
        
        edge_weight = hetero_directed_norm(  # yapf: disable
            edge_index, edge_weight, x.size(self.node_dim), dtype=x.dtype)
        s2t_info = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        return s2t_info

    def message(self,x_i, x_j, edge_weight):
        # x_j has shape [|E|, out_channels] , The first n edges are edges of  spatial nodes
        # x_j denotes a lifted tensor, which contains the source node features of each edge, source_node(如果flow 是 source_to_target)
        # 要从从有向图的角度来解释 edge_index 有几个，就有几个x_
        
        # 对 st , ts分别定义
        alpha_i_j = self.att_g(torch.concat([x_i, x_j], axis=1)).tanh().squeeze(-1) # ( |E|, )    
        
        return x_j *( alpha_i_j * edge_weight ).view(-1,1)



class SSConv(MessagePassing):
    # convolution for relation < s -> t >
    def __init__(self, in_channels, out_channels,add_self_loops=True, eps=0.9):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.eps = eps
        self.att_g = Linear(2*in_channels,1, bias=False)
        self.add_self_loops = add_self_loops
        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.att_g.reset_parameters()
        # self.bias.data.zero_()

    def forward(self, x, edge_index, edge_weight=None):
        # x has shape [N+T, in_channels]
        # edge_index has shape [2, E]
        # edge_index has shape [E, weight_dim]
        edge_index , edge_weight = gcn_norm(  # yapf: disable
            edge_index, edge_weight,add_self_loops=self.add_self_loops,num_nodes=x.size(self.node_dim), dtype=x.dtype)
        s2s_info = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        return s2s_info

    def message(self,x_i, x_j, edge_weight):
        # x_j has shape [|E|, out_channels] , The first n edges are edges of  spatial nodes
        # x_j denotes a lifted tensor, which contains the source node features of each edge, source_node(如果flow 是 source_to_target)
        # 要从从有向图的角度来解释 edge_index 有几个，就有几个x_
        
        # 对 st , ts分别定义
        alpha_i_j = self.att_g(torch.concat([x_i, x_j], axis=1)).tanh().squeeze(-1) # ( |E|, )    
        
        return x_j *( alpha_i_j * edge_weight ).view(-1,1)


class TTConv(MessagePassing):
    # convolution for relation < s -> t >
    def __init__(self, in_channels, out_channels,add_self_loops=True, eps=0.9):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.eps = eps
        self.att_g = Linear(2*in_channels,1, bias=False)
        self.add_self_loops = add_self_loops
        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.att_g.reset_parameters()
        # self.bias.data.zero_()

    def forward(self, x, edge_index, edge_weight=None):
        # x has shape [N+T, in_channels]
        # edge_index has shape [2, E]
        # edge_index has shape [E, weight_dim]
        
        edge_index , edge_weight = gcn_norm(  # yapf: disable
            edge_index, edge_weight,add_self_loops=self.add_self_loops,num_nodes=x.size(self.node_dim), dtype=x.dtype)
        t2t_info = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        return t2t_info

    def message(self,x_i, x_j, edge_weight):
        # x_j has shape [|E|, out_channels] , The first n edges are edges of  spatial nodes
        # x_j denotes a lifted tensor, which contains the source node features of each edge, source_node(如果flow 是 source_to_target)
        # 要从从有向图的角度来解释 edge_index 有几个，就有几个x_
        
        # 对 st , ts分别定义
        alpha_i_j = self.att_g(torch.concat([x_i, x_j], axis=1)).tanh().squeeze(-1) # ( |E|, )    
        
        return x_j *( alpha_i_j * edge_weight ).view(-1,1)







In [157]:



class HeteroFASTGCN(nn.Module):
    def __init__(
        self,node_num,seq_len, in_channels, hidden_channels, n_layers, out_channels=None,
        dropout=0, norm=None, act='relu',n_first=True, act_first=False, eps=0.9, **kwargs
    ):
        
        self.node_num =node_num
        self.seq_len = seq_len
        self.n_first = n_first

        super().__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.num_layers = n_layers

        self.dropout = dropout
        self.act = activation_resolver.make(act)
        self.act_first = act_first
        self.eps = eps

        if out_channels is not None:
            self.out_channels = out_channels
        else:
            self.out_channels = hidden_channels
            
        assert n_layers >= 2 , "intra and inter conv layers must greater than or equals to 2 "

        self.convs = nn.ModuleList()
        
        self.convs.append(self.init_intra_conv(in_channels, hidden_channels))

        for i in range(n_layers - 2):
            if i % 2 == 1: # intra_conv
                self.convs.append(self.init_intra_conv(in_channels, hidden_channels))
            else: # inter conv
                self.convs.append(self.init_inter_conv(hidden_channels, hidden_channels))
          
            
        if n_layers % 2 == 1: # intra_conv
            self.convs.append(self.init_intra_conv(in_channels, out_channels))
        else: # inter conv
            self.convs.append(self.init_inter_conv(hidden_channels, out_channels))

        self.norms = None
        if norm is not None:
            self.norms = nn.ModuleList()
            for _ in range(n_layers - 1):
                self.norms.append(copy.deepcopy(norm))
            
    def init_intra_conv(self, in_channels, out_channels, **kwargs):
        # print("init_intra")
        intrast_homo_conv = HeteroConv({
            ('s', 's2s', 's'): SSConv(in_channels, out_channels, eps=self.eps),
            ('t', 't2t', 't'): TTConv(in_channels, out_channels, eps=self.eps),
        }, aggr='sum')
        return intrast_homo_conv

    def init_inter_conv(self, in_channels, out_channels, **kwargs):
        # print("init_inter")
        interst_biparte_conv = HeteroConv({
            ('s', 's2t', 't'): STConv(in_channels, out_channels, eps=self.eps),
            ('t', 't2s', 's'): TSConv(in_channels, out_channels, eps=self.eps),
        }, aggr='sum')
        return interst_biparte_conv
        # return FAConv(in_channels, out_channels, **kwargs)

    
    def forward(self, x, edge_index, edge_attr=None):
        # x: B * (N+T) * C
        # edge_index: B,2,2*(N*T)
        # edge_attr: B*E or B * (N * T )

        for i in range(self.num_layers):
            xs = list()
            for bi in range(x.shape[0]):

                x_dict = {
                    's': x[bi][:self.node_num,:],
                    't': x[bi][self.node_num:,:]
                }
                edge_index_bi = edge_index[bi]
                if i % 2 == 0: # intra
                    edge_nn = edge_index_bi[:, (edge_index_bi[0] < self.node_num) & (edge_index_bi[1] < self.node_num)]
                    edge_tt = edge_index_bi[:, (edge_index_bi[0] >=self.node_num ) & (edge_index_bi[1]  >=self.node_num)]
                    
                    # set tt edge start index to 0
                    edge_tt = edge_tt - self.node_num
                    edge_index_dict = {
                        ('s', 's2s', 's'): edge_nn,
                        ('t', 't2t', 't'): edge_tt,
                    }
                    out_dict = self.convs[i](x_dict,edge_index_dict )
                    xi = x[bi] + torch.concat([out_dict['s'], out_dict['t']], dim=0)
                    # print("prop intra")
                    
                else: # inter
                    edge_nt = edge_index_bi[:, (edge_index_bi[0] < self.node_num) & (edge_index_bi[1] >= self.node_num)]
                    edge_tn = edge_index_bi[:, (edge_index_bi[0] >= self.node_num) & (edge_index_bi[1] < self.node_num)]        
                    edge_index_dict = {
                        ('s', 's2t', 't'): edge_nt,
                        ('t', 't2s', 's'): edge_tn,
                    }
                    # print("prop inter")
                    
                    out_dict = self.convs[i](x_dict,edge_index_dict )
                    
                    xi = x[bi] + out_dict['s'] + out_dict['t']
                # combining spatial and temporal mixed information
                
                # xi = self.convs[i](x[bi], edge_index[bi])
                xs.append(xi)
            x = torch.stack(xs)
            if i == self.num_layers - 1:
                break
            
            if self.act_first:
                x = self.act(x)
            if self.norms is not None:
                x = self.norms[i](x)
            if not self.act_first:
                x = self.act(x)
            
            x = F.dropout(x, p=self.dropout, training=self.training)
        return x

In [158]:
n_nodes = 2
seq_len = 4
batch_size = 8

xn = torch.randn(n_nodes, 128)
xt = torch.randn(seq_len, 128)
xi = torch.concat([xn, xt], dim=0)


tmp = torch.zeros((n_nodes + seq_len, n_nodes + seq_len)) # (NxT , NxT)
tmp[:n_nodes, n_nodes:] = 1
tmp[n_nodes:, :n_nodes] = 1
tmp[:n_nodes, :n_nodes] = 1
tmp[n_nodes:, n_nodes:] = 1
tmp = tmp - torch.eye(n_nodes + seq_len,n_nodes + seq_len)
edge_index = torch.nonzero(tmp).T
# edge_index
# [edge_index[0][edge_index[0] < n_nodes],
# edge_index[1][edge_index[1] >= n_nodes]]



batch_x = xi.expand(batch_size, -1, -1)
batch_indices = edge_index.expand(batch_size, -1, -1)



In [159]:
edge_index_bi  = batch_indices[0]

edge_nn = edge_index_bi[:, (edge_index_bi[0] < n_nodes) & (edge_index_bi[1] < n_nodes)]

In [160]:
edge_nn

tensor([[0, 1],
        [1, 0]])

In [163]:
gcn = HeteroFASTGCN(n_nodes,seq_len, in_channels=128, hidden_channels=128, n_layers=3)

init_intra
init_inter
init_intra


In [164]:
gcn(batch_x, batch_indices)

prop intra
prop intra
prop intra
prop intra
prop intra
prop intra
prop intra
prop intra
prop inter
prop inter
prop inter
prop inter
prop inter
prop inter
prop inter
prop inter
prop intra
prop intra
prop intra
prop intra
prop intra
prop intra
prop intra
prop intra


tensor([[[ 0.0000e+00,  3.5309e-01,  0.0000e+00,  ...,  6.5688e-01,
          -3.2479e-01,  3.9053e-01],
         [ 0.0000e+00, -6.9949e-02,  0.0000e+00,  ..., -1.3013e-01,
           7.6174e-01, -7.7367e-02],
         [ 0.0000e+00,  7.5263e-02,  1.1883e-01,  ...,  1.4307e-01,
           2.5159e-02,  1.3688e+00],
         [ 0.0000e+00,  6.2398e-02,  3.7416e-01,  ..., -2.0323e-03,
           7.1230e-02, -2.0018e-01],
         [ 0.0000e+00, -4.1840e-04, -5.2703e-03,  ...,  7.3938e-01,
          -4.8131e-03,  8.9747e-01],
         [ 0.0000e+00,  2.2863e-02,  2.2337e-01,  ...,  7.8138e-03,
           1.2037e-01,  3.1702e-03]],

        [[ 0.0000e+00,  3.5309e-01,  0.0000e+00,  ...,  6.5688e-01,
          -3.2479e-01,  3.9053e-01],
         [ 0.0000e+00, -6.9949e-02,  0.0000e+00,  ..., -1.3013e-01,
           7.6174e-01, -7.7367e-02],
         [ 0.0000e+00,  7.5263e-02,  1.1883e-01,  ...,  1.4307e-01,
           2.5159e-02,  1.3688e+00],
         [ 0.0000e+00,  6.2398e-02,  3.7416e-01,  ...

In [86]:
out_dict['t']

6

In [213]:
import torch
all_nt = 20
n_nodes = 10
t_nodes = 10
n_embedding = torch.nn.Parameter(torch.rand(n_nodes, 16))

data =torch.randn(64 , all_nt, 16)


def build_adj(node_embs):
    # 使用softmax处理输入
    tmp = torch.relu(torch.einsum("bnf, bmf -> bnm", data,data))
    softmax_output = torch.softmax(tmp, dim=-1)
    
    # 将原始输入中为0的位置在输出中设置为0
    softmax_output[tmp == 0] = 0

    return softmax_output


In [214]:
adj = build_adj(data)

tensor([[[9.7022e-01, 1.4890e-03, 1.9415e-03,  ..., 0.0000e+00,
          0.0000e+00, 3.0381e-04],
         [9.3984e-08, 9.9999e-01, 2.9526e-06,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0251e-06, 2.4698e-05, 9.9975e-01,  ..., 0.0000e+00,
          1.5246e-04, 0.0000e+00],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 8.6907e-01,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 3.1669e-03,  ..., 0.0000e+00,
          9.9311e-01, 0.0000e+00],
         [4.2778e-05, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 9.8225e-01]],

        [[9.9403e-01, 4.0011e-04, 2.9565e-06,  ..., 0.0000e+00,
          2.6809e-06, 1.2316e-04],
         [4.3335e-03, 9.9048e-01, 0.0000e+00,  ..., 0.0000e+00,
          2.8735e-05, 0.0000e+00],
         [2.1323e-09, 0.0000e+00, 9.9983e-01,  ..., 3.1886e-09,
          1.6799e-08, 8.8486e-06],
         ...,
         [0.0000e+00, 0.0000e+00, 9.6847e-10,  ..., 9.9994e-01,
          0.000

In [210]:
adj[0]

tensor([[16.8180,  1.4516,  0.0000,  0.0000,  8.7246,  3.7397,  0.0000,  0.0000,
          0.0000,  0.0000,  7.9727,  0.0000,  0.0000,  4.1917,  7.5985,  4.5027,
          0.0000,  0.0000,  6.6029,  0.0000],
        [ 1.4516, 14.9157,  0.0000,  2.6948,  1.9211,  5.7597,  0.0000,  0.0000,
          3.7293,  0.0000,  5.8095,  0.0000,  0.0000,  0.0000,  2.6841,  0.7558,
          0.0000,  0.0000,  0.0000,  1.5335],
        [ 0.0000,  0.0000, 19.8097,  0.0000,  1.1402,  0.0000,  4.5790,  1.7512,
          9.9575,  5.9772,  0.0000,  5.6550,  3.3784,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  2.6948,  0.0000, 12.7121,  0.8827,  0.0000,  0.0000,  0.9208,
          0.0000,  0.0000,  0.0000,  0.0000,  5.6144,  0.2867,  0.0000,  0.0000,
          0.2528,  0.0000,  0.0000,  8.2223],
        [ 8.7246,  1.9211,  1.1402,  0.8827, 15.8133,  2.0579,  0.8128,  0.0000,
          1.7998,  0.0000,  4.5066,  0.0000,  3.7069,  0.0000,  2.4055,  0.2129,
      

In [170]:
biadj = adj.detach().clone()
biadj[:, :n_nodes, :n_nodes] = 0
biadj[:, n_nodes:, n_nodes:] = 0


topk_values, topk_indices = torch.topk(biadj, 5)

In [171]:
topk_indices.shape

torch.Size([64, 20, 5])

In [178]:
batch_size = 64
def topk_elements(matrix, k):
    values, indices = torch.topk(matrix.view(-1), k)

    result = torch.zeros_like(matrix)

    flat_result = result.view(-1)
    flat_result[indices] = values

    return result
rate = 0.5
k = int(rate * 2 * n_nodes * t_nodes)
for bi in range(batch_size):
    biadj[bi] = topk_elements(biadj[bi], k)
    

In [179]:
biadj[bi] 

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9221e-03, 2.5567e-03,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         2.9786e-04, 4.3862e-04],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         6.7577e-03, 0.0000e+00, 2.0685e-03, 0.0000e+00, 0.0000e+00, 2.1023e-04,
         9.7227e-04, 2.7818e-04],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         3.3110e-04, 1.0145e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 8.3319e-04],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.9336e-05, 1.1405e-02,
       

In [148]:
biadj[ind] = 0

In [81]:
adj[ind] = 0

mask =  torch.zeros_like(adj)

mask = 



In [83]:
adj.shape

torch.Size([64, 10, 10])

In [None]:
# biparte build : select top k 
biadj = adj.detach().clone()
biadj[:, :n_nodes, :n_nodes] = 0
biadj[:, n_nodes:, n_nodes:] = 0
def topk_elements(matrix, k):
    values, indices = torch.topk(matrix.view(-1), k)

    result = torch.zeros_like(matrix)

    flat_result = result.view(-1)
    flat_result[indices] = values

    return result

k = int(edge_rate * 2 * n_nodes * seq_len)
for bi in range(batch_size):
    biadj[bi] = topk_elements(biadj[bi], k)
adj[:, :n_nodes, :n_nodes] = biadj[:, :n_nodes, :n_nodes]
adj[:, n_nodes:, n_nodes:] = biadj[:, n_nodes:, n_nodes:]



In [218]:
# source_nodes, target_nodes = adj.nonzero().t()

ValueError: too many values to unpack (expected 2)