In [1]:
import math
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
batch_size = 32
input_features = 16
input_window = 30
output_window = 5
output_features = 4

# SKIP CONNECTIONS ARE MISSING

# TCN

In [3]:
class TCN(nn.Module):
    def __init__(self, input_features, output_features, output_window, drop_p):
        super(TCN, self).__init__()
        self.output_features = output_features
        self.output_window = output_window

        self.conv1d_1 = nn.Conv1d(in_channels=input_features, out_channels=32, kernel_size=3, groups=input_features)
        self.bn1d_1 = nn.BatchNorm1d(32)
        self.drop_1 = nn.Dropout1d(p=drop_p)
        self.conv1d_2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, groups=32)
        self.bn1d_2 = nn.BatchNorm1d(64)
        self.drop_2 = nn.Dropout1d(p=drop_p)
        self.conv1d_3 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, groups=64)
        self.bn1d_3 = nn.BatchNorm1d(128)
        self.drop_3 = nn.Dropout1d(p=drop_p)

        self.conv2d_1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3))
        self.bn2d_1 = nn.BatchNorm2d(32)
        self.conv2d_2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3))
        self.bn2d_2 = nn.BatchNorm2d(64)
        self.conv2d_3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3))
        self.bn2d_3 = nn.BatchNorm2d(128)
        self.conv2d_4 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(3, 3), stride=(1,2), dilation=(1,2))
        self.bn2d_4 = nn.BatchNorm2d(64)
        self.conv2d_5 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=(3, 3), stride=(1,2), dilation=(1,2))
        self.bn2d_5 = nn.BatchNorm2d(32)
        self.conv2d_6 = nn.Conv2d(in_channels=32, out_channels=8, kernel_size=(3, 3), stride=(1,2), dilation=(1,2))
        self.bn2d_6 = nn.BatchNorm2d(8)
        self.pool_6 = nn.AvgPool2d(kernel_size=(1,2))

        self.conv2d_7 = None
        # this is the initialization we want to make, but we are making this in the first forward loop to make it dynamic
        # self.conv2d_7 = nn.Conv2d(in_channels=8, out_channels=1, kernel_size=(x.shape[2]-output_features+1, x.shape[3]-output_window+1))

        self.pad = nn.ReplicationPad2d((1,1,0,0))
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pad(x)
        x = self.conv1d_1(x)
        x = self.bn1d_1(x)
        x = self.relu(x)
        x = self.drop_1(x)

        x = self.pad(x)
        x = self.conv1d_2(x)
        x = self.bn1d_2(x)
        x = self.relu(x)
        x = self.drop_2(x)

        x = self.conv1d_3(x)
        x = self.bn1d_3(x)
        x = self.relu(x)
        x = self.drop_3(x)

        x = x.unsqueeze(1)
        # x = torch.permute(x, [0,1,3,2])
        x = x.permute([0,1,3,2])

        x = self.conv2d_1(x)
        x = self.bn2d_1(x)
        x = self.relu(x)

        x = self.conv2d_2(x)
        x = self.bn2d_2(x)
        x = self.relu(x)

        x = self.conv2d_3(x)
        x = self.bn2d_3(x)
        x = self.relu(x)

        x = self.conv2d_4(x)
        x = self.bn2d_4(x)
        x = self.relu(x)

        x = self.conv2d_5(x)
        x = self.bn2d_5(x)
        x = self.relu(x)

        x = self.conv2d_6(x)
        x = self.bn2d_6(x)
        x = self.relu(x)
        x = self.pool_6(x)

        if self.conv2d_7:
            x = self.conv2d_7(x)
        else:
            self.conv2d_7 = nn.Conv2d(in_channels=8, out_channels=1, kernel_size=(x.shape[2]-self.output_features+1, x.shape[3]-self.output_window+1)).to(device)
            x = self.conv2d_7(x)

        return torch.squeeze(x)
    
    def call(self, x, y):
        return self(x)

In [4]:
model = TCN(input_features=16, output_features=4, output_window=5, drop_p=0.2).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

x = torch.randn(batch_size, input_features, input_window).to(device)
y = torch.randn(batch_size, output_features, output_window).to(device)
with torch.no_grad():
    output = model.call(x, y)

