# Lecture 17.2: Multi-Head-Attention Mechanism with Weight Splits

In [613]:
import torch
from torch import nn

### input data / sequence

In [614]:
torch.manual_seed(123)

inputs = torch.tensor([[0.43, 0.15, 0.89, 0.55, 0.87, 0.66],                     
                       [0.57, 0.85, 0.64, 0.22, 0.58, 0.33],
                       [0.77, 0.25, 0.10, 0.05, 0.80, 0.55],]
                       )

### variable definition

In [615]:
batch = torch.stack((inputs, inputs), dim=0)
d_in = batch.shape[-1]
d_out = 6
context_length = batch.shape[1]
dropout = 0.0
num_heads = 2
head_dim = d_out // num_heads
print(batch.shape)
print(f"Input Dimensions: {d_in}\nOutput Dimensions: {d_out}\nContext Length: {context_length}")
print(f"Dropout: {dropout}\nNumber of Heads: {num_heads}\nHead Dimension: {head_dim}")

torch.Size([2, 3, 6])
Input Dimensions: 6
Output Dimensions: 6
Context Length: 3
Dropout: 0.0
Number of Heads: 2
Head Dimension: 3


# Multi-Head-Attention Class

### by and large similar to the causal attention class the Multi-Head-Attention Class is an extension with additional features 

