In [59]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [60]:
batch_size = 12
context_length = 36
embedding_dim = 72
num_heads = 6
head_dim = embedding_dim // num_heads
num_layers = 6
dropout=0.2

In [61]:
class Head(nn.Module):
    def __init__(self, head_dim, mask=False, cross_head = False):
        super().__init__()
        
        self.key = nn.Linear(embedding_dim, head_dim, bias=False)
        self.query = nn.Linear(embedding_dim, head_dim, bias=False)
        self.value = nn.Linear(embedding_dim, head_dim, bias=False)
        self.mask = mask

        if self.mask:
            self.register_buffer("tril", torch.tril(torch.ones(context_length, context_length)))
    
    def forward(self, embeddings, encoder_embeddings = None):

        B, T, C = embeddings.shape

        key = self.key(encoder_embeddings) if encoder_embeddings is not None else self.key(embeddings)
        value = self.value(encoder_embeddings)  if encoder_embeddings is not None else self.value(embeddings)
        query = self.query(embeddings)

        wei = query @ key.transpose(-2, -1) * C ** -0.5
        if self.mask:
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        
        wei = F.softmax(wei, dim=-1)

        value = self.value(embeddings)
        output = wei @ value

        return output

In [62]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_dim):
        super().__init__()
        
        self.heads = nn.ModuleList([Head(head_dim, mask=True) for _ in range(num_heads)])
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, embeddings):

        output = torch.concat([head(embeddings) for head in self.heads], dim=-1)
        output = self.dropout(self.proj(output))

        return output

In [63]:
class MultiCrossHeadAttention(nn.Module):
    def __init__(self, num_heads, head_dim):
        super().__init__()
        
        self.heads = nn.ModuleList([Head(head_dim) for _ in range(num_heads)])
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, embeddings, encoder_embeddings):

        output = torch.concat([head(embeddings, encoder_embeddings) for head in self.heads], dim=-1)
        output = self.dropout(self.proj(output))

        return output

In [64]:
class FeedForward(nn.Module):

    def __init__(self, embedding_dim):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.ReLU(),
            nn.Linear(4 * embedding_dim, embedding_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, embeddings):
        return self.ffn(embeddings)

In [65]:
class DecoderLayer(nn.Module):

    def __init__(self, head_dim, embedding_dim):
        super().__init__()
        
        self.self_mha = MultiHeadAttention(num_heads, head_dim)
        self.cross_mha = MultiCrossHeadAttention(num_heads, head_dim)
        self.ffwd = FeedForward(embedding_dim)
        self.ln1 = nn.LayerNorm(embedding_dim)
        self.ln2 = nn.LayerNorm(embedding_dim)
        self.ln3 = nn.LayerNorm(embedding_dim)
    
    def forward(self, embeddings, encoder_embeddings):

        output = embeddings + self.ln1(self.self_mha(embeddings))
        output = output + self.ln2(self.cross_mha(output, encoder_embeddings))
        output = output + self.ln3(self.ffwd(output))

        return output

In [73]:
class Decoder(nn.Module):
    
    def __init__(self, num_layers):
        super().__init__()
        self.decoders = [DecoderLayer(head_dim, embedding_dim) for _ in range(num_layers)]
        self.final_linear = nn.Linear(embedding_dim, embedding_dim)
    
    def forward(self, embeddings, encoder_embeddings):
        output = embeddings
        for decoder in self.decoders:
            output = decoder(output, encoder_embeddings)
        output = self.final_linear(output)

        return output

In [74]:
x = torch.normal(mean=0.0, std=1.0, size=(batch_size, context_length, embedding_dim))
y = torch.normal(mean=0.0, std=5.0, size=(batch_size, context_length, embedding_dim))
decoder = Decoder(num_layers)
output = decoder(x, y)
output

tensor([[[ 1.3666e+00,  1.4253e+00,  2.7469e+00,  ..., -6.0862e-01,
           3.5372e-01,  2.6608e+00],
         [ 1.4086e+00,  3.7723e+00,  6.8644e-01,  ..., -2.8146e+00,
          -1.7313e+00,  1.5454e+00],
         [ 1.1342e+00,  4.5126e+00,  2.7331e+00,  ..., -3.9873e+00,
          -4.7525e-01,  2.9878e+00],
         ...,
         [-3.6420e-01,  1.9867e+00,  3.5650e+00,  ..., -2.8463e+00,
           3.5189e-01, -1.6098e-01],
         [-2.6925e+00,  4.9364e-01,  3.6550e+00,  ..., -1.6877e+00,
          -2.0918e+00,  1.0198e+00],
         [-1.7929e+00,  3.0967e+00,  3.2468e+00,  ..., -1.3510e+00,
           2.7414e-01,  2.5220e+00]],

        [[ 4.5378e+00,  2.5681e+00, -1.7457e+00,  ...,  2.6736e+00,
          -1.9984e+00,  1.3154e+00],
         [ 1.7903e+00,  2.3349e+00, -6.8890e-01,  ...,  2.0767e+00,
          -1.4876e+00, -4.5297e-01],
         [ 6.3503e-01,  2.8398e+00,  1.4644e-01,  ...,  1.3603e+00,
          -2.4791e+00,  1.0748e+00],
         ...,
         [ 2.8744e+00,  6