# Lecture 17.1: Multi-Head-Attention Mechanism

### extension of the causal attention mechanism

In [36]:
import torch
from torch import nn

inputs = torch.tensor([[0.43, 0.15, 0.89],
                       [0.55, 0.87, 0.66],
                       [0.57, 0.85, 0.64],
                       [0.22, 0.58, 0.33],
                       [0.77, 0.25, 0.10],
                       [0.05, 0.80, 0.55]]
                       )

### defining variables

In [37]:
context_length = inputs.shape[0]
d_in = inputs.shape[1]
d_out = 2
batch = torch.stack((inputs, inputs, inputs), dim=0)
context_length = batch.shape[1]
dropout = 0.0
num_heads = 3
print(f"Context Length: {context_length}\nInput Dimension: {d_in}\nOutput Dimension: {d_out}")

Context Length: 6
Input Dimension: 3
Output Dimension: 2


# Redefining the Causal-Attention Class

In [None]:
class CausalAttentionV1(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
    
    def forward(self, x):
        b, num_tokens, d = x.shape
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores_masked = attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_scores_masked_scaled = attn_scores_masked / torch.sqrt(torch.tensor(keys.shape[-1]))
        attn_weights = torch.softmax(attn_scores_masked_scaled, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_matrix = attn_weights @ values
        return context_matrix

### creating an instance of the causal attention class

In [39]:
torch.manual_seed(123)

causal_attention = CausalAttentionV1(d_in, d_out, context_length, 0.0)
print(causal_attention.forward(batch))

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)


# Multi-Head-Attention Wrapper Class

In [40]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList([
            CausalAttentionV1(d_in, d_out, context_length, dropout, qkv_bias)
            for _ in range(num_heads)
        ])
    
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

### creating an instance of the multi head attention wrapper

In [41]:
torch.manual_seed(123)
multi_head_attention = MultiHeadAttentionWrapper(d_in, d_out, context_length, dropout, num_heads)
context_matrix = multi_head_attention.forward(batch)
print(f"{context_matrix}\n{context_matrix.shape}")

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063,  0.4566,  0.2729],
         [-0.5874,  0.0058,  0.5891,  0.3257,  0.5792,  0.3011],
         [-0.6300, -0.0632,  0.6202,  0.3860,  0.6249,  0.3102],
         [-0.5675, -0.0843,  0.5478,  0.3589,  0.5691,  0.2785],
         [-0.5526, -0.0981,  0.5321,  0.3428,  0.5543,  0.2520],
         [-0.5299, -0.1081,  0.5077,  0.3493,  0.5337,  0.2499]],

        [[-0.4519,  0.2216,  0.4772,  0.1063,  0.4566,  0.2729],
         [-0.5874,  0.0058,  0.5891,  0.3257,  0.5792,  0.3011],
         [-0.6300, -0.0632,  0.6202,  0.3860,  0.6249,  0.3102],
         [-0.5675, -0.0843,  0.5478,  0.3589,  0.5691,  0.2785],
         [-0.5526, -0.0981,  0.5321,  0.3428,  0.5543,  0.2520],
         [-0.5299, -0.1081,  0.5077,  0.3493,  0.5337,  0.2499]],

        [[-0.4519,  0.2216,  0.4772,  0.1063,  0.4566,  0.2729],
         [-0.5874,  0.0058,  0.5891,  0.3257,  0.5792,  0.3011],
         [-0.6300, -0.0632,  0.6202,  0.3860,  0.6249,  0.3102],
         [-0.5675, -0

## 1st Dimension --> number of batches = 3
## 2nd Dimension --> context length = number of tokens in the sequence = 6
## 3rd Dimension --> output Dimension * Number of Heads 2 * 3 = 6