In [6]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from positional_encoding import PositionalEncoding  # Make sure to import from the correct file or directory

class TransformerModel(nn.Module):
    def __init__(self, seq_len, embedding_size, nhead, num_encoder_layers, num_classes, cfg):
        super(TransformerModel, self).__init__()
        self.cfg = cfg
        self.seq_len = seq_len
        self.embedding_size = embedding_size
        self.nhead = nhead
        self.num_encoder_layers = num_encoder_layers
        self.num_classes = num_classes
        self.positional_encoding = PositionalEncoding(self.embedding_size)
        self.encoder_layer = TransformerEncoderLayer(d_model=self.embedding_size, nhead=self.nhead, batch_first=True)
        self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=self.num_encoder_layers)
        self.Linear = nn.Sequential(
            nn.Linear(self.seq_len * self.embedding_size, 1024),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, self.num_classes))

    def forward(self, x):
        # x: (batch_size, seq_len, embedding_size)
        self.batch_size = x.shape[0]
        x = self.positional_encoding(x)
        x = self.transformer_encoder(x)
        x = x.view(self.batch_size, -1)
        x = self.Linear(x)
        if self.cfg['task'] == "VA":
            x = torch.tanh(x)
        elif self.cfg['task'] == "EXPR":
            x = torch.softmax(x, dim=1)
        else:
            x = torch.sigmoid(x)
        return x


In [7]:
model = TransformerModel(50, 768, 8, 6, 8, {'task': 'EXPR'})

In [8]:
x = torch.randn(32, 50, 768)

In [9]:
out = model(x)

In [10]:
out.shape

torch.Size([32, 8])