In [2]:
# !pip install torch

In [3]:
from torch import nn

In [4]:
class PartialTransformer(nn.Module):

  def __init__(self, vocab_size, embedding_dim, n_heads, n_layers, dropout):
    super().__init__()

    self.vocab_size = vocab_size
    self.embedding_dim = embedding_dim
    self.n_heads = n_heads
    self.n_layers = n_layers
    self.dropout = dropout

    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.attention = nn.MultiheadAttention(embedding_dim, n_heads, dropout = dropout)
    self.feed_forward = nn.Sequential(
        nn.Linear(embedding_dim, embedding_dim),
        nn.ReLU(),
        nn.Linear(embedding_dim, embedding_dim)
    )
    self.out = nn.Linear(embedding_dim, vocab_size)

  def forward(self, x):

    x = self.embedding(x)
    x = self.attention(x)
    x = self.feed_forward(x)
    x = self.out(x)

    return x

In [5]:
model = PartialTransformer(
    vocab_size = 1000, 
    embedding_dim = 32, 
    n_heads = 4, 
    n_layers = 2, 
    dropout = 0.5
    )

In [6]:
model.modules

<bound method Module.modules of PartialTransformer(
  (embedding): Embedding(1000, 32)
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
  )
  (feed_forward): Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=32, bias=True)
  )
  (out): Linear(in_features=32, out_features=1000, bias=True)
)>