In [1]:
import sys
sys.path.append('../')

from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
# from models import Informer, Autoformer, Transformer, Reformer
from utils.tools import EarlyStopping, adjust_learning_rate, visual
from utils.metrics import metric

In [2]:
import numpy as np
import torch
import torch.nn as nn
from torch import optim

import os
import time

import warnings
import matplotlib.pyplot as plt
import numpy as np

warnings.filterwarnings('ignore')

In [3]:
from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred
from torch.utils.data import DataLoader

## Dataset

In [4]:
Data = Dataset_ETT_hour

timeenc = 0

shuffle_flag = True
drop_last = True
batch_size = 32
freq = 'h'

root_path = '../dataset'
data_path = 'ETTh1.csv'
flag='train'
seq_len = 96
label_len = 48
pred_len = 48
features = 'S'
target = 'OT'

In [5]:
data_set = Data(
    root_path=root_path,
    data_path=data_path,
    flag=flag,
    size=[seq_len, label_len, pred_len],
    features=features,
    target=target,
    timeenc=timeenc,
    freq=freq
)

In [6]:
data_loader = DataLoader(
    data_set,
    batch_size=batch_size,
    shuffle=shuffle_flag,
    num_workers=10,
    drop_last=drop_last)

In [7]:
batch_x, batch_y, batch_x_mark, batch_y_mark = next(iter(data_loader))

In [8]:
print(batch_x.shape)
print(batch_y.shape)
print(batch_x_mark.shape)
print(batch_y_mark.shape)

torch.Size([32, 96, 1])
torch.Size([32, 96, 1])
torch.Size([32, 96, 4])
torch.Size([32, 96, 4])


In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [10]:
batch_x = batch_x.float().to(device)
batch_y = batch_y.float().to(device)
batch_x_mark = batch_x_mark.float().to(device)
batch_y_mark = batch_y_mark.float().to(device)

In [11]:
seq_len

96

In [12]:
label_len

48

In [13]:
pred_len

48

In [14]:
# decoder input
dec_inp = torch.zeros_like(batch_y[:, -pred_len:, :]).float()
dec_inp = torch.cat([batch_y[:, :label_len, :], dec_inp], dim=1).float().to(device)

In [15]:
dec_inp.shape

torch.Size([32, 96, 1])

## Model

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Embed import DataEmbedding, DataEmbedding_wo_pos
from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer
from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp
import math
import numpy as np

In [18]:
moving_avg = 25
kernel_size = moving_avg

# decomp = series_decomp(kernel_size)

In [19]:
x_enc = batch_x
x_mark_enc = batch_x_mark
x_dec = dec_inp
x_mark_dec = batch_y_mark

In [20]:
print(x_enc.shape)
print(x_mark_enc.shape)
print(x_dec.shape)
print(x_mark_dec.shape)

torch.Size([32, 96, 1])
torch.Size([32, 96, 4])
torch.Size([32, 96, 1])
torch.Size([32, 96, 4])


In [21]:
mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, pred_len, 1)


In [22]:
zeros = torch.zeros([x_dec.shape[0], pred_len, x_dec.shape[2]], device=x_enc.device)

In [23]:
print(zeros.shape)
print(mean.shape)

torch.Size([32, 48, 1])
torch.Size([32, 48, 1])


In [24]:
# series decomposition 

# moving avg 
kernel_size = kernel_size
avg = nn.AvgPool1d(kernel_size, stride=1, padding=0)

