In [1]:
import datetime
from enum import Enum
import json
import os
import random
import re
import signal
import threading
import time
import hashlib
from prettytable import PrettyTable
import sys
####
from typing import Dict, List, Type, Union
import numpy as np
import torch
from torchmetrics import MeanSquaredError, MetricCollection, MeanAbsoluteError, MeanAbsolutePercentageError
from tqdm import tqdm
import wandb
from torch_timeseries.data.scaler import *
from torch_timeseries.datasets import *
from torch_timeseries.experiments.experiment import Experiment

from torch_timeseries.datasets.dataset import TimeSeriesDataset
from torch_timeseries.datasets.splitter import SequenceRandomSplitter, SequenceSplitter
from torch_timeseries.datasets.dataloader import (
    ChunkSequenceTimefeatureDataLoader,
    DDPChunkSequenceTimefeatureDataLoader,
)
from torch_timeseries.datasets.wrapper import MultiStepTimeFeatureSet
from torch_timeseries.models.Informer import Informer
from torch.nn import MSELoss, L1Loss

from torch.optim import Optimizer, Adam
from torch.utils.data import Dataset, DataLoader, RandomSampler, Subset

from torch.nn import DataParallel
import torch.nn as nn
from dataclasses import asdict,dataclass

from torch_timeseries.nn.metric import R2, Corr, TrendAcc,RMSE, compute_corr, compute_r2
from torch_timeseries.metrics.masked_mape import MaskedMAPE
from torch_timeseries.utils.early_stopping import EarlyStopping
import json
import codecs


from torch_timeseries.layers.tcn_output8 import TCNOuputLayer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# device = 'device'

# dataset : TimeSeriesDataset = ExchangeRate(root='/notebooks/pytorch_timeseries/')
# scaler = StandarScaler(device='')
%load_ext autoreload
%autoreload 2

from torch_timeseries.datasets.dataset import TimeSeriesStaticGraphDataset


class STI(torch.nn.Module):
    # 
    def __init__(self,seq_len, latent_dim,num_nodes,out_seq_len,tcn_layers=5,dilated_factor=2,tcn_channel=16,kernel_set=[2,3,6,7],d0=1, layer_norm_affline=True) -> None:
        super().__init__()
        
        self.tcn = TCNOuputLayer(
                        seq_len,num_nodes,out_seq_len,
                        tcn_layers,3,dilated_factor,
                        tcn_channel,kernel_set=kernel_set,d0=d0,
                        layer_norm_affline=True     
                    )

        self.tcn_input_dim = self.tcn.tcn.receptive_field
        self.tcn = TCNOuputLayer(
                        self.tcn_input_dim,num_nodes,out_seq_len,
                        tcn_layers,3,dilated_factor,
                        tcn_channel,kernel_set=kernel_set,d0=d0,
                        layer_norm_affline=True     
                    )


        self.spatial_projection = nn.Sequential(
            nn.Linear(seq_len, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim)
        )
        self.temporal_projection = nn.Sequential(
            nn.Linear(num_nodes, self.tcn_input_dim),
            nn.ReLU(),
            nn.Linear(self.tcn_input_dim, self.tcn_input_dim)
        )
        
        self.freq_projection = nn.Sequential(
            nn.Linear(num_nodes, self.tcn_input_dim),
            nn.ReLU(),
            nn.Linear(self.tcn_input_dim, self.tcn_input_dim)
        )

        

        
        self.spatial_rebuild = nn.Sequential(
            nn.Linear(latent_dim, self.tcn_input_dim),
            nn.ReLU(),
            nn.Linear(self.tcn_input_dim, self.tcn_input_dim)
        )
        
        self.temporal_rebuild = nn.Sequential(
            nn.Linear(seq_len, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, num_nodes)
        )
        
        self.freq_rebuild = nn.Sequential(
            nn.Linear(seq_len, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, num_nodes)
        )
        
        
        


    
    def forward(self, x):
        # x : (B, N, T)
        
        seq_last = x[:,:,-1:].detach()
        x = x - seq_last
        # import pdb;pdb.set_trace
        
        
        Xs = self.spatial_projection(x) # (B, N, latent_dim)
        Xt = self.temporal_projection(x.transpose(1, 2)) # (B, T, tcn_input_dim)
        Xf = torch.abs(torch.fft.fft(x, dim=2, norm='forward'))# (B, N, T)
        
        Xf = self.freq_projection(Xf.transpose(1,2 ))  # （B, T, latent_dim)
        
        Zs = self.spatial_rebuild(Xs)  # (B, N, tcn_input_dim)
        Zt = self.temporal_rebuild(Xt.transpose(1,2)).transpose(1,2) # (B, N, tcn_input_dim)
        Zf = self.freq_rebuild(Xf.transpose(1, 2)).transpose(1, 2) # (B, N, tcn_input_dim)
        
        Z = torch.stack([Zs, Zt, Zf], dim=1)
        O = self.tcn(Z) # (B, O , N)
        
        
        O = (O.transpose(1,2) + seq_last).transpose(1,2)
        return O
        # Zs = self.freq_projection(Xs)
        # Zt = self.freq_projection(Xt)
        # Zf = self.freq_projection(Xf)
    # self.feq_projection = nn.Sequential(
    #     nn.Linear()
    # )
    
    
    
    
