In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
batch_size = 12
context_length = 36
embedding_dim = 72
num_heads = 6
head_dim = embedding_dim // num_heads
num_layers = 6
dropout=0.2

In [3]:
class Head(nn.Module):

    def __init__(self, head_dim, mask=False):
        super().__init__()

        self.key = nn.Linear(embedding_dim, head_dim, bias=False)
        self.query = nn.Linear(embedding_dim, head_dim, bias=False)
        self.value = nn.Linear(embedding_dim, head_dim, bias=False)
        self.mask = mask
        if self.mask:
            self.register('tril', torch.tril(torch.ones(context_length, context_length)))
    
    def forward(self, embeddings):

        B, T, C = embeddings.shape

        key = self.key(embeddings) # (B, T, C)
        query = self.query(embeddings) # (B, T, C)

        # compute the weigts or scores
        wei = query @ key.transpose(-2, -1) * C**-0.5 # (B, T, T)
        if self.mask:
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # (B, T, T)
        
        wei = F.softmax(wei, dim=-1)

        value = self.value(embeddings) # (B, T, C)
        output = wei @ value # (B, T, T) * (B, T, C) -> (B, T, C)

        return output

In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_dim, mask):
        super().__init__()

        self.heads = nn.ModuleList([Head(head_dim, mask) for _ in range(num_heads)])
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, embeddings):

        output = torch.concat([head(embeddings) for head in self.heads], dim=-1) # concat along last dimension b/c the original embedding_dim is divided into n_heads times, each of size head_dim
        output = self.dropout(self.proj(output))

        return output