In [3]:
import datetime
from enum import Enum
import json
import os
import random
import re
import signal
import threading
import time
import hashlib
from prettytable import PrettyTable
import sys
####
from typing import Dict, List, Type, Union
import numpy as np
import torch
from torchmetrics import MeanSquaredError, MetricCollection, MeanAbsoluteError, MeanAbsolutePercentageError
from tqdm import tqdm
import wandb
from torch_timeseries.data.scaler import *
from torch_timeseries.datasets import *
from torch_timeseries.experiments.experiment import Experiment

from torch_timeseries.datasets.dataset import TimeSeriesDataset
from torch_timeseries.datasets.splitter import SequenceRandomSplitter, SequenceSplitter
from torch_timeseries.datasets.dataloader import (
    ChunkSequenceTimefeatureDataLoader,
    DDPChunkSequenceTimefeatureDataLoader,
)
from torch_timeseries.datasets.wrapper import MultiStepTimeFeatureSet
from torch_timeseries.models.Informer import Informer
from torch.nn import MSELoss, L1Loss

from torch.optim import Optimizer, Adam
from torch.utils.data import Dataset, DataLoader, RandomSampler, Subset

from torch.nn import DataParallel
import torch.nn as nn
from dataclasses import asdict,dataclass

from torch_timeseries.nn.metric import R2, Corr, TrendAcc,RMSE, compute_corr, compute_r2
from torch_timeseries.metrics.masked_mape import MaskedMAPE
from torch_timeseries.utils.early_stopping import EarlyStopping
import json
import codecs


from torch_timeseries.layers.tcn_output8 import TCNOuputLayer

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import torch
import torch.nn as nn
from torch.nn import init
import numbers
import torch.nn.functional as F



class PaddedDilatedInceptionT(nn.Module):
    def __init__(self, cin, dilation_factor=2, kernel_set=[2,3,6,7]):
        super(PaddedDilatedInceptionT, self).__init__()
        self.tconv = nn.ModuleList()
        self.kernel_set = kernel_set
        # cout = int(cout/len(self.kernel_set))
        # assert cout > 0
        
        for kern in self.kernel_set:
            self.tconv.append(nn.Conv2d(cin,cin,(1,kern),dilation=(1,dilation_factor)))

    def forward(self,input):

        # input: (B, C, N, T)
        (B, C, N, T) = input.size()
        x = []
        for i in range(len(self.kernel_set)):
            d_out = self.tconv[i](input)
            di = F.pad(d_out, (0, T - d_out.size(3)))
            x.append(di) # (B, C, N, T)
            
        # for i in range(len(self.kernel_set)):
        #     x[i] = x[i][...,-x[-1].size(3):]
        x = torch.cat(x,dim=1) # (B, k*C, N, T)
        return x
    
class PaddedDilatedInceptionS(nn.Module):
    def __init__(self, cin, dilation_factor=2, kernel_set=[2,3,6,7]):
        super(PaddedDilatedInceptionS, self).__init__()
        self.tconv = nn.ModuleList()
        self.kernel_set = kernel_set
        # cout = int(cout/len(self.kernel_set))
        # assert cout > 0
        
        for kern in self.kernel_set:
            self.tconv.append(nn.Conv2d(cin,cin,(kern,1),dilation=(1,dilation_factor)))

    def forward(self,input):

        # input: (B, C, N, T)
        (B, C, N, T) = input.size()
        x = []
        for i in range(len(self.kernel_set)):
            d_out = self.tconv[i](input)
            di = F.pad(d_out, (0, 0,0, N - d_out.size(2)))
            x.append(di) # (B, C, N, T)
            
        # for i in range(len(self.kernel_set)):
        #     x[i] = x[i][...,-x[-1].size(3):]
        x = torch.cat(x,dim=1) # (B, k*C, N, T)
        return x
    
    
B = 32
C = 16
T = 168
N= 90

input1 = torch.randn((B, C, N, T),)
    
t = PaddedDilatedInceptionT(C)
s = PaddedDilatedInceptionS(C)

# t(input1).shape, s(input1).shape

In [43]:
%load_ext autoreload
%autoreload 2