@dataclass
class TSExperiment(Experiment):
    
    latent_dim: int = 1024
    tcn_layers : int = 5

    def _process_one_batch(self, batch_x, batch_y, batch_x_date_enc, batch_y_date_enc):
        # inputs:
        # batch_x: (B, T, N)
        # batch_y: (B, O, N)
        # ouputs:
        # - pred: (B, N)/(B, O, N)
        # - label: (B, N)/(B, O, N)
        batch_size = batch_x.size(0)
        batch_x = batch_x.to(self.device, dtype=torch.float32)
        batch_y = batch_y.to(self.device, dtype=torch.float32)
        batch_x_date_enc = batch_x_date_enc.to(self.device).float()
        batch_y_date_enc = batch_y_date_enc.to(self.device).float()
        batch_x = batch_x.transpose(1,2)
        outputs = self.model(batch_x)  # torch.Size([batch_size, num_nodes])
        # single step prediction
        return outputs, batch_y


    def _init_model(self):
        predefined_NN_adj = None
        padded_A = None
        if isinstance(self.dataset, TimeSeriesStaticGraphDataset) and self.pred_len > 1:
            predefined_NN_adj = torch.tensor(self.dataset.adj).to(self.device)
            D = torch.diag(torch.sum(predefined_NN_adj, dim=1))
            D_sqrt_inv = torch.sqrt(torch.inverse(D))
            normalized_predefined_adj = D_sqrt_inv @predefined_NN_adj @ D_sqrt_inv
            padded_A = torch.nn.functional.pad(normalized_predefined_adj, (0, self.windows, 0, self.windows), mode='constant', value=0).float()

        else:
            padded_A = None

        if isinstance(self.dataset, PeMS_D7):
            temporal_embed_dim = 0
        else:
            temporal_embed_dim = 4
        self.model = STI(
            tcn_layers=self.tcn_layers,
            seq_len=self.windows, latent_dim=self.latent_dim,num_nodes=self.dataset.num_features,out_seq_len=self.pred_len,dilated_factor=2,tcn_channel=16,kernel_set=[2,3,6,7],d0=1, layer_norm_affline=True
        )
        self.model = self.model.to(self.device)


In [4]:
exp = TSExperiment(
    device="cuda:0",
    latent_dim=16,
    dataset_type="SolarEnergy",
    tcn_layers=5,
    horizon=24,
    epochs=100,
    windows=168,
    lr=0.0001,
    data_path='/notebooks/pytorch_timeseries/data/'
)
# model = STI(168, 128, 9, 1)
# out = model(data)

exp.run(seed=662)

Using downloaded and verified file: /notebooks/pytorch_timeseries/data/solar_AL/solar_AL.txt.gz
Extracting /notebooks/pytorch_timeseries/data/solar_AL/solar_AL.txt.gz to /notebooks/pytorch_timeseries/data/solar_AL
train steps: 36601
val steps: 10321
test steps: 5065
torch.get_default_dtype() torch.float32
Creating running results saving dir: './results/runs/SolarEnergy/w168h24s1/b2324a7b7ac74be9469227cc8ea12a11'.
run : 0 in seed: 662
+---------------------------------------+------------+
|                Modules                | Parameters |
+---------------------------------------+------------+
|        tcn.channel_layer.weight       |     48     |
|         tcn.channel_layer.bias        |     16     |
| tcn.tcn.filter_convs.0.tconv.0.weight |    128     |
|  tcn.tcn.filter_convs.0.tconv.0.bias  |     4      |
| tcn.tcn.filter_convs.0.tconv.1.weight |    192     |
|  tcn.tcn.filter_convs.0.tconv.1.bias  |     4      |
| tcn.tcn.filter_convs.0.tconv.2.weight |    384     |
|  tcn.tcn.f

