# Libraries


In [2]:
# standard
import pandas as pd
import numpy as np
import math
from math import sqrt

# machine learning
import torch
import torch.nn as nn
import torch.nn.functional as F

import sys
import os
sys.path.append('..')


from embeddings import PositionalEncoding
from sparse_attention import DetSparseAttentionModule

# Sparse Decoder

In [3]:
class SparseDecoder(nn.Module):
    def __init__(self, d_model, n_heads, encoder_output_dim, forecast_horizon=1, max_len=5000, d_ff=None, dropout=0.1, activation="relu"):
        super(SparseDecoder, self).__init__()
        self.d_model = encoder_output_dim[-1]
        self.batch_size = encoder_output_dim[0]
        self.n_heads = n_heads
        self.forecast_horizon = forecast_horizon
        d_ff = d_ff or 4*self.d_model
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

        # Initialize PositionalEncoding
        self.pos_encoder = PositionalEncoding(self.batch_size, self.d_model, max_len)

        # Sparse Attention Module for cross attention
        self.cross_attention = DetSparseAttentionModule(
            d_model=self.d_model,
            n_heads=self.n_heads,
            prob_sparse_factor=5,
            seq_len=encoder_output_dim[1]
            )

        # Feed-forward network components
        self.conv1 = nn.Conv1d(in_channels=self.d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=self.d_model, kernel_size=1)
        
        # Normalization layers
        self.norm1 = nn.LayerNorm(self.d_model)
        self.norm2 = nn.LayerNorm(self.d_model)

        # Output layer
        self.output_layer = nn.Linear(self.d_model, 1)

    def forward(self, encoder_output, attn_mask=None):
        # Generate positional encodings
        dummy_input = torch.zeros(self.batch_size, self.forecast_horizon, self.d_model)
        pos_encodings = self.pos_encoder(dummy_input)

        # Apply encoder-decoder attention using positional encodings as queries and encoder outputs as keys and values
        attn_output = self.cross_attention(encoder_output, encoder_output, encoder_output).mean(1)
        attn_output = self.norm1(attn_output + self.dropout(attn_output)).unsqueeze(1)

        # Feed-forward network
        ff_output = attn_output.transpose(-1, 1)  # Prepare for conv1d
        ff_output = self.dropout(self.activation(self.conv1(ff_output)))
        ff_output = self.dropout(self.conv2(ff_output))
        ff_output = ff_output.transpose(-1, 1)  # Back to original dims
        ff_output = self.norm2(attn_output + self.dropout(ff_output))

        # Generate forecasts based on the attention output
        forecasts = self.output_layer(ff_output).squeeze(-1)
        
        return forecasts

In [None]:
    # Calculate kernel size and stride
    # These calculations are made to dynamically adjust according to the input tensor's dimensions
    kernel_size = seq_len // target_dim
    stride = kernel_size  # Assuming stride equals kernel size for direct downscaling
    
    # Pooling operation
    # Transpose the tensor to match the expected input shape of avg_pool1d ([batch_size, channels, width])
    input_tensor_transposed = input_tensor.transpose(1, 2)  # Now shape [batch_size, channels, seq_len]
    
    # Apply average pooling
    pooled_tensor = F.avg_pool1d(input_tensor_transposed, kernel_size=kernel_size, stride=stride)
    
    # Transpose back to the original dimension order
    output_tensor = pooled_tensor.transpose(1, 2)  # Shape [batch_size, target_dim, channels]

In [9]:
class DetSparseDecoder(nn.Module):
    def __init__(self, d_model, n_heads, batch_size, seq_len, forecast_horizon, max_len=5000, d_ff=None, dropout=0.1, activation="relu"):
        super(DetSparseDecoder, self).__init__()
        self.SparseDecoder = SparseDecoder(
            encoder_output_dim=[batch_size, seq_len, d_model],
            forecast_horizon=forecast_horizon,
            d_model=d_model,
            n_heads=n_heads
        )

    def forward(self, encoder_output):
        # calculate attention
        attention_output = self.SparseDecoder(encoder_output).squeeze(1)

        return attention_output

In [5]:
class ProbSparseDecoder(nn.Module):
    def __init__(self, d_model, n_heads, encoder_output_dim, forecast_horizon, max_len=5000, d_ff=None, dropout=0.1, activation="relu"):
        super(ProbSparseDecoder, self).__init__()
        self.SparseDecoder_mean = SparseDecoder(
            encoder_output_dim=encoder_output_dim,
            forecast_horizon=forecast_horizon,
            d_model=d_model,
            n_heads=n_heads
        )
        self.SparseDecoder_var = SparseDecoder(
            encoder_output_dim=encoder_output_dim,
            forecast_horizon=forecast_horizon,
            d_model=d_model,
            n_heads=n_heads
        )

    def forward(self, encoder_output):
        # split tensor to extract mean and variance
        output_mean = encoder_output[0]
        output_variance = encoder_output[1]
        
        # calculate attention
        attention_output_mean = self.SparseDecoder_mean(output_mean).squeeze(1)
        attention_output_var = self.SparseDecoder_var(output_variance).squeeze(1)

        # Combine the processed means and variances
        combined_output = torch.stack([attention_output_mean, attention_output_var], dim=0)

        return combined_output

In [10]:
%store -r output
n_heads_global = 4
probabilistic_model = False
len_embedding_vector = 64
batch_size = 512
seq_len = 576

# determine which model to use
if probabilistic_model == True:
    model = ProbSparseDecoder(
        d_model = output.shape[-1],
        n_heads = n_heads_global,
        forecast_horizon = 72
    )
else:
    model = DetSparseDecoder(
        d_model = len_embedding_vector,
        n_heads = n_heads_global,
        forecast_horizon = 72,
        batch_size=batch_size,
        seq_len=seq_len,
    )

model(output).shape

torch.Size([512])