## Multi Head Attention

In [12]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
sequence_length = 4    # 4 words
batch_size = 1
input_dim = 512
d_model = 512
x = torch.randn((batch_size, sequence_length, input_dim))

In [4]:
x

tensor([[[-0.2694, -1.2359, -0.9782,  ..., -0.8081, -0.3821, -0.1259],
         [-0.5397, -1.9118, -1.3352,  ..., -0.3518,  0.4289,  1.0851],
         [ 1.7145,  0.7449,  0.7258,  ...,  1.4510, -1.5135,  1.3949],
         [-1.0144,  0.3269, -0.1946,  ...,  0.6608, -0.5699, -0.6814]]])

In [5]:
qkv_layer = nn.Linear(input_dim , 3 * d_model)

In [6]:
qkv = qkv_layer(x)

In [7]:
qkv

tensor([[[ 0.2081, -0.2032,  0.7234,  ..., -0.2698, -0.6272,  1.5717],
         [-0.0928, -0.2201,  0.7269,  ..., -0.4078, -0.4672,  0.2020],
         [-0.0757,  0.0749, -0.1767,  ..., -0.1290,  0.5579, -0.3411],
         [ 0.3501, -0.2855, -0.2798,  ..., -0.4156,  0.3686, -0.0827]]],
       grad_fn=<ViewBackward0>)

In [8]:
num_heads = 8
head_dim = d_model // num_heads
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3 * head_dim)

In [9]:
qkv = qkv.permute(0, 2, 1, 3) # [batch_size, num_heads, sequence_length, 3*head_dim]
qkv.shape

torch.Size([1, 8, 4, 192])

In [10]:
q, k, v = qkv.chunk(3, dim=-1)
q.shape, k.shape, v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

## Self Attention for multiple heads

For a single head:
$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$ 

In [13]:
d_k = q.size()[-1]
scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
scaled.shape

torch.Size([1, 8, 4, 4])

In [19]:
k.T.shape

  k.T.shape


torch.Size([64, 4, 8, 1])

In [21]:
y = torch.randn(2, 3)
torch.transpose(y, 0, 1)

tensor([[ 0.2370, -0.6806],
        [-1.1236,  0.9982],
        [-0.0747,  0.1154]])

In [22]:
torch.transpose(y, 1, 0)

tensor([[ 0.2370, -0.6806],
        [-1.1236,  0.9982],
        [-0.0747,  0.1154]])

In [29]:
mask = torch.full(scaled.size() , float('-inf'))
mask = torch.triu(mask, diagonal=1)
mask[0][1] # mask for input to a single head

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [30]:
scaled+mask

tensor([[[[ 0.0146,    -inf,    -inf,    -inf],
          [ 0.1821,  0.4276,    -inf,    -inf],
          [-0.2670, -0.0280, -0.3163,    -inf],
          [ 0.1123, -0.3236,  0.1044,  0.0513]],

         [[ 0.5942,    -inf,    -inf,    -inf],
          [-0.0722, -0.2064,    -inf,    -inf],
          [-0.7415, -0.0156, -0.2027,    -inf],
          [ 0.1921, -0.1916,  0.0409, -0.0272]],

         [[-0.4398,    -inf,    -inf,    -inf],
          [-0.3835, -0.1010,    -inf,    -inf],
          [ 0.1020, -0.0861, -0.0317,    -inf],
          [-0.4299, -0.4382, -0.4538,  0.1704]],

         [[-0.2862,    -inf,    -inf,    -inf],
          [-0.0650, -0.4676,    -inf,    -inf],
          [-0.2815, -0.2485,  0.1706,    -inf],
          [ 0.0828, -0.0689, -0.1944, -0.4020]],

         [[-0.3037,    -inf,    -inf,    -inf],
          [-0.1706, -0.1514,    -inf,    -inf],
          [-0.1198, -0.2822, -0.1319,    -inf],
          [ 0.0145,  0.1137,  0.2772,  0.0864]],

         [[-0.5260,    -inf,  

In [31]:
attention = F.softmax(scaled, dim=-1)

In [32]:
values = torch.matmul(attention, v)
values.shape

torch.Size([1, 8, 4, 64])

## Encapsulate all in Function

In [33]:
import math

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

In [34]:
values, attention = scaled_dot_product(q, k, v, mask=mask)

In [35]:
values = values.reshape(batch_size, sequence_length, num_heads * head_dim)
values.size()

torch.Size([1, 4, 512])

In [36]:
linear_layer = nn.Linear(d_model, d_model)

In [37]:
out = linear_layer(values)

## All in One set

In [38]:
import torch
import torch.nn as nn
import math

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(input_dim , 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.size()
        print(f"x.size(): {x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3)
        print(f"qkv.size(): {qkv.size()}")
        q, k, v = qkv.chunk(3, dim=-1)
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}, ")
        values, attention = scaled_dot_product(q, k, v, mask)
        print(f"values.size(): {values.size()}, attention.size:{ attention.size()} ")
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        print(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out

In [39]:
input_dim = 1024
d_model = 512
num_heads = 8

batch_size = 30
sequence_length = 5
x = torch.randn( (batch_size, sequence_length, input_dim) )

model = MultiheadAttention(input_dim, d_model, num_heads)
out = model.forward(x)

x.size(): torch.Size([30, 5, 1024])
qkv.size(): torch.Size([30, 5, 1536])
qkv.size(): torch.Size([30, 5, 8, 192])
qkv.size(): torch.Size([30, 8, 5, 192])
q size: torch.Size([30, 8, 5, 64]), k size: torch.Size([30, 8, 5, 64]), v size: torch.Size([30, 8, 5, 64]), 
values.size(): torch.Size([30, 8, 5, 64]), attention.size:torch.Size([30, 8, 5, 5]) 
values.size(): torch.Size([30, 5, 512])
out.size(): torch.Size([30, 5, 512])