100%|██████████| 36601/36601 [02:41<00:00, 226.06it/s, epoch=0, loss=0.108, lr=0.0005] 


Epoch: 1 cost time: 161.91121077537537
Val on train....


100%|██████████| 36601/36601 [00:53<00:00, 688.44it/s]


Val on train result: {'corr': 0.39743703603744507, 'mae': 0.24525612592697144, 'mse': 0.16543780267238617, 'r2': 0.8433681130409241, 'r2_weighted': 0.8433005809783936}
Evaluating .... 


100%|██████████| 10321/10321 [00:16<00:00, 644.45it/s]


vali_results: {'corr': 0.43304458260536194, 'mae': 0.2397182732820511, 'mse': 0.15938986837863922, 'r2': 0.8301283717155457, 'r2_weighted': 0.8306878805160522}
Testing .... 


100%|██████████| 5065/5065 [00:09<00:00, 516.73it/s]


test_results: {'corr': 0.4295472204685211, 'mae': 0.25305303931236267, 'mse': 0.1654668003320694, 'r2': 0.7468786239624023, 'r2_weighted': 0.7499808073043823}
Validation loss decreased (inf --> 0.159390).  Saving model ...
Saving run checkpoint to './results/runs/SolarEnergy/w168h24s1/b2324a7b7ac74be9469227cc8ea12a11'.
Run state saved ... 
torch.get_default_dtype() torch.float32


100%|██████████| 36601/36601 [02:48<00:00, 217.00it/s, epoch=1, loss=0.105, lr=0.0005] 


Epoch: 2 cost time: 409.89287209510803
Val on train....


 22%|██▏       | 8064/36601 [00:13<00:44, 638.28it/s]

In [35]:
Xf = torch.fft.fft(data, dim=2, norm='forward')[:, :, :]
amp = torch.abs(Xf)
amp

tensor([[[1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0., 

tensor([-0.3273,  0.1516,  0.0161, -0.8483,  0.0766,  0.7280,  0.6113, -0.0470,
         0.9393,  0.6909,  0.0947,  0.8419,  0.1639, -0.1819,  1.0376,  0.0166,
        -0.7076,  0.9921,  0.2978, -0.4104, -1.0616,  0.6393,  0.1030,  0.5463,
        -0.3269,  0.1066, -0.1008,  0.2158, -0.1032, -0.3987, -0.4024, -0.3123],
       grad_fn=<SelectBackward0>) tensor([-0.4796, -0.6486, -1.6920, -1.0470, -0.4733, -0.2575, -0.1732, -0.4943,
         0.1441, -0.6750,  0.2937, -0.0437, -0.0987,  0.2391,  0.2676, -0.3876,
         0.9132,  0.7786, -0.3593, -0.1810, -1.0035, -1.1220, -0.5471,  0.9383,
        -0.9157,  0.2677, -0.0436,  0.3239,  0.2225,  0.7437,  0.0334,  0.9020],
       grad_fn=<SelectBackward0>)
torch.Size([7, 32]) torch.Size([7, 32])


In [48]:
bi = 0

edge_nt = torch.stack((
    edge_index[bi][0][edge_index[bi][0] < self.node_num], # source
    edge_index[bi][1][edge_index[bi][1] >= self.node_num] # target
    ))
edge_tn = torch.stack((
    edge_index[bi][0][edge_index[bi][0] >= self.node_num],
    edge_index[bi][1][edge_index[bi][1] < self.node_num]
    ))               


tensor([-0.4796, -0.6486, -1.6920, -1.0470, -0.4733, -0.2575, -0.1732, -0.4943,
         0.1441, -0.6750,  0.2937, -0.0437, -0.0987,  0.2391,  0.2676, -0.3876,
         0.9132,  0.7786, -0.3593, -0.1810, -1.0035, -1.1220, -0.5471,  0.9383,
        -0.9157,  0.2677, -0.0436,  0.3239,  0.2225,  0.7437,  0.0334,  0.9020],
       grad_fn=<SelectBackward0>)