In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [37]:
batch_size = 12
context_length = 36
embedding_dim = 72
num_heads = 6
head_dim = embedding_dim // num_heads
num_layers = 6
dropout=0.2

In [38]:
class Head(nn.Module):

    def __init__(self, head_dim, mask=False):
        super().__init__()

        self.key = nn.Linear(embedding_dim, head_dim, bias=False)
        self.query = nn.Linear(embedding_dim, head_dim, bias=False)
        self.value = nn.Linear(embedding_dim, head_dim, bias=False)
        self.mask = mask
        if self.mask:
            self.register('tril', torch.tril(torch.ones(context_length, context_length)))
    
    def forward(self, embeddings):

        B, T, C = embeddings.shape

        key = self.key(embeddings) # (B, T, C)
        query = self.query(embeddings) # (B, T, C)

        # compute the weigts or scores
        wei = query @ key.transpose(-2, -1) * C**-0.5 # (B, T, T)
        if self.mask:
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # (B, T, T)
        
        wei = F.softmax(wei, dim=-1)

        value = self.value(embeddings) # (B, T, C)
        output = wei @ value # (B, T, T) * (B, T, C) -> (B, T, C)

        return output

In [39]:
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_dim, mask=False, cross_head = False):
        super().__init__()

        self.heads = nn.ModuleList([Head(head_dim, mask) for _ in range(num_heads)])
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, embeddings):

        output = torch.concat([head(embeddings) for head in self.heads], dim=-1) # concat along last dimension b/c the original embedding_dim is divided into n_heads times, each of size head_dim
        output = self.dropout(self.proj(output))

        return output

In [40]:
class FeedForward(nn.Module):

    def __init__(self, embedding_dim):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.ReLU(),
            nn.Linear(4 * embedding_dim, embedding_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, embeddings):
        return self.ffn(embeddings)

In [41]:
class Encoder(nn.Module):

    def __init__(self, embedding_dim, num_heads):
        super().__init__()
        self.self_mha = MultiHeadAttention(num_heads, head_dim, mask=False)
        self.ffwd = FeedForward(embedding_dim)
        self.ln1 = nn.LayerNorm(embedding_dim)
        self.ln2 = nn.LayerNorm(embedding_dim)
    

    def forward(self, embeddings):
        output = embeddings + self.ln1(self.self_mha(embeddings))
        output = output + self.ln2(self.ffwd(output))

        return output

In [42]:
x = torch.normal(mean=0.0, std=1.0, size=(batch_size, context_length, embedding_dim))
encoder = Encoder(embedding_dim, num_heads)
encoder(x)


tensor([[[-0.6426,  0.7637,  1.4020,  ..., -1.4449, -0.8595,  1.2576],
         [-0.5024, -0.9455, -0.0903,  ..., -2.3918, -2.6005, -2.4849],
         [-0.9345,  2.1450, -0.6613,  ...,  2.1165, -1.9769, -3.4450],
         ...,
         [-0.9818, -0.7390, -0.7188,  ..., -0.6580,  0.1123, -1.3307],
         [ 0.1014, -0.3141,  1.5314,  ..., -2.5026,  0.3968,  2.0615],
         [-1.4914, -0.8695,  0.5274,  ..., -1.8936, -0.6168, -4.8527]],

        [[ 1.8072,  0.7790, -0.8845,  ...,  0.2791, -2.7844,  2.8419],
         [ 0.0472, -2.3096,  0.1101,  ..., -3.2924, -1.3823,  1.8936],
         [-0.6778, -2.4964,  0.6566,  ..., -0.7618, -3.4384, -0.2994],
         ...,
         [ 1.6338, -1.4597,  1.5779,  ..., -1.1601, -0.7695, -1.7222],
         [ 0.9130, -0.3275,  3.5388,  ..., -1.5530, -3.3836,  0.9193],
         [ 1.5087, -0.8400,  1.3757,  ...,  2.0936, -4.1108, -1.4147]],

        [[-0.0892, -0.5195,  1.3635,  ...,  0.4821, -0.9348,  2.0546],
         [-0.0410,  1.4158,  3.3364,  ..., -0