class DilatedTTInception(nn.Module):
    def __init__(self, cin, N, T, dilation_factor=2, kernel_set=[2,3,6,7]):
        super(DilatedTTInception, self).__init__()
        

        self.out_conv = nn.Conv2d(cin*len(kernel_set), cin, (1,1))
        
        self.filter_conv = PaddedDilatedInceptionT(cin, dilation_factor, kernel_set)
        self.gate_conv = PaddedDilatedInceptionT(cin, dilation_factor, kernel_set)
        self.layer_norm = nn.LayerNorm([cin , N, T])
        
    
    
    def forward(self, input_x):
        filter = self.filter_conv(input_x)
        gate = self.gate_conv(input_x)
        x = torch.sigmoid(filter) * torch.tanh(gate)
        x = self.out_conv(x)
        x = x + input_x
        x = self.layer_norm(x)
        return x
        
        
class DilatedSSInception(nn.Module):
    def __init__(self, cin, N, T, dilation_factor=2, kernel_set=[2,3,6,7]):
        super(DilatedSSInception, self).__init__()
        

        self.out_conv = nn.Conv2d(cin*len(kernel_set), cin, (1,1))
        
        self.filter_conv = PaddedDilatedInceptionS(cin, dilation_factor, kernel_set)
        self.gate_conv = PaddedDilatedInceptionS(cin, dilation_factor, kernel_set)
        self.layer_norm = nn.LayerNorm([cin , N, T])
        
    
    
    def forward(self, input_x):
        filter = self.filter_conv(input_x)
        gate = self.gate_conv(input_x)
        x = torch.sigmoid(filter) * torch.tanh(gate)
        x = self.out_conv(x)
        x = x + input_x
        x = self.layer_norm(x)
        return x
        
class DilatedTSInception(nn.Module):
    def __init__(self, cin, N, T, dilation_factor=2, kernel_set=[2,3,6,7]):
        super(DilatedTSInception, self).__init__()
        

        self.out_conv = nn.Conv2d(cin*len(kernel_set), cin, (1,1))
        
        self.filter_conv = PaddedDilatedInceptionT(cin, dilation_factor, kernel_set)
        self.gate_conv = PaddedDilatedInceptionS(cin, dilation_factor, kernel_set)
        self.layer_norm = nn.LayerNorm([cin , N, T])
        
    
    def forward(self, filter_x, gate_x):
        filter = self.filter_conv(filter_x)
        gate = self.gate_conv(gate_x)
        x = torch.tanh(filter) * torch.sigmoid(gate)
        x = self.out_conv(x)
        x = x + filter_x
        x = self.layer_norm(x)
        return x
        
        
class DilatedSTInception(nn.Module):
    def __init__(self, cin, N, T, dilation_factor=2, kernel_set=[2,3,6,7]):
        super(DilatedSTInception, self).__init__()
        

        self.out_conv = nn.Conv2d(cin*len(kernel_set), cin, (1,1))
        
        self.filter_conv = PaddedDilatedInceptionS(cin, dilation_factor, kernel_set)
        self.gate_conv = PaddedDilatedInceptionT(cin, dilation_factor, kernel_set)
        self.layer_norm = nn.LayerNorm([cin , N, T])
        
    
    def forward(self, filter_x, gate_x):
        filter = self.filter_conv(filter_x)
        gate = self.gate_conv(gate_x)
        x = torch.tanh(filter) * torch.sigmoid(gate)
        x = self.out_conv(x)
        x = x + filter_x
        x = self.layer_norm(x)
        return x

class STMixedConv(nn.Module):
    def __init__(self, cin, N, T, dilation_factor=2, kernel_set=[2,3,6,7]) -> None:
        super(STMixedConv, self).__init__()
        
        self.out_conv = nn.Conv2d(cin*4, cin, (1,1))

        self.tt_inception = DilatedTTInception(cin, N, T, dilation_factor=dilation_factor, kernel_set=kernel_set)
        self.ss_inception = DilatedSSInception(cin, N, T, dilation_factor=dilation_factor, kernel_set=kernel_set)
        self.ts_inception = DilatedTSInception(cin, N, T, dilation_factor=dilation_factor, kernel_set=kernel_set)
        self.st_inception = DilatedSTInception(cin, N, T, dilation_factor=dilation_factor, kernel_set=kernel_set)
        
        
    def forward(self, input_x):
        # input_x : (B, C, N, T)
        t_out = self.tt_inception(input_x)
        s_out = self.ss_inception(input_x)
        
        
        ts = self.ts_inception(t_out, s_out)
        st = self.st_inception(s_out, t_out)
        
        all_out = torch.concat([t_out, s_out, ts, st],1)
        out = self.out_conv(all_out)
        out = torch.relu(out)
        return out
    
    
    
    
