# Libraries


In [None]:
# standard
import pandas as pd
import numpy as np
import math
from math import sqrt

# machine learning
import torch
import torch.nn as nn
import torch.nn.functional as F

# Sparse Decoder

In [None]:
class SparseDecoder(nn.Module):
    def __init__(self, d_model, n_heads, encoder_output_dim, forecast_horizon=22, max_len=5000, d_ff=None, dropout=0.1, activation="relu"):
        super(SparseDecoder, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.forecast_horizon = forecast_horizon
        d_ff = d_ff or 4*d_model
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

        # Initialize PositionalEncoding
        self.pos_encoder = PositionalEncoding(d_model, max_len)

        # Sparse Attention Module for cross attention
        self.cross_attention = DetSparseAttentionModule(d_model, n_heads, prob_sparse_factor=5)

        # Feed-forward network components
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        
        # Normalization layers
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Output layer
        self.output_layer = nn.Linear(d_model, 1)

    def forward(self, encoder_output, attn_mask=None):
        # Generate positional encodings
        dummy_input = torch.zeros(self.forecast_horizon, self.d_model).unsqueeze(0)
        pos_encodings = self.pos_encoder(dummy_input)

        # Apply encoder-decoder attention using positional encodings as queries and encoder outputs as keys and values
        attn_output = self.cross_attention(pos_encodings, encoder_output, encoder_output, attn_mask)
        attn_output = self.norm1(attn_output + self.dropout(attn_output))

        # Feed-forward network
        ff_output = attn_output.transpose(-1, 1)  # Prepare for conv1d
        ff_output = self.dropout(self.activation(self.conv1(ff_output)))
        ff_output = self.dropout(self.conv2(ff_output))
        ff_output = ff_output.transpose(-1, 1)  # Back to original dims
        ff_output = self.norm2(attn_output + self.dropout(ff_output))

        # Generate forecasts based on the attention output
        forecasts = self.output_layer(ff_output).squeeze(-1)
        
        return forecasts

In [None]:
class DetSparseDecoder(nn.Module):
    def __init__(self, d_model, n_heads, prob_sparse_factor=5, attention_dropout=0.1):
        super(DetSparseAttentionModule, self).__init__()
        self.attention_layer = AttentionLayer(
            ProbAttention(mask_flag=False, factor=prob_sparse_factor, scale=None, attention_dropout=attention_dropout),
            d_model=d_model, n_heads=n_heads
        )

    def forward(self, queries, keys, values, attn_mask=None):
        # calculate attention
        attention_output, _ = self.attention_layer(queries, keys, values, attn_mask)

        return attention_output

In [None]:
class ProbSparseAttentionModule(nn.Module):
    def __init__(self, d_model, n_heads, prob_sparse_factor=5, attention_dropout=0.1):
        super(ProbSparseAttentionModule, self).__init__()
        # Attention layers for both means and variances
        self.attention_layer_means = AttentionLayer(
            ProbAttention(mask_flag=False, factor=prob_sparse_factor, scale=None, attention_dropout=attention_dropout),
            d_model=d_model, n_heads=n_heads
        )
        self.attention_layer_vars = AttentionLayer(
            ProbAttention(mask_flag=False, factor=prob_sparse_factor, scale=None, attention_dropout=attention_dropout),
            d_model=d_model, n_heads=n_heads
        )

    def forward(self, queries, keys, values, attn_mask=None):
        # Split the input tensors to extract means and variances
        queries_means, queries_vars = queries[0], queries[1]
        keys_means, keys_vars = keys[0], keys[1]
        values_means, values_vars = values[0], values[1]

        # Process means through the attention layer for means
        attention_output_means, _ = self.attention_layer_means(queries_means, keys_means, values_means, attn_mask)
        
        # Process variances through the attention layer for variances
        attention_output_vars, _ = self.attention_layer_vars(queries_vars, keys_vars, values_vars, attn_mask)

        # Combine the processed means and variances
        combined_output = torch.stack([attention_output_means, attention_output_vars], dim=0)

        return combined_output