In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [51]:
batch_size = 12
context_length = 36
embedding_dim = 72
num_heads = 6
head_dim = embedding_dim // num_heads
num_layers = 6
dropout=0.2

In [52]:
class Head(nn.Module):
    def __init__(self, head_dim, mask=False, cross_head = False):
        super().__init__()
        
        self.key = nn.Linear(embedding_dim, head_dim, bias=False)
        self.query = nn.Linear(embedding_dim, head_dim, bias=False)
        self.value = nn.Linear(embedding_dim, head_dim, bias=False)
        self.mask = mask

        if self.mask:
            self.register_buffer("tril", torch.tril(torch.ones(context_length, context_length)))
    
    def forward(self, embeddings, encoder_embeddings = None):

        B, T, C = embeddings.shape

        key = self.key(encoder_embeddings) if encoder_embeddings is not None else self.key(embeddings)
        value = self.value(encoder_embeddings)  if encoder_embeddings is not None else self.value(embeddings)
        query = self.query(embeddings)

        wei = query @ key.transpose(-2, -1) * C ** -0.5
        if self.mask:
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        
        wei = F.softmax(wei, dim=-1)

        value = self.value(embeddings)
        output = wei @ value

        return output

In [53]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_dim):
        super().__init__()
        
        self.heads = nn.ModuleList([Head(head_dim, mask=True) for _ in range(num_heads)])
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, embeddings):

        output = torch.concat([head(embeddings) for head in self.heads], dim=-1)
        output = self.dropout(self.proj(output))

        return output

In [54]:
class MultiCrossHeadAttention(nn.Module):
    def __init__(self, num_heads, head_dim):
        super().__init__()
        
        self.heads = nn.ModuleList([Head(head_dim) for _ in range(num_heads)])
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, embeddings, encoder_embeddings):

        output = torch.concat([head(embeddings, encoder_embeddings) for head in self.heads], dim=-1)
        output = self.dropout(self.proj(output))

        return output

In [55]:
class FeedForward(nn.Module):

    def __init__(self, embedding_dim):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.ReLU(),
            nn.Linear(4 * embedding_dim, embedding_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, embeddings):
        return self.ffn(embeddings)

In [56]:
class Decoder(nn.Module):

    def __init__(self, head_dim, embedding_dim):
        super().__init__()
        
        self.self_mha = MultiHeadAttention(num_heads, head_dim)
        self.cross_mha = MultiCrossHeadAttention(num_heads, head_dim)
        self.ffwd = FeedForward(embedding_dim)
        self.ln1 = nn.LayerNorm(embedding_dim)
        self.ln2 = nn.LayerNorm(embedding_dim)
        self.ln3 = nn.LayerNorm(embedding_dim)
    
    def forward(self, embeddings, encoder_embeddings):

        output = embeddings + self.ln1(self.self_mha(embeddings))
        output = output + self.ln2(self.cross_mha(output, encoder_embeddings))
        output = output + self.ln3(self.ffwd(output))

        return output

In [57]:
x = torch.normal(mean=0.0, std=1.0, size=(batch_size, context_length, embedding_dim))
y = torch.normal(mean=0.0, std=5.0, size=(batch_size, context_length, embedding_dim))
decoder = Decoder(head_dim, embedding_dim)
output = decoder(x, y)
output

tensor([[[ 2.2106, -0.1051,  1.0015,  ..., -0.0353,  2.1705, -3.4710],
         [-3.2287, -0.9642,  0.4708,  ...,  0.8362,  2.0187,  1.5654],
         [ 2.1574, -0.8670, -0.5740,  ...,  4.6818,  4.9625, -1.6910],
         ...,
         [-1.0304,  0.0716, -0.9657,  ...,  0.9160,  4.8802,  1.8321],
         [-1.1211,  0.2472,  0.6258,  ..., -1.1739,  3.9222,  0.7250],
         [-1.5589, -1.0994, -1.1518,  ...,  0.7450,  2.0824, -1.5496]],

        [[ 0.6546, -0.0282, -1.4039,  ..., -1.3663,  0.7442,  0.0474],
         [-0.4554, -0.8557, -2.5814,  ...,  1.0194,  1.5900, -0.5650],
         [-0.7117,  1.2064, -0.0846,  ..., -1.2891,  0.2806, -0.4523],
         ...,
         [-1.7036, -2.1728, -1.6704,  ..., -4.3779, -2.0408, -1.6914],
         [ 1.1053,  1.1844, -2.1042,  ..., -0.5678,  0.7903, -1.7722],
         [ 1.0782, -0.7808, -0.1235,  ..., -4.1783,  1.4987, -0.1619]],

        [[-2.6997, -1.7621, -0.6285,  ...,  2.2566,  3.1865,  0.9203],
         [-2.3394, -0.5988,  0.0254,  ..., -1