print(x.shape)
print(output.shape)

189240
torch.Size([32, 16, 30])
torch.Size([32, 4, 5])


# Encoder-Decoder LSTM

In [5]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=3, drop_p=0.2):
        super(Encoder, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=drop_p)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        outputs, (hidden, cell) = self.lstm(x)
        
        return hidden, cell

class Decoder(nn.Module):
    def __init__(self, output_size, hidden_size, num_layers=2, drop_p=0.2, teacher_forcing_ratio=1.0):
        super(Decoder, self).__init__()
        self.output_size = output_size
        self.teacher_forcing_ratio = teacher_forcing_ratio

        self.lstm = nn.LSTM(input_size=output_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=drop_p)
        self.fc = nn.Linear(hidden_size, self.output_size)

    def forward(self, hidden, cell, output_window, target=None):
        batch_size = hidden.size(1)

        # Initial decoder input: zeros
        decoder_input = torch.zeros(batch_size, 1, self.output_size, device=hidden.device)
        outputs = []

        for t in range(output_window):
            out, (hidden, cell) = self.lstm(decoder_input, (hidden, cell))
            pred = self.fc(out.squeeze(1))  # [batch, output_size]
            outputs.append(pred.unsqueeze(1))  # [batch, 1, output_size]

            if target is not None and torch.rand(1) < self.teacher_forcing_ratio:
                decoder_input = target[:, t].unsqueeze(1)  # [batch, 1, output_size]
            else:
                decoder_input = pred.unsqueeze(1)
        
        return torch.cat(outputs, dim=1)  # [batch, output_window, output_size]

class EncoderDecoderLSTM(nn.Module):
    def __init__(self, input_size, output_size, output_window, hidden_size=128, num_layers=3, drop_p=0.2, teacher_forcing_ratio=1.0):
        super(EncoderDecoderLSTM, self).__init__()
        self.encoder = Encoder(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, drop_p=drop_p)
        self.decoder = Decoder(output_size=output_size, hidden_size=hidden_size, num_layers=num_layers, drop_p=drop_p, teacher_forcing_ratio=teacher_forcing_ratio)
        self.output_window = output_window

    def forward(self, x, target=None):
        if target is not None:
            target = target.permute(0, 2, 1)
        
        hidden, cell = self.encoder(x)
        out = self.decoder(hidden, cell, self.output_window, target)
        out = out.permute(0, 2, 1)
        return out
    
    def set_teacher_forcing_ratio(self, new_value):
        self.decoder.teacher_forcing_ratio = new_value

    def call(self, x, y):
        return self(x, y)

In [6]:
model = EncoderDecoderLSTM(input_size=input_features, output_size=output_features, output_window=output_window, num_layers=3, drop_p=0.2).to(device)
model.train()
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

x = torch.randn(batch_size, input_features, input_window).to(device)
y = torch.randn(batch_size, output_features, output_window).to(device)
with torch.no_grad():
    output = model.call(x, y)

print(x.shape)
print(y.shape)
print(output.shape)

model = EncoderDecoderLSTM(input_size=input_features, output_size=output_features, output_window=output_window, num_layers=3, drop_p=0.2).to(device)
model = model.eval()
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

x = torch.randn(batch_size, input_features, input_window).to(device)
with torch.no_grad():
    output = model.call(x, None)

print(x.shape)
print(output.shape)

672260
torch.Size([32, 16, 30])
torch.Size([32, 4, 5])
torch.Size([32, 4, 5])
672260
torch.Size([32, 16, 30])
torch.Size([32, 4, 5])


# Encoder-Decoder GRU

In [7]:
class GRUEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=3, drop_p=0.2):
        super(GRUEncoder, self).__init__()
        self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=drop_p)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # [batch, seq_len, features]
        outputs, hidden = self.gru(x)  # no cell state in GRU
        return hidden  # [num_layers, batch, hidden_size]

