In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [37]:
batch_size = 12
context_length = 36
embedding_dim = 72
num_heads = 6
head_dim = embedding_dim // num_heads
num_layers = 6
dropout=0.2

In [38]:
class Head(nn.Module):

    def __init__(self, head_dim, mask=False):
        super().__init__()

        self.key = nn.Linear(embedding_dim, head_dim, bias=False)
        self.query = nn.Linear(embedding_dim, head_dim, bias=False)
        self.value = nn.Linear(embedding_dim, head_dim, bias=False)
        self.mask = mask
        if self.mask:
            self.register('tril', torch.tril(torch.ones(context_length, context_length)))
    
    def forward(self, embeddings):

        B, T, C = embeddings.shape

        key = self.key(embeddings) # (B, T, C)
        query = self.query(embeddings) # (B, T, C)

        # compute the weigts or scores
        wei = query @ key.transpose(-2, -1) * C**-0.5 # (B, T, T)
        if self.mask:
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # (B, T, T)
        
        wei = F.softmax(wei, dim=-1)

        value = self.value(embeddings) # (B, T, C)
        output = wei @ value # (B, T, T) * (B, T, C) -> (B, T, C)

        return output

In [39]:
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_dim, mask=False, cross_head = False):
        super().__init__()

        self.heads = nn.ModuleList([Head(head_dim, mask) for _ in range(num_heads)])
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, embeddings):

        output = torch.concat([head(embeddings) for head in self.heads], dim=-1) # concat along last dimension b/c the original embedding_dim is divided into n_heads times, each of size head_dim
        output = self.dropout(self.proj(output))

        return output

In [40]:
class FeedForward(nn.Module):

    def __init__(self, embedding_dim):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.ReLU(),
            nn.Linear(4 * embedding_dim, embedding_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, embeddings):
        return self.ffn(embeddings)

In [45]:
class EncoderLayer(nn.Module):

    def __init__(self, embedding_dim, num_heads):
        super().__init__()
        self.self_mha = MultiHeadAttention(num_heads, head_dim, mask=False)
        self.ffwd = FeedForward(embedding_dim)
        self.ln1 = nn.LayerNorm(embedding_dim)
        self.ln2 = nn.LayerNorm(embedding_dim)
    

    def forward(self, embeddings):
        output = embeddings + self.ln1(self.self_mha(embeddings))
        output = output + self.ln2(self.ffwd(output))

        return output

In [48]:
class Encoder(nn.Module):

    def __init__(self, num_layers):
        super().__init__()

        self.encoders = nn.Sequential(*[EncoderLayer(embedding_dim, num_heads) for _ in range(num_layers)])
    
    def forward(self, inputs):

        output = self.encoders(inputs)
        return output

In [49]:
x = torch.normal(mean=0.0, std=1.0, size=(batch_size, context_length, embedding_dim))
encoder = Encoder(num_layers)
encoder(x)


tensor([[[ 5.8216e+00,  1.8666e+00,  6.8943e+00,  ..., -1.4715e+00,
          -1.7789e+00, -4.2688e+00],
         [ 3.8751e+00, -6.1958e+00,  4.7511e+00,  ...,  4.7486e+00,
          -4.3046e-04, -1.5259e+00],
         [ 1.9800e+00, -6.3879e-01,  6.6451e+00,  ...,  3.6647e+00,
           3.9955e+00, -5.2388e+00],
         ...,
         [ 3.2471e+00, -5.4973e+00,  4.0067e+00,  ...,  2.5814e+00,
          -1.1707e+00, -3.1103e+00],
         [ 3.7453e+00, -7.0919e+00,  3.0443e+00,  ...,  8.3878e-01,
           7.8253e-01, -3.3646e+00],
         [ 1.2099e+00, -5.6506e+00,  5.8804e+00,  ...,  1.7355e+00,
           1.6730e+00, -3.5347e+00]],

        [[ 3.0355e+00, -5.7101e+00,  6.4742e+00,  ..., -3.0717e+00,
           4.1893e+00, -7.8306e+00],
         [ 1.1288e+00, -2.5011e-01,  4.1230e+00,  ..., -9.0750e-01,
           2.8177e+00, -5.2242e+00],
         [ 1.6335e+00, -4.4204e+00,  6.5933e+00,  ..., -1.0236e+00,
           2.6325e+00, -7.5299e+00],
         ...,
         [ 6.0442e+00, -4