class STCN(nn.Module):
    def __init__(self, N, T, O,  hidden_channel, latent_dim=128, n_layers=3, in_dim=1,out_dim=1, dilation_factor=2, kernel_set=[2,3,6,7]) -> None:
        super(STCN, self).__init__()


        self.spatial_projection = nn.Sequential(
            nn.Linear( T, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim)
        )
        self.temporal_projection =nn.GRU(
            N,latent_dim, batch_first=True
        )
        
        # self.freq_projection = nn.Sequential(
        #     nn.Linear(num_nodes, self.tcn_input_dim),
        #     nn.ReLU(),
        #     nn.Linear(self.tcn_input_dim, self.tcn_input_dim)
        # )

        

        
        self.feature_rebuild = nn.Sequential(
            nn.Linear(latent_dim, T),
            torch.nn.ELU(),
            nn.Linear(T, T),
        )

        
        self.time_rebuild = nn.GRU(latent_dim, N, batch_first=True)
        
        self.start_conv = nn.Conv2d(in_dim, hidden_channel, (1,1))
        self.st_convs = nn.ModuleList()
        self.n_layers = n_layers
        for i in range(n_layers):
            self.st_convs.append(
                STMixedConv(hidden_channel, N, T, dilation_factor, kernel_set)
            )
        self.end_conv = nn.Conv2d(hidden_channel, out_dim, (1, 1 ))
        self.mlp = nn.Sequential(
            nn.Linear(T, T), 
            nn.ReLU(),
            nn.Linear(T, O)
        )
    def forward(self, x):
        # x : (B, N, T)
        
        seq_last = x[:,:,-1:].detach()
        x = x - seq_last

        Xs = self.spatial_projection(x) # (B, N, latent_dim)
        Xt, _ = self.temporal_projection(x.transpose(1, 2)) # (B, T, latent_dim)
        # Xf = torch.abs(torch.fft.fft(x, dim=2, norm='forward'))# (B, N, T)
        
        # Xf = self.freq_projection(Xf.transpose(1,2 ))  # （B, T, latent_dim)
        
        Zs = self.feature_rebuild(Xs)  # (B, N, T)
        Zt, _ = self.time_rebuild(Xt)
        Zt = Zt.transpose(1,2) # (B, N, T)
        # Zf = self.freq_rebuild(Xf.transpose(1, 2)).transpose(1, 2) # (B, N, tcn_input_dim)
        
        Z = torch.stack([Zs, Zt, x], dim=1)
        
        

        x = self.start_conv(Z) # (B, C, N, T)
        
        # skip = x
        
        for i in range(self.n_layers):
            x = self.st_convs[i](x)
            
            
        x = self.end_conv(torch.relu(x))
        x = self.mlp(x) # (B , 1, N, O)
        
        x = (x.squeeze(1)+ seq_last).transpose(1,2)
            # skip = skip + x
        return x
        

@dataclass
class STCNExperiment(Experiment):
    
    hidden_channels: int = 16
    n_layers : int = 1
    dilated_factor : int = 2
    latent_dim : int = 16

    def _process_one_batch(self, batch_x, batch_y, batch_x_date_enc, batch_y_date_enc):
        # inputs:
        # batch_x: (B, T, N)
        # batch_y: (B, O, N)
        # ouputs:
        # - pred: (B, N)/(B, O, N)
        # - label: (B, N)/(B, O, N)
        batch_size = batch_x.size(0)
        batch_x = batch_x.to(self.device, dtype=torch.float32)
        batch_y = batch_y.to(self.device, dtype=torch.float32)
        batch_x_date_enc = batch_x_date_enc.to(self.device).float()
        batch_y_date_enc = batch_y_date_enc.to(self.device).float()
        batch_x = batch_x.transpose(1,2) # （B, N, T)
        
        outputs = self.model(batch_x)  # torch.Size([batch_size, num_nodes])
        # single step prediction
        return outputs, batch_y


    def _init_model(self):
        self.model = STCN(
            self.dataset.num_features,
            self.windows,
            self.pred_len,
            self.hidden_channels,
            self.latent_dim,
            self.n_layers,
            3,
            self.pred_len,
            self.dilated_factor,
            
        )
        self.model = self.model.to(self.device)





