In [1]:
import torch

In [3]:
class TransformerLayer(torch.nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads

        self._self_att = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self._mlp = torch.nn.Sequential(
            torch.nn.Linear(embed_dim, 4 * embed_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(4 * embed_dim, embed_dim),
        )
        self._in_norm = torch.nn.LayerNorm(embed_dim)
        self._mlp_norm = torch.nn.LayerNorm(embed_dim)

    def forward(self, x):
        x_norm = self._in_norm(x)
        x = x + self._self_att(x_norm, x_norm, x_norm)[0] # Get the results of the attention layer. We don't want the weights. That's why we have the index
        x = x + self._mlp(self._mlp_norm(x))
        return x

class Transformer(torch.nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers):
        super().__init__()
        self._network = torch.nn.Sequential(
            *[TransformerLayer(embed_dim, num_heads) for _ in range(num_layers)]
        )

    def forward(self, x):
        return self._network(x)

net = Transformer(128, 8, 4)
print(net(torch.rand(16, 10, 128)).shape)
net

torch.Size([16, 10, 128])


Transformer(
  (_network): Sequential(
    (0): TransformerLayer(
      (_self_att): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (_mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=128, bias=True)
      )
      (_in_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (_mlp_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (_self_att): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (_mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=128, bias=True)
      )
      (_in_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (_mlp_norm): Lay