In [616]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_in, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape

        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        attn_scores_masked_scaled = attn_scores / torch.sqrt(torch.tensor(keys.shape[-1]))
        attn_weights = torch.softmax(attn_scores_masked_scaled, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec

### creating an instance of the Multi-Head-Attention Class

In [617]:
multi_head_attention = MultiHeadAttention(d_in, d_out, context_length, dropout, num_heads)
context_matrix = multi_head_attention.forward(batch)
print(context_matrix)

tensor([[[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1117, -0.0547,  0.0406, -0.0213, -0.3251, -0.2993],
         [ 0.1196, -0.0491,  0.0318, -0.0635, -0.2788, -0.2578]],

        [[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1117, -0.0547,  0.0406, -0.0213, -0.3251, -0.2993],
         [ 0.1196, -0.0491,  0.0318, -0.0635, -0.2788, -0.2578]]],
       grad_fn=<ViewBackward0>)


# Context Vector Step by Step

### initializing input and variables

In [618]:
torch.manual_seed(50)
inputs_test = torch.tensor([[0.43, 0.15, 0.89, 0.55, 0.87, 0.66],                     
                            [0.57, 0.85, 0.64, 0.22, 0.58, 0.33],
                            [0.77, 0.25, 0.10, 0.05, 0.80, 0.55],]
                            )

batch_size1 = 1
batch_2 = torch.stack((inputs, inputs), dim=0)
batch_size2 = batch_2.shape[0]
context_length = inputs_test.shape[0]
input_dimension = inputs_test.shape[-1]
output_dimension = input_dimension
number_heads = 2
head_dimension = output_dimension // number_heads
dropout = 0.0

print(f"Context Length: {context_length}\nInput Dimension: {input_dimension}\nOutput Dimension: {output_dimension}")
print(f"Number of Attention Heads: {number_heads}\nHead Dimension: {head_dimension}\n")
print(f"\nInput Embeddings:\n{inputs_test}\n{inputs_test.shape}")

Context Length: 3
Input Dimension: 6
Output Dimension: 6
Number of Attention Heads: 2
Head Dimension: 3


Input Embeddings:
tensor([[0.4300, 0.1500, 0.8900, 0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400, 0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000, 0.0500, 0.8000, 0.5500]])
torch.Size([3, 6])


### initializing trainable weight matrices

In [619]:
torch.manual_seed(50)
W_query = torch.rand(6, 6)
W_key = torch.rand(6, 6)
W_value = torch.rand(6, 6)

print(f"Trainable Weight Matrix Query:\n{W_query}\n{W_query.shape}\n")
print(f"Trainable Weight Matrix Key:\n{W_key}\n{W_key.shape}\n")
print(f"Trainable Weight Matrix Value:\n{W_value}\n{W_value.shape}")

Trainable Weight Matrix Query:
tensor([[0.6180, 0.0687, 0.3893, 0.0404, 0.4013, 0.1442],
        [0.4605, 0.4877, 0.5927, 0.9634, 0.1230, 0.4048],
        [0.4985, 0.9987, 0.6049, 0.5229, 0.6974, 0.2505],
        [0.3624, 0.4621, 0.7145, 0.5058, 0.0518, 0.2492],
        [0.2395, 0.4233, 0.0022, 0.6848, 0.7497, 0.2489],
        [0.3490, 0.1953, 0.2792, 0.2526, 0.3792, 0.7686]])
torch.Size([6, 6])

Trainable Weight Matrix Key:
tensor([[0.6907, 0.7526, 0.1184, 0.8699, 0.8391, 0.0532],
        [0.1705, 0.2914, 0.5682, 0.0380, 0.2715, 0.6987],
        [0.0944, 0.6719, 0.3782, 0.8738, 0.7267, 0.9486],
        [0.4388, 0.4589, 0.0359, 0.3525, 0.0389, 0.3880],
        [0.1319, 0.5018, 0.8077, 0.9322, 0.5680, 0.5034],
        [0.2074, 0.4779, 0.4203, 0.4014, 0.0436, 0.6490]])
torch.Size([6, 6])

Trainable Weight Matrix Value:
tensor([[0.9007, 0.8274, 0.3785, 0.7518, 0.2839, 0.7220],
        [0.8421, 0.2474, 0.6834, 0.0551, 0.7824, 0.7901],
        [0.9596, 0.6065, 0.5926, 0.7300, 0.9768, 0.9680

### computing queries, keys and values matrices by multiplying inputs with trainable weight matrices

### 2 X 3 X 6 ---> 2 - number of batches, 3 - context length (number of tokens), 6 - token embedding dimension

In [620]:
torch.manual_seed(50)
queries = batch_2 @ W_query
keys = batch_2 @ W_key
values = batch_2 @ W_value

print(f"Queries Matrix:\n{queries}\n{queries.shape}\n")
print(f"Keys Matrix:\n{keys}\n{keys.shape}\n")
print(f"Values Matrix:\n{values}\n{values.shape}")

Queries Matrix:
tensor([[[1.4165, 1.7428, 1.3738, 1.6680, 1.7427, 1.2065],
         [1.3965, 1.5045, 1.3634, 1.7684, 1.3510, 1.0393],
         [1.0425, 0.7438, 0.6994, 1.0363, 1.2204, 0.8715]],

        [[1.4165, 1.7428, 1.3738, 1.6680, 1.7427, 1.2065],
         [1.3965, 1.5045, 1.3634, 1.7684, 1.3510, 1.0393],
         [1.0425, 0.7438, 0.6994, 1.0363, 1.2204, 0.8715]]])
torch.Size([2, 3, 6])

Keys Matrix:
tensor([[[0.8996, 1.9697, 1.4726, 2.4273, 1.5926, 2.0517],
         [0.8405, 1.6564, 1.4076, 1.8381, 1.5266, 1.8229],
         [0.8254, 1.4067, 1.1502, 1.7509, 1.2670, 1.0896]],

        [[0.8996, 1.9697, 1.4726, 2.4273, 1.5926, 2.0517],
         [0.8405, 1.6564, 1.4076, 1.8381, 1.5266, 1.8229],
         [0.8254, 1.4067, 1.1502, 1.7509, 1.2670, 1.0896]]])
torch.Size([2, 3, 6])

Values Matrix:
tensor([[[2.5330, 1.7337, 1.9566, 2.3132, 2.3171, 1.7606],
         [2.5011, 1.4389, 1.7511, 1.7171, 2.0760, 1.9639],
         [1.9532, 1.0375, 1.1871, 1.6099, 1.1633, 1.1343]],

        [[2.533

### reshaping the queries, keys and values matrices by splitting output dimension in number of heads and head dimension

### 2 X 3 X 2 X 3 ---> 2 main blocks, 3 sub blocks per main block, 2 rows per sub block, 3 columns overall
#### 2 - number of batches (main blocks), 3 - tokens per batch (sub blocks), 2 - number of attention heads (rows per sub block), 3 - head dimensions (columns)

In [621]:
torch.manual_seed(50)
queries_rs = queries.view(batch_size2, context_length, number_heads, head_dimension)
keys_rs = keys.view(batch_size2, context_length, number_heads, head_dimension)
values_rs = values.view(batch_size2, context_length, number_heads, head_dimension)

print(f"Queries Matrix reshaped:\n{queries_rs}\n{queries_rs.shape}\n")
print(f"Keys Matrix reshaped:\n{keys_rs}\n{keys_rs.shape}\n")
print(f"Values Matrix reashaped:\n{values_rs}\n{values_rs.shape}")

Queries Matrix reshaped:
tensor([[[[1.4165, 1.7428, 1.3738],
          [1.6680, 1.7427, 1.2065]],

         [[1.3965, 1.5045, 1.3634],
          [1.7684, 1.3510, 1.0393]],

         [[1.0425, 0.7438, 0.6994],
          [1.0363, 1.2204, 0.8715]]],


        [[[1.4165, 1.7428, 1.3738],
          [1.6680, 1.7427, 1.2065]],

         [[1.3965, 1.5045, 1.3634],
          [1.7684, 1.3510, 1.0393]],

         [[1.0425, 0.7438, 0.6994],
          [1.0363, 1.2204, 0.8715]]]])
torch.Size([2, 3, 2, 3])

Keys Matrix reshaped:
tensor([[[[0.8996, 1.9697, 1.4726],
          [2.4273, 1.5926, 2.0517]],

         [[0.8405, 1.6564, 1.4076],
          [1.8381, 1.5266, 1.8229]],

         [[0.8254, 1.4067, 1.1502],
          [1.7509, 1.2670, 1.0896]]],


        [[[0.8996, 1.9697, 1.4726],
          [2.4273, 1.5926, 2.0517]],

         [[0.8405, 1.6564, 1.4076],
          [1.8381, 1.5266, 1.8229]],

         [[0.8254, 1.4067, 1.1502],
          [1.7509, 1.2670, 1.0896]]]])
torch.Size([2, 3, 2, 3])

Values 

### transposing the queries, keys and values matrices inner 2 dimensions, tokens per batch and number of attention heads swap places

### 2 X 2 X 3 X 3 ---> 2 main blocks, 2 attention heads per block (sub blocks), 3 tokens per batch (in each sub block), 3 as number of head dimensions


In [622]:
torch.manual_seed(50)
queries_t = queries_rs.transpose(1, 2)
keys_t = keys_rs.transpose(1, 2)
values_t = values_rs.transpose(1, 2)

print(f"Queries transposed:\n{queries_t}\n{queries_t.shape}\n")
print(f"Keys transposed:\n{keys_t}\n{keys_t.shape}\n")
print(f"Values transposed:\n{values_t}\n{values_t.shape}")

Queries transposed:
tensor([[[[1.4165, 1.7428, 1.3738],
          [1.3965, 1.5045, 1.3634],
          [1.0425, 0.7438, 0.6994]],

         [[1.6680, 1.7427, 1.2065],
          [1.7684, 1.3510, 1.0393],
          [1.0363, 1.2204, 0.8715]]],


        [[[1.4165, 1.7428, 1.3738],
          [1.3965, 1.5045, 1.3634],
          [1.0425, 0.7438, 0.6994]],

         [[1.6680, 1.7427, 1.2065],
          [1.7684, 1.3510, 1.0393],
          [1.0363, 1.2204, 0.8715]]]])
torch.Size([2, 2, 3, 3])

Keys transposed:
tensor([[[[0.8996, 1.9697, 1.4726],
          [0.8405, 1.6564, 1.4076],
          [0.8254, 1.4067, 1.1502]],

         [[2.4273, 1.5926, 2.0517],
          [1.8381, 1.5266, 1.8229],
          [1.7509, 1.2670, 1.0896]]],


        [[[0.8996, 1.9697, 1.4726],
          [0.8405, 1.6564, 1.4076],
          [0.8254, 1.4067, 1.1502]],

         [[2.4273, 1.5926, 2.0517],
          [1.8381, 1.5266, 1.8229],
          [1.7509, 1.2670, 1.0896]]]])
torch.Size([2, 2, 3, 3])

Values transposed:
tensor

### before the Attention Scores Matrix is computed the last 2 dimensions of the reshaped and transposed Keys matrix have to be swapped 

### within each sub block of the 2 batches rows become columns and columns become rows 

In [623]:
torch.manual_seed(50)
keys_t_t = keys_t.transpose(2, 3)
attn_scores = queries_t @ keys_t_t

print(f"Keys Matrix with its last two dimensions transposed:\n{keys_t_t}\n")
print(f"Attention Scores Matrix:\n{attn_scores}\n{attn_scores.shape}")

Keys Matrix with its last two dimensions transposed:
tensor([[[[0.8996, 0.8405, 0.8254],
          [1.9697, 1.6564, 1.4067],
          [1.4726, 1.4076, 1.1502]],

         [[2.4273, 1.8381, 1.7509],
          [1.5926, 1.5266, 1.2670],
          [2.0517, 1.8229, 1.0896]]],


        [[[0.8996, 0.8405, 0.8254],
          [1.9697, 1.6564, 1.4067],
          [1.4726, 1.4076, 1.1502]],

         [[2.4273, 1.8381, 1.7509],
          [1.5926, 1.5266, 1.2670],
          [2.0517, 1.8229, 1.0896]]]])

Attention Scores Matrix:
tensor([[[[6.7302, 6.0112, 5.2011],
          [6.2273, 5.5849, 4.8373],
          [3.4329, 3.0928, 2.7114]],

         [[9.2994, 7.9254, 6.4430],
          [8.5764, 7.2074, 5.9404],
          [6.2472, 5.3565, 4.3103]]],


        [[[6.7302, 6.0112, 5.2011],
          [6.2273, 5.5849, 4.8373],
          [3.4329, 3.0928, 2.7114]],

         [[9.2994, 7.9254, 6.4430],
          [8.5764, 7.2074, 5.9404],
          [6.2472, 5.3565, 4.3103]]]])
torch.Size([2, 2, 3, 3])


### the mask is being initialized

In [624]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

print(f"Mask to be set before application:\n{mask}")

Mask to be set before application:
tensor([[0., 1., 1.],
        [0., 0., 1.],
        [0., 0., 0.]])


### the mask is being applied on the attention scores matrix to mask out future tokens

In [625]:
attn_scores_masked = attn_scores.masked_fill(mask.bool(), -torch.inf)

print(f"Mask applied on Attention Scores:\n{attn_scores_masked}")

Mask applied on Attention Scores:
tensor([[[[6.7302,   -inf,   -inf],
          [6.2273, 5.5849,   -inf],
          [3.4329, 3.0928, 2.7114]],

         [[9.2994,   -inf,   -inf],
          [8.5764, 7.2074,   -inf],
          [6.2472, 5.3565, 4.3103]]],


        [[[6.7302,   -inf,   -inf],
          [6.2273, 5.5849,   -inf],
          [3.4329, 3.0928, 2.7114]],

         [[9.2994,   -inf,   -inf],
          [8.5764, 7.2074,   -inf],
          [6.2472, 5.3565, 4.3103]]]])


### the masked attention scores are scaled by the square root of the last keys dimension (after reshaping the keys matrix)

In [626]:
attn_scores_masked_scaled = attn_scores_masked / torch.sqrt(torch.tensor(keys_t.shape[-1]))

print(f"Original keys Matrix shape: {keys.shape}\n")
print(f"Keys Matrix after splitting last dimension: {keys_rs.shape}\n")
print(f"Keys Matrix after transposition of inner dimensions: {keys_t.shape}\n")
print(f"Masked Attention Scores scaled by the square root of the keys dimension:\n{attn_scores_masked_scaled}")

Original keys Matrix shape: torch.Size([2, 3, 6])

Keys Matrix after splitting last dimension: torch.Size([2, 3, 2, 3])

Keys Matrix after transposition of inner dimensions: torch.Size([2, 2, 3, 3])

Masked Attention Scores scaled by the square root of the keys dimension:
tensor([[[[3.8857,   -inf,   -inf],
          [3.5953, 3.2244,   -inf],
          [1.9820, 1.7856, 1.5654]],

         [[5.3690,   -inf,   -inf],
          [4.9516, 4.1612,   -inf],
          [3.6068, 3.0926, 2.4886]]],


        [[[3.8857,   -inf,   -inf],
          [3.5953, 3.2244,   -inf],
          [1.9820, 1.7856, 1.5654]],

         [[5.3690,   -inf,   -inf],
          [4.9516, 4.1612,   -inf],
          [3.6068, 3.0926, 2.4886]]]])


### attention weights are being calculated by applying the softmax function on the masked and scaled attention scores 

In [627]:
attn_weights_sm = torch.softmax(attn_scores_masked_scaled, dim=-1)

print(f"Attention Weights:\n{attn_weights_sm}")

Attention Weights:
tensor([[[[1.0000, 0.0000, 0.0000],
          [0.5917, 0.4083, 0.0000],
          [0.4031, 0.3312, 0.2657]],

         [[1.0000, 0.0000, 0.0000],
          [0.6879, 0.3121, 0.0000],
          [0.5195, 0.3107, 0.1698]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.5917, 0.4083, 0.0000],
          [0.4031, 0.3312, 0.2657]],

         [[1.0000, 0.0000, 0.0000],
          [0.6879, 0.3121, 0.0000],
          [0.5195, 0.3107, 0.1698]]]])


### a dropout is applied 

In [628]:
torch.manual_seed(50)
dropout = nn.Dropout(dropout)
attn_weights_do = dropout(attn_weights_sm)

print(f"Attention Weights:\n{attn_weights_do}\n{attn_weights_do.shape}")

Attention Weights:
tensor([[[[1.0000, 0.0000, 0.0000],
          [0.5917, 0.4083, 0.0000],
          [0.4031, 0.3312, 0.2657]],

         [[1.0000, 0.0000, 0.0000],
          [0.6879, 0.3121, 0.0000],
          [0.5195, 0.3107, 0.1698]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.5917, 0.4083, 0.0000],
          [0.4031, 0.3312, 0.2657]],

         [[1.0000, 0.0000, 0.0000],
          [0.6879, 0.3121, 0.0000],
          [0.5195, 0.3107, 0.1698]]]])
torch.Size([2, 2, 3, 3])


### the context vector is being computed 

In [629]:
context_vector = (attn_weights_do @ values_t).transpose(1, 2)

print(f"Attention Weights Shape: {attn_weights_do.shape}\nValues transposed Shape: {values_t.shape}")
print(f"Context Vector:\n{context_vector}\n{context_vector.shape}")

Attention Weights Shape: torch.Size([2, 2, 3, 3])
Values transposed Shape: torch.Size([2, 2, 3, 3])
Context Vector:
tensor([[[[2.5330, 1.7337, 1.9566],
          [2.3132, 2.3171, 1.7606]],

         [[2.5200, 1.6133, 1.8727],
          [2.1271, 2.2419, 1.8240]],

         [[2.3684, 1.4510, 1.6841],
          [2.0086, 2.0463, 1.7174]]],


        [[[2.5330, 1.7337, 1.9566],
          [2.3132, 2.3171, 1.7606]],

         [[2.5200, 1.6133, 1.8727],
          [2.1271, 2.2419, 1.8240]],

         [[2.3684, 1.4510, 1.6841],
          [2.0086, 2.0463, 1.7174]]]])
torch.Size([2, 3, 2, 3])


In [630]:
context_vector = context_vector.contiguous().view(batch_size2, context_length, output_dimension)

print(f"Final Context Vector:\n{context_vector}\n{context_vector.shape}")

Final Context Vector:
tensor([[[2.5330, 1.7337, 1.9566, 2.3132, 2.3171, 1.7606],
         [2.5200, 1.6133, 1.8727, 2.1271, 2.2419, 1.8240],
         [2.3684, 1.4510, 1.6841, 2.0086, 2.0463, 1.7174]],

        [[2.5330, 1.7337, 1.9566, 2.3132, 2.3171, 1.7606],
         [2.5200, 1.6133, 1.8727, 2.1271, 2.2419, 1.8240],
         [2.3684, 1.4510, 1.6841, 2.0086, 2.0463, 1.7174]]])
torch.Size([2, 3, 6])
