In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary

### Transformer Model

In [2]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = nn.Parameter(torch.zeros(1, 100, d_model))  # simple positional embedding
        self.transformer = nn.Transformer(
            d_model = d_model, 
            nhead = nhead,
            num_encoder_layers = num_layers,
            num_decoder_layers = num_layers,
            dim_feedforward = 2048,
            dropout = 0.1,
            batch_first=True # (batch, seq, feature)
        )
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):
        src = self.embedding(src) + self.pos_encoder[:, :src.size(1)]
        tgt = self.embedding(tgt) + self.pos_encoder[:, :tgt.size(1)]
        out = self.transformer(src, tgt)
        return self.fc_out(out)

In [3]:
vocab_size = 1000 # sequence length
batch_size = 2

model = Seq2SeqTransformer(vocab_size)
summary(model, depth=4)

Layer (type:depth-idx)                                                 Param #
Seq2SeqTransformer                                                     51,200
├─Embedding: 1-1                                                       512,000
├─Transformer: 1-2                                                     --
│    └─TransformerEncoder: 2-1                                         --
│    │    └─ModuleList: 3-1                                            --
│    │    │    └─TransformerEncoderLayer: 4-1                          3,152,384
│    │    │    └─TransformerEncoderLayer: 4-2                          3,152,384
│    │    │    └─TransformerEncoderLayer: 4-3                          3,152,384
│    │    │    └─TransformerEncoderLayer: 4-4                          3,152,384
│    │    │    └─TransformerEncoderLayer: 4-5                          3,152,384
│    │    │    └─TransformerEncoderLayer: 4-6                          3,152,384
│    │    └─LayerNorm: 3-2                              

In [4]:
# dummy data
src_seq = torch.randint(0, vocab_size, (batch_size, 10))
tgt_seq = torch.randint(0, vocab_size, (batch_size, 10))

output = model(src_seq, tgt_seq)
print("Output shape:", output.shape)  # (batch, seq, vocab_size)

Output shape: torch.Size([2, 10, 1000])
