In [3]:
import torch
import torch.nn as nn

<h3>Implementing Multi-head attention</h3>
This is achieved by creating multiple instances of self-attention mechanism each with their own weights and combining the outputs.<br>
To implement this, we will create a MultiHeadAttentionWrapper class that stacks multiple instances of the previous CasualAttention module

In [4]:
class CasualAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape #dimension of the new batch
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2) #tranposing only the inner dimension leaving the outer dimension as it is
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # no of tokens to account for
        attn_weights= torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vector = attn_weights @ values

        return context_vector


In [5]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CasualAttention(d_in, d_out, context_length, dropout, qkv_bias)
             for _ in range(num_heads)] #get the num_heads outputs from the casual attention mechanism
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1) #this concatenates the output from the casual attention, dimension is -1 because we are concatenating along the column
    

In [8]:
#getting the vectors for the input sequence
inputs =  torch.tensor(
    [
        [0.43, 0.15, 0.89], #Your x1
        [0.55, 0.87, 0.66], #journey x2
        [0.57, 0.85, 0.64], #starts x3
        [0.22, 0.58, 0.33], #with x4
        [0.77, 0.25, 0.10], #one  x5
        [0.05, 0.80, 0.55]  #step x6
    ]
)
inputs.shape

torch.Size([6, 3])

In [9]:
#creating the batch with two inputs
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)
print(batch)

torch.Size([2, 6, 3])
tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])


In [10]:
torch.manual_seed(123)
context_length = batch.shape[1] #number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vectors = mha(batch)
print(context_vectors)
print(f"Context vectors.shape: {context_vectors.shape} ")

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
Context vectors.shape: torch.Size([2, 6, 4]) 


Implementing multi Head attention with weights splits, this is a more efficeint way to implement the MHA

In [17]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0),  "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads #reduce the projection dim to match the required output dim

        #initializing the trainable weights for key , value and query 
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(self.d_out, self.d_out) #linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_token, d_in = x.shape

        #obtaining the key, query and vallues matrices
        keys = self.W_key(x) #shape: b, num_tokens, d_out
        queries = self.W_query(x)
        values = self.W_value(x)

        #we implicitly split the matrix by adding a num-heads dimeneion
        #unroll last die: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        #ie, split the d_out column into two extra columns and add it to the trainable weights.
        keys = keys.view(b, num_token, self.num_heads, self.head_dim)
        values = values.view(b, num_token, self.num_heads, self.head_dim)
        queries = queries.view(b, num_token, self.num_heads, self.head_dim)

        #Transpose: (b, num_tokens, num_heads, heads_dim) -> (b, num_heads, num_tokens, head_dim)
        #group matrices by num_heads inteasd of num_tokens
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        queries = queries.transpose(1, 2)

        #compute scaled dot-product attention (aka self-attention) with a casual mask
        #that is, find the attention scores
        attn_scores = queries @ keys.transpose(2, 3) #dot product for each head

        #original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask[:num_token, :num_token].bool()

        #use the mask trucated to the number of tokens and converted to boolean
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / (keys.shape[-1]**0.5), dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vector = (attn_weights @ values).transpose(1, 2)

        #combine heads, where self.d_out = self.num_heads + self.head_dim
        context_vector = context_vector.contiguous().view(b, num_token, self.d_out)#flatten
        context_vector = self.out_proj(context_vector) #original otimal projection

        return context_vector


In [18]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

batch_size, context_length, d_in = batch.shape
d_out = 6
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vector = mha(batch)
print(context_vector)
print(f"Context vector.shape: {context_vector.shape}")

torch.Size([2, 6, 3])
tensor([[[-0.1084,  0.3656, -0.0570,  0.0369,  0.0875,  0.0014],
         [-0.0203,  0.3179, -0.0227, -0.0276,  0.1205,  0.0277],
         [ 0.0096,  0.3001, -0.0120, -0.0471,  0.1314,  0.0367],
         [ 0.0237,  0.2843, -0.0300, -0.0220,  0.1289,  0.0451],
         [ 0.0008,  0.2862, -0.0476,  0.0349,  0.1411,  0.0135],
         [ 0.0205,  0.2776, -0.0468,  0.0170,  0.1342,  0.0338]],

        [[-0.1084,  0.3656, -0.0570,  0.0369,  0.0875,  0.0014],
         [-0.0203,  0.3179, -0.0227, -0.0276,  0.1205,  0.0277],
         [ 0.0096,  0.3001, -0.0120, -0.0471,  0.1314,  0.0367],
         [ 0.0237,  0.2843, -0.0300, -0.0220,  0.1289,  0.0451],
         [ 0.0008,  0.2862, -0.0476,  0.0349,  0.1411,  0.0135],
         [ 0.0205,  0.2776, -0.0468,  0.0170,  0.1342,  0.0338]]],
       grad_fn=<ViewBackward0>)
Context vector.shape: torch.Size([2, 6, 6])
