In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

class ScaledDotProductAttention(nn.Module):
    def __init__(self, attention_dim):
        super(ScaledDotProductAttention, self).__init__()
        self.scaling_factor = np.sqrt(attention_dim)

    def forward(self, q, k, v):
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scaling_factor
        attn_weights = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, v)
        return output, attn_weights
    
class ResiDualNet(nn.Module):
    def __init__(self, config, kernel_size=3, num_blocks=3):
        super(ResiDualNet, self).__init__()
        self.config = config
        self.input_size = config.input_size
        self.hidden_size = config.hidden_size
        self.output_size = config.output_size

        self.resnet = nn.ModuleList([ResNetBlock(self.config, kernel_size) for _ in range(num_blocks)])

    def forward(self, x):

        for block in self.resnet:
            x = block(x)
                    
        return x
    


In [4]:
class ResNetBlock(nn.Module):
    def __init__(self, config, kernel_size=3, padding=1):
        super(ResNetBlock, self).__init__()
        self.config = config
        self.input_size = config.input_size
        self.hidden_size = config.hidden_size
        self.output_size = config.output_size
    
        self.encoder_lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, batch_first=True, bidirectional=True)
        self.batchNorm1 = nn.BatchNorm1d(2 * self.hidden_size)

        self.encoder_conv1 = nn.Conv1d(in_channels=self.input_size, out_channels=self.hidden_size * 2, kernel_size=kernel_size, padding=1)
        self.batchNorm2 = nn.BatchNorm1d(self.hidden_size * 2)
        
        self.encoder_conv2 = nn.Conv1d(in_channels=self.hidden_size * 2, out_channels=self.hidden_size, kernel_size=kernel_size, padding=1)
        self.batchNorm3 = nn.BatchNorm1d(self.hidden_size)
        
        self.decoder_lstm = nn.LSTM(input_size=self.hidden_size * 3, hidden_size=self.hidden_size, batch_first=True, bidirectional=True)
        self.batchNorm4 = nn.BatchNorm1d(self.hidden_size)

        self.decoder_conv2 = nn.Conv1d(in_channels=self.hidden_size * 3, out_channels=self.hidden_size, kernel_size=kernel_size, padding=1)
        self.batchNorm5 = nn.BatchNorm1d(self.hidden_size)

        self.decoder_linear = nn.Linear(self.hidden_size * 2, self.hidden_size * 4)
        self.decoder_linear2 = nn.Linear(2 * self.hidden_size * self.config.lag_size, self.config.lag_size * self.output_size)

        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.flatten = nn.Flatten()

        self.attention = ScaledDotProductAttention(self.hidden_size)

        #Define attention layers
        self.attention_lstm = nn.Linear(self.hidden_size * 2, self.hidden_size * 2)
        self.attention_cnn = nn.Linear(self.hidden_size, self.hidden_size)

    def forward(self, x):

        res = x
        #LSTM
        lstm, (_,_) = self.encoder_lstm(x)
        lstm = self.dropout(lstm)

        #CNN
        conv_input = torch.transpose(x, 1, 2)
        encoded = self.encoder_conv1(conv_input)
        encoded = self.batchNorm2(encoded)
        
        encoded = self.encoder_conv2(encoded)
        encoded = self.batchNorm3(encoded)
        #decoded = self.activation(decoded)
        decoded = torch.transpose(encoded, 1, 2)

        lstm = torch.cat((lstm, decoded), -1)

        #LSTM
        final, _ = self.decoder_lstm(lstm)

        final = self.flatten(final)
        final = self.decoder_linear2(final)
        final = final.view(self.config.batch_size, self.config.lag_size, -1)

        final += res
        
        return self.activation(final)