In [1]:
import torch
import torch.nn as nn

In [3]:
inputs = torch.tensor([
    [0.43, 0.15, 0.89], # Your      (x^1)
    [0.55, 0.87, 0.66], # journey   (x^2)
    [0.57, 0.85, 0.64], # starts    (x^3)
    [0.22, 0.58, 0.33], # with      (x^4)
    [0.77, 0.25, 0.10], # one       (x^5)
    [0.05, 0.80, 0.55]  # step      (x^6)
])

batch = torch.stack((inputs, inputs), dim=0)

In [41]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
        
        self.Wq = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wk = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wv = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
        
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        
        # QKV Shape: (b, num_tokens, d_out)
        queries = self.Wq(x) 
        keys = self.Wk(x)
        values = self.Wv(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) => (b, num_tokens, num_heads, head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) => (b, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        
        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        
        # Uses the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [42]:
batch_size, context_length, d_in = batch.shape # 2, 6, 3
d_out = 2
dropout = 0.0
num_heads = 2

In [43]:
torch.manual_seed(666)

mha = MultiHeadAttention(d_in, d_out, context_length, dropout, num_heads)
context_vecs = mha(batch)

In [15]:
print(context_vecs, end="\n\n")
print("MHA-Context Vectors Shape =>", context_vecs.shape)

tensor([[[-0.0264,  0.4706],
         [-0.0254,  0.4416],
         [-0.0257,  0.4319],
         [-0.0198,  0.4514],
         [-0.0245,  0.4371],
         [-0.0170,  0.4541]],

        [[-0.0264,  0.4706],
         [-0.0254,  0.4416],
         [-0.0257,  0.4319],
         [-0.0198,  0.4514],
         [-0.0245,  0.4371],
         [-0.0170,  0.4541]]], grad_fn=<ViewBackward0>)

MHA-Context Vectors Shape => torch.Size([2, 6, 2])


In [16]:
# Exercise 3.3 Initializing GPT-2 Size Attention Modules:

context_length_gpt2 = 1024
d_in_gpt2 = 768
d_out_gpt2 = 768
num_heads_gpt2 = 12
dropout_gpt2 = 0.1

mha_gpt2 = MultiHeadAttention(d_in_gpt2, d_out_gpt2, context_length_gpt2, dropout_gpt2, num_heads_gpt2)

In [27]:
print("GPT-2 Sized Attention Module\n\n", mha_gpt2)

GPT-2 Sized Attention Module

 MultiHeadAttention(
  (Wq): Linear(in_features=768, out_features=768, bias=False)
  (Wk): Linear(in_features=768, out_features=768, bias=False)
  (Wv): Linear(in_features=768, out_features=768, bias=False)
  (out_proj): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)


In [28]:
def count_parameters(model):
    return "{:,}".format(sum(p.numel() for p in model.parameters() if p.requires_grad))

In [31]:
# 4 Layers => Wq, Wk, Wv, out_proj

# Wq, Wk & Wv has weights of Shape (d_in, d_out)
# out_proj has weights of Shape (d_out, d_out)

# Only out_proj has bias vector of size 768

# 4 * (768 * 768) = 2,359,296 + 768 = 2,360,064

print("GPT2 Parameter Count =>", count_parameters(mha_gpt2))

GPT2 Parameter Count => 2,360,064