The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [44]:
model = STCN(N, T,3,  C, in_dim=3)
input2 = torch.randn((B, N, T))
model(input2).shape

torch.Size([32, 3, 90])

In [48]:
exp = STCNExperiment(
    hidden_channels=4,
    dataset_type="SolarEnergy",
    dilated_factor=1,
    n_layers=3,
    epochs=100,
    horizon=24,
    windows=168,
    device="cuda:4",
    data_path='/notebooks/pytorch_timeseries/data/',
)
# model = STI(168, 128, 9, 1)
# out = model(data)

exp.run(seed=234)

Using downloaded and verified file: /notebooks/pytorch_timeseries/data/solar_AL/solar_AL.txt.gz
Extracting /notebooks/pytorch_timeseries/data/solar_AL/solar_AL.txt.gz to /notebooks/pytorch_timeseries/data/solar_AL
train steps: 36601
val steps: 10321
test steps: 5065
torch.get_default_dtype() torch.float32
Creating running results saving dir: './results/runs/SolarEnergy/w168h24s1/5653879eeddeac0d52a2046e5f502361'.
run : 0 in seed: 234
+----------------------------------------------------+------------+
|                      Modules                       | Parameters |
+----------------------------------------------------+------------+
|            spatial_projection.0.weight             |    2688    |
|             spatial_projection.0.bias              |     16     |
|            spatial_projection.2.weight             |    256     |
|             spatial_projection.2.bias              |     16     |
|          temporal_projection.weight_ih_l0          |    6576    |
|          tempora

100%|██████████| 36601/36601 [04:34<00:00, 133.10it/s, epoch=0, loss=0.0914, lr=0.0003]


Epoch: 1 cost time: 275.0020716190338
Traininng loss : 0.20994998700916767
Evaluating .... 


100%|██████████| 10321/10321 [00:30<00:00, 341.83it/s]


vali_results: {'corr': 0.3775153160095215, 'mae': 0.22028812766075134, 'mse': 0.14802412688732147, 'r2': 0.8423250317573547, 'r2_weighted': 0.8427611589431763}
Testing .... 


100%|██████████| 5065/5065 [00:15<00:00, 322.89it/s]


test_results: {'corr': 0.31939759850502014, 'mae': 0.22899308800697327, 'mse': 0.15650640428066254, 'r2': 0.7610507607460022, 'r2_weighted': 0.7635198831558228}
Validation loss decreased (inf --> 0.148024).  Saving model ...
Saving run checkpoint to './results/runs/SolarEnergy/w168h24s1/5653879eeddeac0d52a2046e5f502361'.
Run state saved ... 
torch.get_default_dtype() torch.float32


100%|██████████| 36601/36601 [04:33<00:00, 133.74it/s, epoch=1, loss=0.167, lr=0.0003] 


Epoch: 2 cost time: 594.8347315788269
Traininng loss : 0.14914005703025765
Evaluating .... 


100%|██████████| 10321/10321 [00:29<00:00, 350.60it/s]


vali_results: {'corr': 0.27678728103637695, 'mae': 0.22768355906009674, 'mse': 0.14584636688232422, 'r2': 0.844618022441864, 'r2_weighted': 0.8450745344161987}
Testing .... 


100%|██████████| 5065/5065 [00:16<00:00, 311.07it/s]


test_results: {'corr': 0.1825583428144455, 'mae': 0.2417413890361786, 'mse': 0.1549934297800064, 'r2': 0.7632052302360535, 'r2_weighted': 0.7658059597015381}
Validation loss decreased (0.148024 --> 0.145846).  Saving model ...
Saving run checkpoint to './results/runs/SolarEnergy/w168h24s1/5653879eeddeac0d52a2046e5f502361'.
Run state saved ... 
torch.get_default_dtype() torch.float32