In [25]:
front = x_enc[:, 0:1, :].repeat(1, (kernel_size - 1) // 2, 1)
end = x_enc[:, -1:, :].repeat(1, (kernel_size - 1) // 2, 1)

In [26]:
front.shape

torch.Size([32, 12, 1])

In [27]:
end.shape

torch.Size([32, 12, 1])

In [28]:
x = torch.cat([front, x_enc, end], dim=1)

In [29]:
x.shape

torch.Size([32, 120, 1])

In [30]:
x = avg(x.permute(0, 2, 1))
x.shape

torch.Size([32, 1, 96])

In [31]:
x = x.permute(0,2,1)
x.shape

torch.Size([32, 96, 1])

In [32]:
moving_mean = x
res = x - moving_mean

print(moving_mean.shape)
print(res.shape)

torch.Size([32, 96, 1])
torch.Size([32, 96, 1])


In [33]:
seasonal_init = res
trend_init = moving_mean

In [34]:
# decoder input 

trend_init = torch.cat([trend_init[:, -label_len:, :], mean], dim=1)
seasonal_init = torch.cat([seasonal_init[:, -label_len:, :], zeros], dim=1)

In [35]:
enc_in = 1
dec_in = 1
d_model = 512
embed = 'timeF'
dropout = 0.05

enc_embedding = DataEmbedding_wo_pos(enc_in, d_model, embed, freq,
                                            dropout).to(device)
dec_embedding = DataEmbedding_wo_pos(dec_in, d_model, embed, freq,
                                            dropout).to(device)

enc_out = enc_embedding(x_enc, x_mark_enc)
dec_out = dec_embedding(seasonal_init, x_mark_dec)


In [36]:
enc_out.shape

torch.Size([32, 96, 512])

In [37]:
dec_out.shape

torch.Size([32, 96, 512])

### Encoder

In [107]:
#-- EncoderLayer

# new_x, attn = self.attention(
#     x, x, x,
#     attn_mask=attn_mask
# )
# x = x + self.dropout(new_x)
# x, _ = self.decomp1(x)
# y = x
# y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
# y = self.dropout(self.conv2(y).transpose(-1, 1))
# res, _ = self.decomp2(x + y)

In [38]:
# AutoCorrelationLayer
# new_x, attn = attention(x,x,x, attn_mask=attn_mask)

x = enc_out 
attn_mask = None 

# correlation = AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout,
#                                         output_attention=configs.output_attention)
d_model = 512
n_heads = 8
d_keys = None
d_values = None

d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)

In [39]:
# inner_correlation = correlation
query_projection = nn.Linear(d_model, d_keys * n_heads).to(device)
key_projection = nn.Linear(d_model, d_keys * n_heads).to(device)
value_projection = nn.Linear(d_model, d_values * n_heads).to(device)
out_projection = nn.Linear(d_values * n_heads, d_model).to(device)
n_heads = n_heads

In [40]:
query_projection

Linear(in_features=512, out_features=512, bias=True)

In [41]:
queries = enc_out
keys = enc_out 
values = enc_out 

B, L, _ = queries.shape
_, S, _ = keys.shape
H = n_heads

In [42]:
queries.shape

torch.Size([32, 96, 512])

In [43]:
keys.shape

torch.Size([32, 96, 512])

In [44]:
queries = query_projection(queries).view(B, L, H, -1)
keys = key_projection(keys).view(B, S, H, -1)
values = value_projection(values).view(B, S, H, -1)

In [45]:
print(queries.shape)
print(keys.shape)
print(values.shape)

torch.Size([32, 96, 8, 64])
torch.Size([32, 96, 8, 64])
torch.Size([32, 96, 8, 64])


In [46]:
# inner correlation 
# out, attn = inner_correlation(quaries, keys, values, attn_mask)

mask_flag = False
factor = 1
scale = None 
attention_dropout=0.1
output_attention=False 


In [47]:
B, L, H, E = queries.shape
_, S, _, D = values.shape

In [48]:
values = values[:, :L, :, :] # 그대로 B, L, H, E
keys = keys[:, :L, :, :] # 그대로 B, L, H, E

In [49]:
# period-based dependencies
q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
res = q_fft * torch.conj(k_fft)
corr = torch.fft.irfft(res, dim=-1)

In [50]:
q_fft.shape

torch.Size([32, 8, 64, 49])

In [51]:
k_fft.shape

torch.Size([32, 8, 64, 49])

In [52]:
res.shape

torch.Size([32, 8, 64, 49])

In [53]:
corr.shape

torch.Size([32, 8, 64, 96])

In [54]:
# # time delay agg
# if self.training:
#     V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
# else:
#     V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)

# if self.output_attention:
#     return (V.contiguous(), corr.permute(0, 3, 1, 2))
# else:
#     return (V.contiguous(), None)

In [55]:
values.shape

torch.Size([32, 96, 8, 64])

In [56]:
values_ = values.permute(0,2,3,1).contiguous()

In [57]:
values_.shape

torch.Size([32, 8, 64, 96])

In [58]:
corr.shape

torch.Size([32, 8, 64, 96])

In [59]:
head = values_.shape[1]
channel = values_.shape[2]
length = values_.shape[3]

In [60]:
# find top k
top_k = int(factor * math.log(length))
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)


In [61]:
weights.shape

torch.Size([32, 4])

In [78]:
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values_
delays_agg = torch.zeros_like(values_).float()
for i in range(top_k):
    pattern = torch.roll(tmp_values, -int(index[i]), -1)
    delays_agg = delays_agg + pattern * \
                    (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))


In [79]:
delays_agg.shape

torch.Size([32, 8, 64, 96])

In [80]:
corr.shape

torch.Size([32, 8, 64, 96])

In [81]:
corr_ = corr.permute(0,3,1,2)
corr_.shape

torch.Size([32, 96, 8, 64])

In [88]:
V = delays_agg.permute(0, 3, 1, 2)
V.shape

torch.Size([32, 96, 8, 64])

In [89]:
V = V.contiguous()
V.shape

torch.Size([32, 96, 8, 64])

In [85]:
print(queries.shape)
print(keys.shape)
print(values.shape)

torch.Size([32, 96, 8, 64])
torch.Size([32, 96, 8, 64])
torch.Size([32, 96, 8, 64])


In [87]:
enc_out.shape

torch.Size([32, 96, 512])