In [2]:
import torch
import torch.nn as nn

In [3]:
input_embeddings = torch.tensor([
    [0.43, 0.15, 0.89], # Your    -> x_0
    [0.55, 0.87, 0.66], # journey -> x_1
    [0.57, 0.85, 0.64], # starts  -> x_2
    [0.22, 0.58, 0.33], # with    -> x_3
    [0.77, 0.25, 0.10], # one     -> x_4
    [0.05, 0.80, 0.55], # step    -> x_5
])
input_embeddings = torch.stack([input_embeddings, input_embeddings])
input_embeddings.shape

torch.Size([2, 6, 3])

## simple, inefficient approach => stack multiple causal self-attention layers

In [4]:
class CausalAttention(nn.Module):
    def __init__(self, d_in: int, d_out: int, context_length: int, dropout: float):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out)
        self.W_key = nn.Linear(d_in, d_out)
        self.W_value = nn.Linear(d_in, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
    
    def forward(self, x):
        batch_size, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec

In [5]:
class SimpleMultiHeadAttention(nn.Module):
    def __init__(self, d_in: int, d_out: int, context_length: int, dropout: float, num_heads: int):
        super().__init__()
        self.heads = nn.ModuleList([
            CausalAttention(d_in, d_out, context_length, dropout) for _ in range(num_heads)
        ])
    
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [6]:
torch.manual_seed(123)
context_length = input_embeddings.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mh_attn = SimpleMultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mh_attn(input_embeddings)
display(context_vecs.shape)  # the last dimension is d_out * num_heads, i.e. the head outputs are concatenated
display(context_vecs)

torch.Size([2, 6, 4])

tensor([[[-0.0960,  0.7940, -0.2296,  0.3355],
         [ 0.0285,  0.9387, -0.3357,  0.3490],
         [ 0.0657,  0.9850, -0.3718,  0.3576],
         [ 0.1062,  0.9604, -0.3583,  0.3177],
         [ 0.0659,  0.9308, -0.3349,  0.3549],
         [ 0.1188,  0.9375, -0.3443,  0.3120]],

        [[-0.0960,  0.7940, -0.2296,  0.3355],
         [ 0.0285,  0.9387, -0.3357,  0.3490],
         [ 0.0657,  0.9850, -0.3718,  0.3576],
         [ 0.1062,  0.9604, -0.3583,  0.3177],
         [ 0.0659,  0.9308, -0.3349,  0.3549],
         [ 0.1188,  0.9375, -0.3443,  0.3120]]], grad_fn=<CatBackward0>)

## efficient appraoch => multi-head causal self-attention as one class

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in: int, d_out: int, context_length: int, dropout: float, num_heads: int):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out)
        self.W_key = nn.Linear(d_in, d_out)
        self.W_value = nn.Linear(d_in, d_out)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
    
    def forward(self, x):
        batch_size, num_tokens, d_in = x.shape
        # do linear transformation for all heads at once
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        # reshape the outputs to isolate the heads
        keys = keys.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        values = values.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        # calculate attention scores & weights
        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        # calculate context vectors
        context_vec = (attn_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(batch_size, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

In [9]:
torch.manual_seed(123)
batch_size, context_length, d_in = input_embeddings.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(input_embeddings)
display(context_vecs.shape)
display(context_vecs)

torch.Size([2, 6, 2])

tensor([[[ 0.7732, -0.2205],
         [ 0.7706, -0.1791],
         [ 0.7684, -0.1686],
         [ 0.7485, -0.1963],
         [ 0.7558, -0.1972],
         [ 0.7427, -0.2082]],

        [[ 0.7732, -0.2205],
         [ 0.7706, -0.1791],
         [ 0.7684, -0.1686],
         [ 0.7485, -0.1963],
         [ 0.7558, -0.1972],
         [ 0.7427, -0.2082]]], grad_fn=<ViewBackward0>)