In [1]:
import torch.nn as nn
import torch.nn.functional as F

class MiniTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=4, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        # x: [batch, seq_len] -> transformer espera [seq_len, batch, d_model]
        x = self.embedding(x).permute(1,0,2)
        x = self.transformer(x)
        # predecimos solo el Ãºltimo token de la secuencia
        x = self.fc(x[-1])
        return F.log_softmax(x, dim=-1)

model = MiniTransformer(40)
print(model)

MiniTransformer(
  (embedding): Embedding(40, 64)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc): Linear(in_features=64, out_features=40, bias=True)
)


