In [1]:
import torch
import torch.nn as nn

In [15]:
class Seq2SeqConv1D(nn.Module):
    def __init__(self, model_parameter):
        super(Seq2SeqConv1D, self).__init__()
        self.future_step = model_parameter.future_step
        self.output_size = model_parameter.output_size
        self.lag_window = model_parameter.lag_window
        self.num_blocks = model_parameter.num_blocks
        self.batch_size = model_parameter.batch_size
 
        self.resnet = nn.ModuleList([ResNetBlock(model_parameter) for _ in range(self.num_blocks)])
        
        self.linear = nn.Linear(self.output_size * self.lag_window, self.future_step * self.output_size)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()

    def forward(self, x):
        for block in self.resnet:
            x = block(x)
            
        x = self.flatten(x)
        x = self.linear(x)
        x = x.view(self.batch_size, self.future_step, -1)
        return self.activation(x)

In [17]:
class ResNetBlock(nn.Module):
    def __init__(self, model_parameter, padding=1):
        super(ResNetBlock, self).__init__()
        self.model_parameter = model_parameter
        self.input_size = model_parameter.input_size
        self.hidden_size = model_parameter.hidden_size
        self.output_size = model_parameter.output_size
        self.kernel_size = model_parameter.kernel_size
        self.future_step = model_parameter.future_step
        self.lag_window = model_parameter.lag_window
        self.batch_size = model_parameter.batch_size
    
        self.encoder_lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, batch_first=True, bidirectional=True)
        self.batchNorm1 = nn.BatchNorm1d(2 * self.hidden_size)

        self.encoder_conv1 = nn.Conv1d(in_channels=self.input_size, out_channels=self.hidden_size, kernel_size=self.kernel_size, padding=1)
        self.batchNorm2 = nn.BatchNorm1d(self.hidden_size)
        
        self.encoder_conv2 = nn.Conv1d(in_channels=self.hidden_size * 2, out_channels=self.hidden_size, kernel_size=self.kernel_size, padding=1)
        self.batchNorm3 = nn.BatchNorm1d(self.hidden_size)
        
        self.decoder_lstm = nn.LSTM(input_size=self.hidden_size * 3, hidden_size=self.hidden_size, batch_first=True, bidirectional=True)
        self.batchNorm4 = nn.BatchNorm1d(self.hidden_size)

        self.decoder_conv2 = nn.Conv1d(in_channels=self.hidden_size * 3, out_channels=self.hidden_size, kernel_size=self.kernel_size, padding=1)
        self.batchNorm5 = nn.BatchNorm1d(self.hidden_size)

        self.decoder_linear = nn.Linear(self.hidden_size * 2, self.hidden_size * 4)
        self.decoder_linear2 = nn.Linear(3 * self.hidden_size * self.lag_window, self.lag_window * self.output_size)

        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.flatten = nn.Flatten()

    def forward(self, x):

        res = x
        #LSTM
        lstm, (_,_) = self.encoder_lstm(x)
        lstm = self.dropout(lstm)

        #CNN
        conv_input = torch.transpose(x, 1, 2)
        encoded = self.encoder_conv1(conv_input)
        encoded = self.batchNorm2(encoded)
        
        # encoded = self.encoder_conv2(encoded)
        # encoded = self.batchNorm3(encoded)
        #decoded = self.activation(decoded)
        decoded = torch.transpose(encoded, 1, 2)

        lstm = torch.cat((lstm, decoded), -1)
        #lstm = self.dropout(lstm)
        
        #LSTM
        #final, _ = self.decoder_lstm(lstm)
        #final = self.dropout(final)

        #final = self.decoder_linear(final)
        final = self.flatten(lstm)
        final = self.decoder_linear2(final)
        final = final.view(self.batch_size, self.lag_window, -1)

        final += res
        
        return self.activation(final)