In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

In [3]:
class LSTMBlock(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMBlock, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
    
    def forward(self, x):
        out, _ = self.lstm(x)
        return out

class CNNBlock(nn.Module):
    def __init__(self, input_size, hidden_size, kernel_size=3):
        super(CNNBlock, self).__init__()
        self.conv = nn.Conv1d(input_size, hidden_size, kernel_size, padding=1)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = x.transpose(1, 2)  # (batch_size, channels, seq_len)
        x = self.conv(x)
        x = self.relu(x)
        x = x.transpose(1, 2)  # (batch_size, seq_len, channels)
        return x

class AttentionBlock(nn.Module):
    def __init__(self, hidden_size):
        super(AttentionBlock, self).__init__()
        self.attention = nn.Linear(hidden_size, hidden_size)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, query, context):
        scores = torch.bmm(query, context.transpose(1, 2))  # (batch_size, seq_len, seq_len)
        attention_weights = self.softmax(scores)
        attended = torch.bmm(attention_weights, context)  # (batch_size, seq_len, hidden_size)
        return attended


In [4]:
class ResidualCombinedModel(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(ResidualCombinedModel, self).__init__()
        self.seasonality_lstm = LSTMBlock(input_size, hidden_size)
        self.seasonality_cnn = CNNBlock(input_size, hidden_size)
        
        self.trend_lstm = LSTMBlock(input_size, hidden_size)
        self.trend_cnn = CNNBlock(input_size, hidden_size)
        
        self.residual_lstm = LSTMBlock(input_size, hidden_size)
        self.residual_cnn = CNNBlock(input_size, hidden_size)
        
        self.attention_seasonality = AttentionBlock(hidden_size)
        self.attention_trend = AttentionBlock(hidden_size)
        self.attention_residual = AttentionBlock(hidden_size)
        
        self.fc = nn.Linear(hidden_size, 1)  # Final output layer
    
    def forward(self, seasonality, trend, residual):
        # First layer: Separate LSTM and CNN for each component
        seasonality_lstm_out = self.seasonality_lstm(seasonality)
        seasonality_cnn_out = self.seasonality_cnn(seasonality)
        seasonality_out = seasonality_lstm_out * seasonality_cnn_out + seasonality  # Residual connection
        
        trend_lstm_out = self.trend_lstm(trend)
        trend_cnn_out = self.trend_cnn(trend)
        trend_out = trend_lstm_out * trend_cnn_out + trend  # Residual connection
        
        residual_lstm_out = self.residual_lstm(residual)
        residual_cnn_out = self.residual_cnn(residual)
        residual_out = residual_lstm_out * residual_cnn_out + residual  # Residual connection
        
        # Second layer: Inter-component attention
        trend_attention = self.attention_seasonality(seasonality_out, trend_out)
        residual_attention = self.attention_seasonality(seasonality_out, residual_out)
        
        seasonality_out = (seasonality_out * trend_attention * residual_attention) + seasonality
        
        seasonality_attention = self.attention_trend(trend_out, seasonality_out)
        residual_attention = self.attention_trend(trend_out, residual_out)
        
        trend_out = (trend_out * seasonality_attention * residual_attention) + trend
        
        seasonality_attention = self.attention_residual(residual_out, seasonality_out)
        trend_attention = self.attention_residual(residual_out, trend_out)
        
        residual_out = (residual_out * seasonality_attention * trend_attention) + residual
        
        # Combine all components and output
        combined_out = seasonality_out + trend_out + residual_out
        final_out = self.fc(combined_out)
        
        return final_out
