# Multihead Atttention From Scratch

In [30]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

## Why multihead?
In simple terms, think of each attention head as having its own set of 'glasses' to look at the data. Each set of glasses allows a head to see specific features or relationships in the data. By having multiple heads (multiple sets of glasses), the model can see and understand the data from various perspectives at the same time. This multiplicity of perspectives helps in creating a more comprehensive representation of data. Which is crucial for tasks that require a nuanced understanding of complex relationships, like language undertstanding and translation. 

The use of multiple attention heads in Transformers models enhances the networks ability to capture various types of the data.

In [11]:
sequence_length = 4
batch_size = 1
input_dim = 512
d_model = 512 # the output dimension size. This should match the input dimension so we can stack

# initialize a random tensor
x = torch.randn((batch_size, sequence_length, input_dim))
x.shape

torch.Size([1, 4, 512])

In [12]:
# now instead of creating seperate qkl layers we will create it at once
qkv_layer = nn.Linear(input_dim, 3 * d_model)

In [13]:
# so the first step is to feed our x into our qkv_layer
qkv = qkv_layer(x)
qkv.shape

torch.Size([1, 4, 1536])

In [17]:
# now lets define the number of heads
num_heads = 8
head_dim = d_model // num_heads # the dimension is just splitting the dimension of the model
head_dim

64

In [18]:
# to apply the multiple heads we simply just want to reshape our qkv output
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3 * head_dim)
qkv.shape

torch.Size([1, 4, 8, 192])

In [20]:
# as you see this is just reshaping the weighted matrix, thats it
4 * 8 * 192 == 4 * 1536

True

In [22]:
# next we just want to permute the match (switch dimensions)
# this will turn our dimensions from: [batch_size, sequence_length, num_heads, 3*head_dim]
#                                 to: [batch_size, num_heads, sequence_length, 3*head_dim]
qkv = qkv.permute(0, 2, 1, 3)
qkv.shape

torch.Size([1, 8, 4, 192])

In [24]:
# then we will break down the layers individually by the last dimension
q, k, v = qkv.chunk(3, dim=-1)
q.shape, k.shape, v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

In [32]:
# now we simply perform our self attention as we've done before
# we only want to transpose the last two dimensions
scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))

In [33]:
# remember this is just attention so in this case we should get the last dimensions as 4x4 
# as its performing an attention matrix on the sequence
scaled.shape

torch.Size([1, 8, 4, 4])

In [37]:
# now we will do the same thing with masks
# remember, that we will only use the masks for the decoder and not the encoder
mask = torch.full(scaled.size(), float("-inf")) # create a matrix of all inf with size of the scaled matrix
mask = torch.triu(mask, diagonal=1) # fill the lower diagn with 0
mask[0][1]

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [38]:
# then for the instance we are adding the mask we simply just add
scaled += mask
scaled[0][1]

tensor([[ 0.2898,    -inf,    -inf,    -inf],
        [-0.3619,  0.4547,    -inf,    -inf],
        [-0.1549,  0.0713, -0.1431,    -inf],
        [ 0.0760, -0.1361,  0.3425,  0.2237]], grad_fn=<SelectBackward0>)

In [39]:
# lastly to calculate our attention we need to use softmax
attention = F.softmax(scaled, dim=-1)

In [41]:
# lastly, we calculate the value
value = torch.matmul(attention, v)
value.shape

torch.Size([1, 8, 4, 64])

In [101]:
# lets put it all into a class in this case, not a function
def scaled_dot_product(q, k, v, mask=None):
    scaled = torch.matmul(q, k.tranpose(-1, -2)) / math.sqrt(q.size(-1))
    if mask: scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return attention, values

class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, embed_dim, n_heads = 8):
        super(MultiHeadAttention).__init__()
        self.input_dim, self.embed_dim, self.n_heads = input_dim, embed_dim, n_heads
        self.head_dim = embed_dim // n_heads

        self.qkv = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask = None):
        """x: [batch_size, sequence_length, embedded_dim]"""
        batch_size, seq_len, _ = x.size()

        # [batch_size, seq_length, embed_dim*3]
        qkv = self.qkv(x)

        # [batch_size, seq_length, n_heads, 3 * head_dim]
        qkv = qkv.reshape(batch_size, seq_len, self.n_heads, 3*self.head_dim)

        # [batch_size, n_heads, seq_length, 3 * head_dim]
        qkv = qkv.permute(0, 2, 1, 3)

        # chunk to get the qkv layer separate
        # [batch_size, n_heads, seq_length, head_dim]
        q,k,v = qkv.chunk(3, dim=-1)

        # calculate attention and values
        attention, values = scaled_dot_product(q, k, v, mask)

        # reshape the valuess matrix back to original
        # remember: n_heads = 8
        #         : head_dim = 64
        #         : n_heads * head_dim = 512, which is the same size as the input
        values = values.permute(0, 2, 1, 3)
        values = values.reshape(batch_size, seq_len, self.embed_dim)

        # feed our values through our linear projectioj
        output = self.o_proj(values)
        return attention, output

In [102]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, embed_dim, n_heads = 8):
        super(MultiHeadAttention).__init__()
        self.input_dim, self.embed_dim, self.n_heads = input_dim, embed_dim, n_heads
        self.head_dim = embed_dim // n_heads
        self.qkv = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask = None):
        batch_size, seq_len, _ = x.size()
        qkv = self.qkv(x)
        qkv = qkv.reshape(batch_size, seq_len, self.n_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)
        q,k,v = qkv.chunk(3, dim=-1)
        attention, values = scaled_dot_product(q, k, v, mask)
        values = values.permute(0, 2, 1, 3)
        values = values.reshape(batch_size, seq_len, self.embed_dim)
        output = self.o_proj(values)
        return attention, output