class GRUDecoder(nn.Module):
    def __init__(self, output_size, hidden_size, num_layers=2, drop_p=0.2, teacher_forcing_ratio=1.0):
        super(GRUDecoder, self).__init__()
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.teacher_forcing_ratio = teacher_forcing_ratio

        self.gru = nn.GRU(input_size=output_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=drop_p)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, hidden, output_window, target=None):
        batch_size = hidden.size(1)
        decoder_input = torch.zeros(batch_size, 1, self.output_size, device=hidden.device)
        outputs = []

        for t in range(output_window):
            out, hidden = self.gru(decoder_input, hidden)
            pred = self.fc(out.squeeze(1))  # [batch, output_size]
            outputs.append(pred.unsqueeze(1))

            if target is not None and torch.rand(1).item() < self.teacher_forcing_ratio:
                decoder_input = target[:, t].unsqueeze(1)
            else:
                decoder_input = pred.unsqueeze(1)

        return torch.cat(outputs, dim=1)  # [batch, output_window, output_size]

class EncoderDecoderGRU(nn.Module):
    def __init__(self, input_size, output_size, output_window, hidden_size=128, num_layers=3, drop_p=0.2, teacher_forcing_ratio=1.0):
        super(EncoderDecoderGRU, self).__init__()
        self.encoder = GRUEncoder(input_size, hidden_size, num_layers=num_layers, drop_p=drop_p)
        self.decoder = GRUDecoder(output_size, hidden_size, num_layers=num_layers, drop_p=drop_p, teacher_forcing_ratio=teacher_forcing_ratio)
        self.output_window = output_window

    def forward(self, x, target=None):
        if target is not None:
            target = target.permute(0, 2, 1)  # [batch, output_window, output_size]

        hidden = self.encoder(x)
        out = self.decoder(hidden, self.output_window, target)
        return out.permute(0, 2, 1)  # [batch, output_size, output_window]

    def set_teacher_forcing_ratio(self, new_value):
        self.decoder.teacher_forcing_ratio = new_value

    def call(self, x, y):
        return self(x, y)

In [8]:
model = EncoderDecoderGRU(input_size=input_features, output_size=output_features, output_window=output_window, num_layers=3, drop_p=0.2).to(device)
model.train()
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

x = torch.randn(batch_size, input_features, input_window).to(device)
y = torch.randn(batch_size, output_features, output_window).to(device)
with torch.no_grad():
    output = model.call(x, y)

print(x.shape)
print(y.shape)
print(output.shape)

model = EncoderDecoderGRU(input_size=input_features, output_size=output_features, output_window=output_window, num_layers=3, drop_p=0.2).to(device)
model = model.eval()
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

x = torch.randn(batch_size, input_features, input_window).to(device)
with torch.no_grad():
    output = model.call(x, None)

print(x.shape)
print(output.shape)

504324
torch.Size([32, 16, 30])
torch.Size([32, 4, 5])
torch.Size([32, 4, 5])
504324
torch.Size([32, 16, 30])
torch.Size([32, 4, 5])


# Encoder-Decoder Transformer

In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)  # [max_len, d_model]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # [max_len, 1]
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # [d_model/2]
        pe[:, 0::2] = torch.sin(position * div_term)  # even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices
        pe = pe.unsqueeze(1)  # [max_len, 1, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class EncoderDecoderTransformer(nn.Module):
    def __init__(self, input_features, output_features, input_window, output_window, d_model=128, nhead=8, num_layers=3, dropout=0.1):
        super().__init__()
        self.input_window = input_window
        self.output_window = output_window
        self.d_model = d_model

        # Project input and output features into d_model space
        self.input_proj = nn.Linear(input_features, d_model)
        self.output_proj = nn.Linear(output_features, d_model)
        self.output_linear = nn.Linear(d_model, output_features)

        # Positional Encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
        self.pos_decoder = PositionalEncoding(d_model, dropout=dropout)

        # Transformer
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_layers, num_decoder_layers=num_layers, dim_feedforward=512, dropout=dropout, batch_first=True)

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

    def forward(self, src, tgt=None):
        """
        src: [batch, input_features, input_window]
        tgt: [batch, output_features, output_window]
        """
        batch_size = src.size(0)
        src = src.permute(0, 2, 1)  # [batch, input_window, input_features]
        src = self.input_proj(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)

        if tgt is not None:
            # === Training mode ===
            tgt = tgt.permute(0, 2, 1)  # [batch, output_window, output_features]
            tgt = self.output_proj(tgt) * math.sqrt(self.d_model)
            tgt = self.pos_decoder(tgt)
            tgt_mask = self.generate_square_subsequent_mask(self.output_window).to(src.device)
            out = self.transformer(src, tgt, tgt_mask=tgt_mask)
            out = self.output_linear(out)
            return out.permute(0, 2, 1)  # [batch, output_features, output_window]
        else:
            # === Inference mode ===
            output = torch.zeros(batch_size, self.output_window, self.output_linear.out_features, device=src.device)

            for t in range(self.output_window):
                decoder_input = output.clone()  # [batch, output_window, output_features]
                decoder_input = self.output_proj(decoder_input) * math.sqrt(self.d_model)
                decoder_input = self.pos_decoder(decoder_input)

                tgt_mask = self.generate_square_subsequent_mask(self.output_window).to(src.device)
                out = self.transformer(src, decoder_input, tgt_mask=tgt_mask)
                out = self.output_linear(out)
                output[:, t] = out[:, t]  # Take the t-th step output

            return output.permute(0, 2, 1)  # [batch, output_features, output_window]
        
    def call(self, x, y):
        return self(x, y)

