In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Tuple
import math
from torch import Tensor
# import os
# from utils.tools import EarlyStopping, adjust_learning_rate
import time
import numpy as np
from data_provider.data_factory import data_provider

In [3]:
@dataclass
class AutoformerConfig:
    kernel_size: int        # kernel size for moving average in series decomposition
    seq_len: int            # input sequence length
    pred_len: int           # forecast horizon
    n_encoders: int         # no. of encoder layers
    n_decoders: int         # no. of decoder layers
    d_model: int            # dimension of model's hidden states and embeddings
    n_heads: int            # no. of attention heads
    d_ff: int               # dimension of feed-forward network in transformer blocks
    
    # c: Auto-correlation intensity factor
    # Controls the number of time delay steps (k = c * log(L))
    # Typically set between 1-3
    c: float

In [4]:
class AutoCorrelation(nn.Module):
    def __init__(self, config: AutoformerConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.d_k = config.d_model // config.n_heads
        self.c = config.c
        
        # Projections for Q/K/V
        self.query_proj = nn.Linear(config.d_model, config.d_model)
        self.key_proj = nn.Linear(config.d_model, config.d_model) 
        self.value_proj = nn.Linear(config.d_model, config.d_model)
        self.out_proj = nn.Linear(config.d_model, config.d_model)

    def time_delay_agg(self, values: Tensor, corr: Tensor, delays: Tensor) -> Tensor:
        # Time delay aggregation
        batch_size, L, H, d_k = values.shape
        output = torch.zeros_like(values)
        
        for i, delay in enumerate(delays):
            # Roll the values by delay steps
            rolled = torch.roll(values, shifts=int(delay), dims=1)
            # Weight by correlation score
            output += rolled * corr[:, i:i+1, :, None]
            
        return output

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        batch_size, L, _ = q.shape
        
        # Project and reshape for multi-head
        q = self.query_proj(q).view(batch_size, L, self.n_heads, self.d_k)
        k = self.key_proj(k).view(batch_size, L, self.n_heads, self.d_k)
        v = self.value_proj(v).view(batch_size, L, self.n_heads, self.d_k)

        # Compute autocorrelation for each head
        fft_q = torch.fft.rfft(q, dim=1)
        fft_k = torch.fft.rfft(k, dim=1)
        
        # Cross correlation in frequency domain
        corr = fft_q * torch.conj(fft_k)
        corr = torch.fft.irfft(corr, dim=1)
        
        # Select top-k delays
        num_delays = int(self.c * math.log(L))
        delays, indices = torch.topk(corr[:, 1:], num_delays, dim=1)  # Exclude 0 lag
        delays = F.softmax(delays, dim=-1)
        
        # Time delay aggregation
        out = self.time_delay_agg(v, delays, indices + 1)
        
        # Reshape and project output
        out = out.reshape(batch_size, L, self.d_model)
        return self.out_proj(out)

In [5]:
class MovingAvg(nn.Module):
    def __init__(self, config: AutoformerConfig):
        super().__init__()
        self.kernel_size = config.kernel_size
        self.padding_size = (self.kernel_size - 1) // 2

    def forward(self, x: Tensor) -> Tensor:
        # padding logic taken from original implementation
        padding = torch.cat([
            x[:, :1, :].repeat(1, self.padding_size, 1),  # front padding
            x,
            x[:, -1:, :].repeat(1, self.padding_size, 1)  # end padding
        ], dim=1)
        
        # compute moving average using avg_pool1d
        return F.avg_pool1d(
            padding.permute(0, 2, 1), 
            kernel_size=self.kernel_size, 
            stride=1
        ).permute(0, 2, 1)

In [6]:
class SeriesDecomp(nn.Module):
    def __init__(self, config: AutoformerConfig):
        super(SeriesDecomp, self).__init__()
        self.moving_avg = MovingAvg(config)

    def forward(self, x):
        trend = self.moving_avg(x)
        seasonal = x - trend
        return seasonal, trend

In [7]:
class AutoformerEncoderLayer(nn.Module):
    def __init__(self, config: AutoformerConfig):
        super().__init__()
        self.auto_correlation = AutoCorrelation(config)
        self.series_decomp = SeriesDecomp(config)
        
        # Two-layer feed forward network
        self.ff = nn.Sequential(
            nn.Linear(config.d_model, config.d_ff),
            nn.ReLU(),
            nn.Linear(config.d_ff, config.d_model)
        )
        
    def forward(self, x: Tensor) -> Tensor:
        # Auto-correlation with residual
        auto_out = self.auto_correlation(x, x, x)
        s1, _ = self.series_decomp(auto_out + x)
        
        # Feed forward with residual
        ff_out = self.ff(s1)
        s2, _ = self.series_decomp(ff_out + s1)
        
        return s2

In [8]:
class AutoformerEncoder(nn.Module):
    def __init__(self, config: AutoformerConfig):
        super(AutoformerEncoder, self).__init__()
        self.encoders = nn.ModuleList([AutoformerEncoderLayer(config) for l in range(config.n_encoders)])
    
    def forward(self, x: Tensor):
        for encoder in self.encoders:
            x = encoder(x)
        return x

In [9]:
class AutoformerDecoderLayer(nn.Module):
    def __init__(self, config: AutoformerConfig):
        super().__init__()
        self.auto_correlation = AutoCorrelation(config)
        self.cross_correlation = AutoCorrelation(config)
        self.series_decomp = SeriesDecomp(config)
        
        self.ff = nn.Sequential(
            nn.Linear(config.d_model, config.d_ff),
            nn.ReLU(),
            nn.Linear(config.d_ff, config.d_model)
        )
        
        # Trend projection layers
        self.trend_proj1 = nn.Linear(config.d_model, config.d_model)
        self.trend_proj2 = nn.Linear(config.d_model, config.d_model)
        self.trend_proj3 = nn.Linear(config.d_model, config.d_model)
        
    def forward(self, x_seasonal: Tensor, x_trend: Tensor, enc_out: Tensor) -> Tuple[Tensor, Tensor]:
        # Self attention
        auto_out = self.auto_correlation(x_seasonal, x_seasonal, x_seasonal)
        s1, t1 = self.series_decomp(auto_out + x_seasonal)
        trend = x_trend + self.trend_proj1(t1)
        
        # Cross attention
        cross_out = self.cross_correlation(s1, enc_out, enc_out)
        s2, t2 = self.series_decomp(cross_out + s1)
        trend = trend + self.trend_proj2(t2)
        
        # Feed forward
        ff_out = self.ff(s2)
        s3, t3 = self.series_decomp(ff_out + s2)
        trend = trend + self.trend_proj3(t3)
        
        return s3, trend

In [10]:
class AutoformerDecoder(nn.Module):
    def __init__(self, config:AutoformerConfig):
        super().__init__()
        self.config = config
        self.decoders = nn.ModuleList([AutoformerDecoderLayer(config) for l in range(config.n_decoders)])
        self.projection = nn.Linear(config.d_model, config.d_model)
        self.series_decomp = SeriesDecomp(config)

    def forward(self, x: Tensor, enc_out:Tensor) -> Tensor:
        # Initialize seasonal and trend components
        I, d = x.shape
        x_ens, x_ent = self.series_decomp(x[I//2:])
        x_des = torch.cat([x_ens, torch.zeros(self.config.pred_len, d)])
        x_det = torch.cat([x_ent, x.mean(dim=0).repeat(self.pred_len, 1)])
        
        # Progressive refinement through decoder layers
        for decoder in self.decoders:
            x_des, x_det = decoder(x_des, x_det, enc_out)
            
        # Final prediction combines both components
        return self.projection(x_des) + x_det
        


In [11]:
class Autoformer(nn.Module):
    def __init__(self, config: AutoformerConfig):
        super().__init__()
        self.config = config
        self.embedding = nn.Linear(config.d_model, config.d_model)
        self.encoder = AutoformerEncoder(config)
        self.decoder = AutoformerDecoder(config)
        self.output_proj = nn.Linear(config.d_model, 1)
        self.series_decomp = SeriesDecomp(config)
        
    def forward(self, x: Tensor) -> Tensor:
        # Split input sequence
        enc_in = x[:, :self.config.seq_len, :]
        
        # Encoder
        enc_out = self.encoder(self.embedding(enc_in))
            
        # Initialize decoder inputs
        I = self.config.seq_len
        dec_init = x[:, I//2:I, :]
        seasonal, trend = self.series_decomp(dec_init)
        
        # Pad seasonal and trend components
        pad_seasonal = torch.zeros(
            (x.size(0), self.config.pred_len, x.size(-1)), 
            device=x.device
        )
        pad_trend = x[:, -1:, :].repeat(1, self.config.pred_len, 1)
        
        seasonal = torch.cat([seasonal, pad_seasonal], dim=1)
        trend = torch.cat([trend, pad_trend], dim=1)
        
        # Decoder
        for dec_layer in self.decoder:
            seasonal, trend = dec_layer(seasonal, trend, enc_out)
            
        # Final prediction
        seasonal = self.output_proj(seasonal)
        return seasonal + trend

In [13]:
curr_config = AutoformerConfig(
    kernel_size=25,                     # default kernel size for moving average
    seq_len=96,                         # from --seq_len
    pred_len=24,                        # from --pred_len
    label_len=48,                       # from --label_len
    n_encoders=2,                       # from --e_layers
    n_decoders=1,                       # from --d_layers
    d_model=512,                        # typical transformer dimension
    n_heads=8,                          # typical transformer heads
    d_ff=2048,                          # typical transformer feed-forward dimension
    c=3.0,                              # from --factor, controls auto-correlation delay steps
    enc_in=7,                           # no. of input features for encoder 
    dec_in=7,                           # no. of input features for decoder
    c_out=7,                            # no. of output features for decoder
    root_path="./dataset/ETT-small/",   # path to dataset
    data_path="ETTh1.csv",
    model_id="ETTh1_96_24",
    model="Autoformer",
    data="ETTh1",
    features="M",
    des="Exp",
    itr=1,
    is_training=1,
    patience=7,
    device="cuda" if torch.cuda.is_available() else "cpu",
    learning_rate=0.0001,
    )

TypeError: AutoformerConfig.__init__() got an unexpected keyword argument 'label_len'

In [None]:
model = Autoformer(curr_config).float()

In [None]:
def _predict(self, batch_x, batch_y, batch_x_mark, batch_y_mark):
    # decoder input
    dec_inp = torch.zeros_like(batch_y[:, -curr_config.pred_len:, :]).float()
    dec_inp = torch.cat([batch_y[:, :curr_config.label_len, :], dec_inp], dim=1).float().to(curr_config.device)
    # encoder - decoder

    def _run_model():
        outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
        if curr_config.output_attention:
            outputs = outputs[0]
        return outputs

    outputs = _run_model()

    f_dim = -1 if self.args.features == 'MS' else 0
    outputs = outputs[:, -curr_config.pred_len:, f_dim:]
    batch_y = batch_y[:, -curr_config.pred_len:, f_dim:].to(curr_config.device)

    return outputs, batch_y

In [None]:
def vali(vali_data, vali_loader, criterion):
    total_loss = []
    model.eval()
    with torch.no_grad():
        for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
            batch_x = batch_x.float().to(curr_config.device)
            batch_y = batch_y.float()

            batch_x_mark = batch_x_mark.float().to(curr_config.device)
            batch_y_mark = batch_y_mark.float().to(curr_config.device)

            outputs, batch_y = _predict(batch_x, batch_y, batch_x_mark, batch_y_mark)

            pred = outputs.detach().cpu()
            true = batch_y.detach().cpu()

            loss = criterion(pred, true)

            total_loss.append(loss)
    total_loss = np.average(total_loss)
    model.train()
    return total_loss

In [None]:
def train():
    train_data, train_loader = data_provider(curr_config, flag='train')
    vali_data, vali_loader = data_provider(curr_config, flag='val')
    test_data, test_loader = data_provider(curr_config, flag='test')

    time_now = time.time()
    train_steps = len(train_loader)
    # early_stopping = EarlyStopping(patience=curr_config.patience, verbose=True)

    model_optim = torch.optim.Adam(model.parameters(), lr=curr_config.learning_rate)
    criterion = nn.MSELoss()

    if curr_config.use_amp:
        scaler = torch.amp.GradScaler()

    for epoch in range(curr_config.train_epochs):
        iter_count = 0
        train_loss = []

        model.train()
        epoch_time = time.time()
        for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
            iter_count += 1
            model_optim.zero_grad()
            batch_x = batch_x.float().to(curr_config.device)

            batch_y = batch_y.float().to(curr_config.device)
            batch_x_mark = batch_x_mark.float().to(curr_config.device)
            batch_y_mark = batch_y_mark.float().to(curr_config.device)

            outputs, batch_y = _predict(batch_x, batch_y, batch_x_mark, batch_y_mark)

            loss = criterion(outputs, batch_y)
            train_loss.append(loss.item())

            if (i + 1) % 100 == 0:
                print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                speed = (time.time() - time_now) / iter_count
                left_time = speed * ((curr_config.train_epochs - epoch) * train_steps - i)
                print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                iter_count = 0
                time_now = time.time()

            if curr_config.use_amp:
                scaler.scale(loss).backward()
                scaler.step(model_optim)
                scaler.update()
            else:
                loss.backward()
                model_optim.step()

        print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
        train_loss = np.average(train_loss)
        vali_loss = vali(vali_data, vali_loader, criterion)
        test_loss = vali(test_data, test_loader, criterion)

        print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
            epoch + 1, train_steps, train_loss, vali_loss, test_loss))
        # early_stopping(vali_loss, model, path)
        # if early_stopping.early_stop:
        #     print("Early stopping")
        #     break

        # adjust_learning_rate(model_optim, epoch + 1, curr_config)
    return