In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  # (T, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)  # even
        pe[:, 1::2] = torch.cos(position * div_term)  # odd

        pe = pe.unsqueeze(1)  # (T, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (T, B, d_model)
        x = x + self.pe[:x.size(0)]
        return x

class CTTSModel(nn.Module):
    def __init__(self, input_dim, output_dim=1, cnn_channels=(64, 32), kernel_size=3,
                 nhead=4, num_layers=2, dropout=0.1, max_len=5000):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        # CNN encoder
        self.cnn = nn.Sequential(
            nn.Conv1d(in_channels=input_dim, out_channels=cnn_channels[0], kernel_size=kernel_size, padding=kernel_size // 2),
            nn.ReLU(),
            nn.Conv1d(in_channels=cnn_channels[0], out_channels=cnn_channels[1], kernel_size=kernel_size, padding=kernel_size // 2),
            nn.ReLU()
        )

        # Positional encoding + Transformer
        self.pos_encoder = PositionalEncoding(d_model=cnn_channels[1], max_len=max_len)
        encoder_layer = nn.TransformerEncoderLayer(d_model=cnn_channels[1], nhead=nhead, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Decoder
        self.decoder = nn.Linear(cnn_channels[1], output_dim)

    def forward(self, x):
        """
        x: (B, T, input_dim) → input_dim = target + exogenous
        Returns: (B, output_dim)
        """
        # CNN expects (B, input_dim, T)
        x = x.permute(0, 2, 1)  # (B, N, T)
        cnn_out = self.cnn(x)   # (B, 32, T)

        # Transformer expects (T, B, d_model)
        cnn_out = cnn_out.permute(2, 0, 1)  # (T, B, 32)
        transformer_input = self.pos_encoder(cnn_out)
        transformer_out = self.transformer_encoder(transformer_input)  # (T, B, 32)

        # Use last time step's output for prediction
        last_step = transformer_out[-1]  # (B, 32)
        output = self.decoder(last_step)  # (B, output_dim)
        return output


In [2]:
model = CTTSModel(input_dim=10, output_dim=1)  # e.g., 9 exogenous + 1 target

x = torch.randn(64, 30, 10)  # batch_size=64, sequence_len=30
y_hat = model(x)  # (64, 1)