In [None]:
model = EncoderDecoderTransformer(input_features=input_features, output_features=output_features, input_window=input_window, output_window=output_window).to(device)
model.train()
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

x = torch.randn(batch_size, input_features, input_window).to(device)
y = torch.randn(batch_size, output_features, output_window).to(device)
with torch.no_grad():
    output = model.call(x, y)

print(x.shape)
print(y.shape)
print(output.shape)

model = EncoderDecoderTransformer(input_features=input_features, output_features=output_features, input_window=input_window, output_window=output_window).to(device)
model = model.eval()
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

x = torch.randn(batch_size, input_features, input_window).to(device)
output = model.call(x, None)

print(x.shape)
print(output.shape)

1392388
torch.Size([32, 16, 30])
torch.Size([32, 4, 5])
torch.Size([32, 4, 5])
1392388
torch.Size([32, 16, 30])
torch.Size([32, 4, 5])


# Coin-wise Cross Attention LSTM

In [11]:
class CoinWiseCrossAttentionLSTM(nn.Module):
    def __init__(self, input_window, output_window, hidden_dim=128, output_features=4, drop_p=0.2, num_layers=1, num_heads=4, target_coin_index=0):
        super().__init__()
        self.input_window = input_window
        self.output_window = output_window
        self.hidden_dim = hidden_dim
        self.output_features = output_features
        self.num_coins = 4
        self.features_per_coin = 4
        self.target_coin_index = target_coin_index

        # Create one LSTM per coin
        self.lstm_blocks = nn.ModuleList([
            nn.LSTM(input_size=self.features_per_coin, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True, dropout=drop_p
            ) for _ in range(self.num_coins)
        ])

        # Attention: Q, K, V are all the LSTM outputs [batch, 4, hidden_dim]
        self.attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True, dropout=drop_p)

        # Final projection to output
        self.fc = nn.Linear(hidden_dim, output_features * output_window)

    def forward(self, x):
        """
        x: [batch_size, input_features=16, input_window]
        Returns: [batch_size, output_features, output_window]
        """
        batch_size = x.size(0)

        # Split input into 4 parts: one per coin
        x = x.view(batch_size, self.num_coins, self.features_per_coin, self.input_window)
        lstm_outputs = []

        for i in range(self.num_coins):
            coin_input = x[:, i]  # [batch, 4, input_window]
            coin_input = coin_input.permute(0, 2, 1)  # [batch, input_window, 4]
            _, (h_n, _) = self.lstm_blocks[i](coin_input)  # h_n: [num_layers, batch, hidden_dim]
            lstm_outputs.append(h_n[-1])  # Take last layer: [batch, hidden_dim]

        # Stack LSTM outputs: [batch, 4, hidden_dim]
        lstm_stack = torch.stack(lstm_outputs, dim=1)

        # Self-attention across the 4 coin representations
        attn_out, _ = self.attention(lstm_stack, lstm_stack, lstm_stack)  # [batch, 4, hidden_dim]

        # Pool across coins (e.g., mean pooling)
        target_coin = attn_out[:, self.target_coin_index, :]
        
        # Predict
        output = self.fc(target_coin)  # [batch, output_features * output_window]
        output = output.view(batch_size, self.output_features, self.output_window)

        return output
    
    def call(self, x, y):
        return self(x)

