In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx as onnx


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class TransformerL(nn.Module):
    """
    Linear transformer-based model for time series forecasting.
    """
    def __init__(self, input_dim: int, embed_dim: int, num_layers: int, num_heads: int, dropout: float):
        super(TransformerL, self).__init__()
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout
        
        self.embedding = nn.Linear(self.input_dim, self.embed_dim)        
        self.encoder_layers = nn.ModuleList([nn.TransformerEncoderLayer(self.embed_dim, self.num_heads, dim_feedforward=4*self.embed_dim, dropout=self.dropout)
                                             for _ in range(self.num_layers)])
        self.output_layer = nn.Linear(self.embed_dim, 1)

    def forward(self, src) -> torch.Tensor:
        """
        Forward pass of the Transformer model.
        
        Parameters
        ----------
        src : torch.Tensor
            Input tensor of shape (batch_size, seq_length, input_dim).
        
        Returns
        -------
        torch.Tensor
            Output tensor of shape (batch_size, seq_length, 1).
        """
        src_embedded = self.embedding(src)
        src_embedded = src_embedded.permute(1, 0, 2)

        for encoder in self.encoder_layers:
            src_embedded = encoder(src_embedded)

        output = self.output_layer(src_embedded)
        return output

In [3]:
model_m = TransformerL(input_dim=10, embed_dim=64, num_layers=2, num_heads=8, dropout=0.01)

# Dummy input tensor (batch_size, seq_length, input_dim)
dummy_input = torch.randn(8, 5, 10)

# Exportar el modelo a ONNX
onnx.export(model_m, dummy_input, "multi-transformerL.onnx", input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size', 1: 'seq_length'}, 'output': {0: 'batch_size', 1: 'seq_length'}})