<a href="https://colab.research.google.com/github/chen-star/llm_model_trainings/blob/main/3_3_transformer_impl_multi_head_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h1> ‚≠ê Multi-head Attention ‚≠ê

# ‚úà Imports

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import override

# üî¢ Hyperparameters

In [2]:
batch_size = 3

embedding_dimension = 100
num_heads = 4 # embedding_dimension must be divisible by num_heads

context_window_size = 8

# üëì Multi-head Attention

Review the single-head attention defined in [3_2_transformer_impl_transformer_block.ipynb](https://github.com/chen-star/llm_model_trainings/blob/main/3_2_transformer_impl_transformer_block.ipynb)

In [3]:
class AttentionHead(nn.Module):
  def __init__(self, embedding_dimension):
    super().__init__()

    # define W_Q, W_K, W_V
    self.q_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)
    self.k_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)
    self.v_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)

    # define W0
    self.w0_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)


  @override
  def forward(self, X):
    # Q = XW_Q
    # K = XW_K
    # V = XW_V
    Q = self.q_layer(X)
    K = self.k_layer(X)
    V = self.v_layer(X)

    attention_score = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
    return self.w0_layer(attention_score)

Define the multi-head attention

In [4]:
class MultiHeadAttention(nn.Module):
  def __init__(self, embedding_dimension, num_heads):
    super().__init__()

    # define W_Q, W_K, W_V
    self.q_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)
    self.k_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)
    self.v_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)

    # define W0
    self.w0_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)

    # ***** New in multi-head *****
    self.num_heads = num_heads
    self.head_dimension = embedding_dimension // num_heads
    # *****************************


  @override
  def forward(self, X, print_dimension_info: bool=False):
    batch_size, context_window_size, embedding_dimension = X.shape
    if print_dimension_info: print(f"X.shape: {X.shape}")

    # Q = XW_Q
    # K = XW_K
    # V = XW_V
    Q = self.q_layer(X)
    K = self.k_layer(X)
    V = self.v_layer(X)
    if print_dimension_info: print(f"Before split: Q.shape: {Q.shape}, K.shape: {K.shape}, V.shape: {V.shape}")

    # ***** Split Q,K,V *****
    Q = Q.view(batch_size, context_window_size, self.num_heads, self.head_dimension)
    K = K.view(batch_size, context_window_size, self.num_heads, self.head_dimension)
    V = V.view(batch_size, context_window_size, self.num_heads, self.head_dimension)
    if print_dimension_info: print(f"After split: Q.shape: {Q.shape}, K.shape: {K.shape}, V.shape: {V.shape}")

    # For attention score calculation, pytorch expects the shape to be
    # [batch_size, num_heads, context_window_size, head_dimension]
    Q = Q.transpose(1,2)
    K = K.transpose(1,2)
    V = V.transpose(1,2)
    if print_dimension_info: print(f"After transpose: Q.shape: {Q.shape}, K.shape: {K.shape}, V.shape: {V.shape}")
    # *****************************

    attention_score = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
    if print_dimension_info: print(f"attention_score.shape: {attention_score.shape}")

    # Transpose back
    attention_score = attention_score.transpose(1,2)
    if print_dimension_info: print(f"Transposed attention_score.shape: {attention_score.shape}")

    # ***** Merge heads *****
    attention_score = attention_score.reshape(batch_size, context_window_size, embedding_dimension)
    if print_dimension_info: print(f"After merge heads: attention_score.shape: {attention_score.shape}")
    # *****************************

    return self.w0_layer(attention_score)

In [5]:
multi_head_attention = MultiHeadAttention(embedding_dimension, num_heads=num_heads)
multi_head_attention

MultiHeadAttention(
  (q_layer): Linear(in_features=100, out_features=100, bias=False)
  (k_layer): Linear(in_features=100, out_features=100, bias=False)
  (v_layer): Linear(in_features=100, out_features=100, bias=False)
  (w0_layer): Linear(in_features=100, out_features=100, bias=False)
)

In [6]:
# pass data once to test
print(f"batch_size: {batch_size}")
print(f"context_window_size: {context_window_size}")
print(f"embedding_dimension: {embedding_dimension}")
print(f"num_heads: {num_heads}")
print(f"head_dimension: {embedding_dimension // num_heads}\n")

random_X = torch.randn(batch_size, context_window_size, embedding_dimension)
output = multi_head_attention(random_X, print_dimension_info=True)

batch_size: 3
context_window_size: 8
embedding_dimension: 100
num_heads: 4
head_dimension: 25

X.shape: torch.Size([3, 8, 100])
Before split: Q.shape: torch.Size([3, 8, 100]), K.shape: torch.Size([3, 8, 100]), V.shape: torch.Size([3, 8, 100])
After split: Q.shape: torch.Size([3, 8, 4, 25]), K.shape: torch.Size([3, 8, 4, 25]), V.shape: torch.Size([3, 8, 4, 25])
After transpose: Q.shape: torch.Size([3, 4, 8, 25]), K.shape: torch.Size([3, 4, 8, 25]), V.shape: torch.Size([3, 4, 8, 25])
attention_score.shape: torch.Size([3, 4, 8, 25])
Transposed attention_score.shape: torch.Size([3, 8, 4, 25])
After merge heads: attention_score.shape: torch.Size([3, 8, 100])