100%|██████████| 36601/36601 [04:30<00:00, 135.50it/s, epoch=2, loss=0.226, lr=0.0003] 


Epoch: 3 cost time: 910.949385881424
Traininng loss : 0.1422859501627671
Evaluating .... 


100%|██████████| 10321/10321 [00:27<00:00, 371.84it/s]


vali_results: {'corr': 0.4074137210845947, 'mae': 0.24390935897827148, 'mse': 0.15038971602916718, 'r2': 0.8397960066795349, 'r2_weighted': 0.840248167514801}
Testing .... 


100%|██████████| 5065/5065 [00:15<00:00, 325.72it/s]


test_results: {'corr': 0.3453538715839386, 'mae': 0.24903534352779388, 'mse': 0.1547049731016159, 'r2': 0.764063835144043, 'r2_weighted': 0.7662417888641357}
EarlyStopping counter: 1 out of 5
Saving run checkpoint to './results/runs/SolarEnergy/w168h24s1/5653879eeddeac0d52a2046e5f502361'.
Run state saved ... 
torch.get_default_dtype() torch.float32


100%|██████████| 36601/36601 [04:22<00:00, 139.30it/s, epoch=3, loss=0.156, lr=0.000299] 


Epoch: 4 cost time: 1217.2029941082
Traininng loss : 0.13797226795679204
Evaluating .... 


100%|██████████| 10321/10321 [00:30<00:00, 341.17it/s]


vali_results: {'corr': 0.42725977301597595, 'mae': 0.21580471098423004, 'mse': 0.1535366028547287, 'r2': 0.8366211652755737, 'r2_weighted': 0.8369054794311523}
Testing .... 


100%|██████████| 5065/5065 [00:15<00:00, 318.57it/s]


test_results: {'corr': 0.40604981780052185, 'mae': 0.21351811289787292, 'mse': 0.15183526277542114, 'r2': 0.7697080373764038, 'r2_weighted': 0.7705779671669006}
EarlyStopping counter: 2 out of 5
Saving run checkpoint to './results/runs/SolarEnergy/w168h24s1/5653879eeddeac0d52a2046e5f502361'.
Run state saved ... 
torch.get_default_dtype() torch.float32


100%|██████████| 36601/36601 [04:34<00:00, 133.58it/s, epoch=4, loss=0.061, lr=0.000299] 


Epoch: 5 cost time: 1537.6035895347595
Traininng loss : 0.13608024431021323
Evaluating .... 


100%|██████████| 10321/10321 [00:30<00:00, 343.48it/s]


vali_results: {'corr': 0.43036940693855286, 'mae': 0.21429631114006042, 'mse': 0.1552933305501938, 'r2': 0.8346232771873474, 'r2_weighted': 0.8350394368171692}
Testing .... 


100%|██████████| 5065/5065 [00:16<00:00, 310.64it/s]


test_results: {'corr': 0.4058297872543335, 'mae': 0.21331733465194702, 'mse': 0.1536165326833725, 'r2': 0.7658516764640808, 'r2_weighted': 0.7678865194320679}
EarlyStopping counter: 3 out of 5
Saving run checkpoint to './results/runs/SolarEnergy/w168h24s1/5653879eeddeac0d52a2046e5f502361'.
Run state saved ... 
torch.get_default_dtype() torch.float32


100%|██████████| 36601/36601 [04:35<00:00, 132.93it/s, epoch=5, loss=0.0878, lr=0.000298]


Epoch: 6 cost time: 1859.5271019935608
Traininng loss : 0.1307150977403186
Evaluating .... 


100%|██████████| 10321/10321 [00:28<00:00, 357.99it/s]


vali_results: {'corr': 0.42688626050949097, 'mae': 0.22351773083209991, 'mse': 0.15816693007946014, 'r2': 0.8316601514816284, 'r2_weighted': 0.8319869637489319}
Testing .... 


100%|██████████| 5065/5065 [00:16<00:00, 309.93it/s]


