In [26]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

class ScaledDotProductAttention(nn.Module):
    def __init__(self, attention_dim):
        super(ScaledDotProductAttention, self).__init__()
        self.scaling_factor = np.sqrt(attention_dim)

    def forward(self, q, k, v):
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scaling_factor
        attn_weights = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, v)
        return output, attn_weights
    
class SimpleSelfAttention(nn.Module):
    def __init__(self, embed_size):
        super(SimpleSelfAttention, self).__init__()
        self.embed_size = embed_size

        # Linear layers to compute the query, key, and value matrices
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        queries = self.query(x)
        keys = self.key(x)
        values = self.value(x)

        # Compute attention scores
        energy = torch.matmul(queries, keys.transpose(-2, -1)) / (self.embed_size ** 0.5)
        attention = torch.softmax(energy, dim=-1)

        # Apply attention weights to values
        out = torch.matmul(attention, values)
        return out

class FeatureAttention(nn.Module):
    def __init__(self, feature_dim):
        super(FeatureAttention, self).__init__()
        self.attention_weights = nn.Parameter(torch.randn(feature_dim))
    
    def forward(self, x):
        # x shape: (batch_size, lag_size, feature_dim)
        attn_scores = torch.matmul(x, self.attention_weights)  # (batch_size, lag_size)
        attn_scores = torch.softmax(attn_scores, dim=-1)  # Normalize attention scores

        weighted_features = x * attn_scores.unsqueeze(-1)  # (batch_size, lag_size, feature_dim)
        return weighted_features
    
class ResiDualNet(nn.Module):
    def __init__(self, config, kernel_size=3, num_blocks=3):
        super(ResiDualNet, self).__init__()
        self.config = config
        self.input_size = config.input_size
        self.hidden_size = config.hidden_size
        self.output_size = config.output_size

        self.resnet = nn.ModuleList([ResNetBlock(self.config, kernel_size) for _ in range(num_blocks)])

    def forward(self, x):

        for block in self.resnet:
            x = block(x)
                    
        return x
    


In [66]:
from math import e


class ResNetBlock(nn.Module):
    def __init__(self, config, kernel_size=3, padding=1):
        super(ResNetBlock, self).__init__()
        self.config = config
        self.input_size = config.input_size
        self.hidden_size = config.hidden_size
        self.output_size = config.output_size
    
        self.encoder_lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, batch_first=True, bidirectional=True)
        self.batchNorm1 = nn.BatchNorm1d(self.hidden_size)

        self.encoder_conv1 = nn.Conv1d(in_channels=self.input_size, out_channels=self.hidden_size * 2, kernel_size=kernel_size, padding=1)
        self.batchNorm2 = nn.BatchNorm1d(self.hidden_size * 2)
        
        self.encoder_conv2 = nn.Conv1d(in_channels=self.hidden_size * 2, out_channels=self.hidden_size * 2, kernel_size=kernel_size, padding=1)
        self.batchNorm3 = nn.BatchNorm1d(self.hidden_size * 2)
        
        self.decoder_lstm = nn.LSTM(input_size=self.hidden_size * 2, hidden_size=self.hidden_size, batch_first=True, bidirectional=True)
        self.batchNorm4 = nn.BatchNorm1d(self.hidden_size)

        self.decoder_conv2 = nn.Conv1d(in_channels=self.hidden_size * 2, out_channels=self.hidden_size, kernel_size=kernel_size, padding=1)
        self.batchNorm5 = nn.BatchNorm1d(self.hidden_size)

        self.decoder_linear = nn.Linear(self.hidden_size * 2, self.hidden_size * 4)
        self.decoder_linear2 = nn.Linear(2 * self.hidden_size * self.config.lag_size, self.config.lag_size * self.output_size)

        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.flatten = nn.Flatten()
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=1, padding=1)

        # self.attention = ScaledDotProductAttention(self.hidden_size)
        self.attention_lstm = FeatureAttention(self.hidden_size * 2)
        self.attention_cnn = FeatureAttention(self.hidden_size * 2)
        self.attention = FeatureAttention(self.hidden_size * 2)

        #Define attention layers
        # self.attention_lstm = nn.Linear(self.hidden_size * 2, self.hidden_size * 2)
        # self.attention_cnn = nn.Linear(self.hidden_size, self.hidden_size)

    def forward(self, x):

        res = x
        #LSTM
        # lstm_input = x[:, :, :10]
        # Second ndarray with the first 8 columns and the 11th column
        # cnn_input = torch.cat((x[:, :, :8], x[:, :, 10:]), -1)

        lstm, (_,_) = self.encoder_lstm(x)
        lstm = self.dropout(lstm)
        lstm = self.activation(lstm)
        # lstm = self.attention_lstm(lstm)

        #CNN
        conv_input = torch.transpose(x, 1, 2)
        encoded = self.encoder_conv1(conv_input)
        encoded = self.batchNorm2(encoded)
        encoded = self.activation(encoded)
        encoded = self.maxpool(encoded)
        # encoded = self.activation(encoded)
        
        encoded = self.encoder_conv2(encoded)
        encoded = self.batchNorm3(encoded)
        encoded = self.activation(encoded)
        #decoded = self.activation(decoded)
        encoded = self.maxpool(encoded)
        decoded = torch.transpose(encoded, 1, 2)
        # decoded = self.attention_cnn(decoded)

        # forward_output = lstm[:, :, :self.input_size-1]  # First 4 units (forward)
        # backward_output = lstm[:, :, self.input_size-1:]   # Last 4 units (backward)

        # Calculate the average
        # lstm = (forward_output + backward_output) / 2
        # lstm = torch.cat((lstm, decoded), -1)
        lstm = lstm * decoded

        # concated = torch.cat((concated, lstm[:, :, 8:], decoded[:, :, 8:]), -1)
        # lstm = self.attention(lstm)

        #LSTM
        final, _ = self.decoder_lstm(lstm)
        # final = self.activation(final)

        final = self.flatten(final)
        final = self.decoder_linear2(final)
        final = final.view(self.config.batch_size, self.config.lag_size, -1)

        final += res
        
        return final