# Multi Head attention

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
sequence_length = 4 # My name is Ankit
batch_size =1 # number of sentences in batch
input_dim = 512 # vector dimension of each input
d_model = 512 # vector dimension of each output        
x = torch.rand((batch_size, sequence_length, input_dim)) # this will be the input , we are skipping positional encoding for now.

In [5]:
x.size() # batch_size X seq_len X input_dim

torch.Size([1, 4, 512])

In [7]:
qkv_layer = nn.Linear(input_dim, 3*d_model) # this is done to concatinate all three q, k, v vectors together which we will split up later as per multi heads

In [9]:
qkv_layer # takes 512 input features and return 1536 output values. 
#Linear creates a linear layer where input tensor of 512 is multiplied with some weight matrix and bias is added to transform the result into 1536 o/p.

Linear(in_features=512, out_features=1536, bias=True)

In [10]:
qkv = qkv_layer(x)

In [11]:
qkv.shape # here we can see that each word vector is 3*512 -> 1536 in size

torch.Size([1, 4, 1536])

In [12]:
num_heads= 8 # now we have 8 attention head
head_dim= d_model//num_heads # each head will have dim of 512 / 8 => 64
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3*head_dim)

In [13]:
qkv.shape # (1 sentence in a batch, each seq of lenght 4, 8 attention heads, q+k+v (64+64+64) is 192)

torch.Size([1, 4, 8, 192])

In [14]:
# let's change sequencing in qkv just to make things easy in future and perform parallel operations of last 2 dimensions
qkv=qkv.permute(0, 2, 1, 3) # [batch_size, num_heads, sequence_length, 3*head_dim]
qkv.shape

torch.Size([1, 8, 4, 192])

In [15]:
# now we will divide the chuck of each head into q, k, v
q, k, v=qkv.chunk(3, dim=-1) # divide last dimension by 3 i.e 192/3 = 64
q.shape, k.shape, v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

# Self Attention for multiple heads

**For single head**

self attention = softmax( ((Q.K^T)/d_k^0.5) + M)

new V = self attention . V

In [16]:
import math
d_k = q.size()[-1] # get size of one of these vectors
scaled = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(d_k) # (4*64) X (64*4) = (4*4)
scaled.shape

torch.Size([1, 8, 4, 4])

In [17]:
k.shape

torch.Size([1, 8, 4, 64])

In [18]:
k.transpose(-1,-2) == k.transpose(-2,-1)

tensor([[[[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          ...,
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          ...,
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          ...,
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         ...,

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          ...,
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, 

In [19]:
k.shape

torch.Size([1, 8, 4, 64])

In [20]:
k.transpose(-1,-2).shape

torch.Size([1, 8, 64, 4])

In [21]:
mask = torch.full(scaled.size(), float('-inf'))
mask = torch.triu(mask, diagonal=1)
mask[0][1] # mask for input to single head

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [23]:
(scaled+mask)[0][0] # This is the tensor for one head that is 4 X 4 matrix

tensor([[-0.0866,    -inf,    -inf,    -inf],
        [-0.0618,  0.0062,    -inf,    -inf],
        [ 0.0147,  0.0777,  0.1410,    -inf],
        [-0.0816,  0.0330,  0.0682,  0.0899]], grad_fn=<SelectBackward0>)

In [24]:
scaled+=mask

In [25]:
np.exp(0.0216) / (np.exp(0.0216) + np.exp(0.1516)) # for attention[0][0][1][0]

0.4675456936126812

In [26]:
attention = F.softmax(scaled, dim = -1)

In [27]:
attention[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4830, 0.5170, 0.0000, 0.0000],
        [0.3125, 0.3329, 0.3546, 0.0000],
        [0.2237, 0.2509, 0.2599, 0.2656]], grad_fn=<SelectBackward0>)

In [28]:
values = torch.matmul(attention, v) # this value vector will be much more context aware then the input vectors.
values.shape

torch.Size([1, 8, 4, 64])

## Function

In [29]:
import math

def scaled_dot_product(q, k, v, mask = None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(d_k)
    if mask is not None:
        scaled+= mask
        attention = F.softmax(scaled, dim=-1)
        values = torch.matmul(attention, v)
        return values, attention

In [31]:
values, attention = scaled_dot_product(q, k, v, mask=mask)


In [32]:
attention.shape

torch.Size([1, 8, 4, 4])

In [33]:
attention[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4830, 0.5170, 0.0000, 0.0000],
        [0.3125, 0.3329, 0.3546, 0.0000],
        [0.2237, 0.2509, 0.2599, 0.2656]], grad_fn=<SelectBackward0>)

In [34]:
values.size()

torch.Size([1, 8, 4, 64])

In [35]:
# now it's time to combine back all of the heads together
values = values.reshape(batch_size, sequence_length, num_heads*head_dim)
values.size()

torch.Size([1, 4, 512])

In [36]:
linear_layer = nn.Linear(d_model,d_model) # to make heads also communicate with each other about the info. they have learned we are going to pass it through the linear layer
out=linear_layer(values)
out.shape

torch.Size([1, 4, 512])

In [37]:
out

tensor([[[-0.2161,  0.3665,  0.3851,  ...,  0.0876,  0.0324,  0.1388],
         [ 0.0430, -0.0295, -0.1194,  ..., -0.2154,  0.0133, -0.0976],
         [-0.0010,  0.2153, -0.0640,  ...,  0.3443,  0.0404,  0.4272],
         [ 0.1973,  0.1521,  0.0879,  ..., -0.0050, -0.2509, -0.1399]]],
       grad_fn=<ViewBackward0>)

# Class

In [38]:

import torch
import torch.nn as nn
import math

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(input_dim , 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None): # forward pass
        batch_size, sequence_length, input_dim = x.size()
        print(f"x.size(): {x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3)
        print(f"qkv.size(): {qkv.size()}")
        q, k, v = qkv.chunk(3, dim=-1)
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}, ")
        values, attention = scaled_dot_product(q, k, v, mask)
        print(f"values.size(): {values.size()}, attention.size:{ attention.size()} ")
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        print(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out

In [39]:
input_dim = 1024
d_model = 512
num_heads = 8

batch_size = 30
sequence_length = 5
x = torch.randn( (batch_size, sequence_length, input_dim) )

model = MultiheadAttention(input_dim, d_model, num_heads)
out = model.forward(x)

x.size(): torch.Size([30, 5, 1024])
qkv.size(): torch.Size([30, 5, 1536])
qkv.size(): torch.Size([30, 5, 8, 192])
qkv.size(): torch.Size([30, 8, 5, 192])
q size: torch.Size([30, 8, 5, 64]), k size: torch.Size([30, 8, 5, 64]), v size: torch.Size([30, 8, 5, 64]), 
values.size(): torch.Size([30, 8, 5, 64]), attention.size:torch.Size([30, 8, 5, 5]) 
values.size(): torch.Size([30, 5, 512])
out.size(): torch.Size([30, 5, 512])