test_results: {'corr': 0.37635013461112976, 'mae': 0.22961583733558655, 'mse': 0.16661570966243744, 'r2': 0.7473635673522949, 'r2_weighted': 0.7482448220252991}
EarlyStopping counter: 4 out of 5
Saving run checkpoint to './results/runs/SolarEnergy/w168h24s1/5653879eeddeac0d52a2046e5f502361'.
Run state saved ... 
torch.get_default_dtype() torch.float32


100%|██████████| 36601/36601 [04:29<00:00, 135.67it/s, epoch=6, loss=0.0795, lr=0.000297]


Epoch: 7 cost time: 2174.6854321956635
Traininng loss : 0.12359967287401896
Evaluating .... 


100%|██████████| 10321/10321 [00:30<00:00, 343.24it/s]


vali_results: {'corr': 0.05398860201239586, 'mae': 0.2302062213420868, 'mse': 0.16709120571613312, 'r2': 0.8219844102859497, 'r2_weighted': 0.8225070834159851}
Testing .... 


100%|██████████| 5065/5065 [00:16<00:00, 309.69it/s]


test_results: {'corr': -0.018981575965881348, 'mae': 0.23476339876651764, 'mse': 0.17481611669063568, 'r2': 0.7337695956230164, 'r2_weighted': 0.7358540296554565}
EarlyStopping counter: 5 out of 5
Saving run checkpoint to './results/runs/SolarEnergy/w168h24s1/5653879eeddeac0d52a2046e5f502361'.
Run state saved ... 
loss no decreased for 5 epochs,  early stopping ....
Testing .... 


100%|██████████| 5065/5065 [00:16<00:00, 314.14it/s]

test_results: {'corr': 0.1825583428144455, 'mae': 0.2417413890361786, 'mse': 0.1549934297800064, 'r2': 0.7632052302360535, 'r2_weighted': 0.7658059597015381}





{'corr': 0.1825583428144455,
 'mae': 0.2417413890361786,
 'mse': 0.1549934297800064,
 'r2': 0.7632052302360535,
 'r2_weighted': 0.7658059597015381}

In [35]:
Xf = torch.fft.fft(data, dim=2, norm='forward')[:, :, :]
amp = torch.abs(Xf)
amp

tensor([[[1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0., 

tensor([-0.3273,  0.1516,  0.0161, -0.8483,  0.0766,  0.7280,  0.6113, -0.0470,
         0.9393,  0.6909,  0.0947,  0.8419,  0.1639, -0.1819,  1.0376,  0.0166,
        -0.7076,  0.9921,  0.2978, -0.4104, -1.0616,  0.6393,  0.1030,  0.5463,
        -0.3269,  0.1066, -0.1008,  0.2158, -0.1032, -0.3987, -0.4024, -0.3123],
       grad_fn=<SelectBackward0>) tensor([-0.4796, -0.6486, -1.6920, -1.0470, -0.4733, -0.2575, -0.1732, -0.4943,
         0.1441, -0.6750,  0.2937, -0.0437, -0.0987,  0.2391,  0.2676, -0.3876,
         0.9132,  0.7786, -0.3593, -0.1810, -1.0035, -1.1220, -0.5471,  0.9383,
        -0.9157,  0.2677, -0.0436,  0.3239,  0.2225,  0.7437,  0.0334,  0.9020],
       grad_fn=<SelectBackward0>)
torch.Size([7, 32]) torch.Size([7, 32])


In [48]:
bi = 0

edge_nt = torch.stack((
    edge_index[bi][0][edge_index[bi][0] < self.node_num], # source
    edge_index[bi][1][edge_index[bi][1] >= self.node_num] # target
    ))
edge_tn = torch.stack((
    edge_index[bi][0][edge_index[bi][0] >= self.node_num],
    edge_index[bi][1][edge_index[bi][1] < self.node_num]
    ))               


tensor([-0.4796, -0.6486, -1.6920, -1.0470, -0.4733, -0.2575, -0.1732, -0.4943,
         0.1441, -0.6750,  0.2937, -0.0437, -0.0987,  0.2391,  0.2676, -0.3876,
         0.9132,  0.7786, -0.3593, -0.1810, -1.0035, -1.1220, -0.5471,  0.9383,
        -0.9157,  0.2677, -0.0436,  0.3239,  0.2225,  0.7437,  0.0334,  0.9020],
       grad_fn=<SelectBackward0>)