In [12]:
model = CoinWiseCrossAttentionLSTM(input_window, output_window, hidden_dim=128, output_features=4, drop_p=0.2, num_layers=2, num_heads=4, target_coin_index=0).to(device)
model.train()
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

x = torch.randn(batch_size, input_features, input_window).to(device)
y = torch.randn(batch_size, output_features, output_window).to(device)
with torch.no_grad():
    output = model.call(x, y)

print(x.shape)
print(output.shape)

871444
torch.Size([32, 16, 30])
torch.Size([32, 4, 5])


# Feature-wise Cross Attention LSTM

In [None]:
class FeatureWiseCrossAttentionLSTM(nn.Module):
    def __init__(self, input_window, output_window, output_features=4, drop_p=0.2, hidden_dim=128, num_layers=1, num_heads=4, target_coin_index=0):
        super().__init__()
        self.input_window = input_window
        self.output_window = output_window
        self.output_features = output_features
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_features = 16  # Fixed in this design
        self.group_size = 4
        self.num_groups = self.num_features // self.group_size
        self.target_coin_index = target_coin_index

        # One LSTM per feature
        self.feature_lstms = nn.ModuleList([
            nn.LSTM(input_size=1, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True, dropout=drop_p)
            for _ in range(self.num_features)
        ])

        # One attention module per group of 4 features
        self.group_attentions = nn.ModuleList([
            nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True, dropout=drop_p)
            for _ in range(self.num_groups)
        ])

        # Final attention to merge 4 group embeddings
        self.final_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True, dropout=drop_p)

        # Output projection
        self.fc = nn.Linear(hidden_dim, output_features * output_window)

    def forward(self, x):
        """
        x: [batch_size, 16, input_window]
        Returns: [batch_size, output_features, output_window]
        """
        batch_size = x.size(0)
        feature_representations = []

        # Process each feature through its own LSTM
        for i in range(self.num_features):
            feature_input = x[:, i].unsqueeze(-1)  # [batch, input_window, 1]
            _, (h_n, _) = self.feature_lstms[i](feature_input)
            feature_representations.append(h_n[-1])  # [batch, hidden_dim]

        # Stack into [batch, 16, hidden_dim]
        features_stack = torch.stack(feature_representations, dim=1)

        # Group features into 4 groups, run attention in each group
        group_outputs = []
        for i in range(self.num_groups):
            group = features_stack[:, i * self.group_size:(i + 1) * self.group_size]  # [batch, 4, hidden_dim]
            attn_out, _ = self.group_attentions[i](group, group, group)
            group_pooled = attn_out.mean(dim=1)  # [batch, hidden_dim]
            group_outputs.append(group_pooled)

        # Stack group-level outputs: [batch, 4, hidden_dim]
        groups_stack = torch.stack(group_outputs, dim=1)

        # Final attention to merge 4 groups
        final_attn_out, _ = self.final_attention(groups_stack, groups_stack, groups_stack)
        final_embedding = final_attn_out[:, self.target_coin_index, :]

        # Project to output
        output = self.fc(final_embedding)  # [batch, output_features * output_window]
        output = output.view(batch_size, self.output_features, self.output_window)

        return output
    
    def call(self, x, y):
        return self(x)

In [14]:
model = FeatureWiseCrossAttentionLSTM(input_window, output_window, hidden_dim=128, output_features=4, drop_p=0.2, num_layers=2, num_heads=4).to(device)
model.train()
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

x = torch.randn(batch_size, input_features, input_window).to(device)
y = torch.randn(batch_size, output_features, output_window).to(device)
with torch.no_grad():
    output = model.call(x,y)

print(x.shape)
print(output.shape)

3519508
torch.Size([32, 16, 30])
torch.Size([32, 4, 5])
