In [13]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
sequence_len = 4
batch_size = 1
input_dim = 512
d_model = 512
x = torch.randn((batch_size, sequence_len, input_dim))

In [3]:
x.size()

torch.Size([1, 4, 512])

In [4]:
qkv_layer = nn.Linear(input_dim, 3*d_model)

In [5]:
qkv = qkv_layer(x)

In [6]:
qkv.shape

torch.Size([1, 4, 1536])

In [7]:
# import matplotlib.pyplot as plt
# y_val = torch.histc(qkv, bins=200, min=-3, max=3)
# x_val = np.arange(-1, 1, 0.01)*3
# plt.bar(x_val, y_val, align='center', color=['forestgreen'])
# plt.title('qkv distribution')

In [8]:
num_heads = 8
head_dim = d_model//num_heads
qkv = qkv.reshape(batch_size, sequence_len, num_heads, 3*head_dim)

In [9]:
qkv.shape

torch.Size([1, 4, 8, 192])

In [10]:
qkv = qkv.permute(0, 2, 1, 3)
qkv.shape

torch.Size([1, 8, 4, 192])

In [11]:
q, k, v = qkv.chunk(3, dim=-1)
q.shape, k.shape, v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

# Self Attention for Multiple heads

In [14]:
d_k = q.size()[-1]
scaled = torch.matmul(q, k.transpose(-2, -1))/math.sqrt(d_k)
scaled.shape

torch.Size([1, 8, 4, 4])

In [17]:
k.T.shape

torch.Size([64, 4, 8, 1])

In [18]:
y = torch.randn(2, 3)
torch.transpose(y, 0, 1)

tensor([[-1.4633,  0.5615],
        [ 0.4364, -1.1420],
        [ 0.6907,  0.0925]])

In [19]:
torch.transpose(y, 1, 0)

tensor([[-1.4633,  0.5615],
        [ 0.4364, -1.1420],
        [ 0.6907,  0.0925]])

In [20]:
mask = torch.full(scaled.size(), float('-inf'))
mask = torch.triu(mask, diagonal=1)
mask[0][1]

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [21]:
(scaled+mask)[0][0]

tensor([[ 0.5556,    -inf,    -inf,    -inf],
        [ 0.0318, -0.0891,    -inf,    -inf],
        [-0.1515, -0.0797, -0.1964,    -inf],
        [ 0.4185, -0.4006,  0.2747, -0.1385]], grad_fn=<SelectBackward0>)

In [23]:
scaled+=mask

In [24]:
np.exp(0.0318)/(np.exp(0.0318)+np.exp(-0.0891))

0.5301882376437154

In [25]:
attention = F.softmax(scaled, dim=-1)

In [26]:
attention.shape

torch.Size([1, 8, 4, 4])

In [27]:
attention[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5302, 0.4698, 0.0000, 0.0000],
        [0.3300, 0.3545, 0.3155, 0.0000],
        [0.3472, 0.1531, 0.3007, 0.1989]], grad_fn=<SelectBackward0>)

In [28]:
values = torch.matmul(attention, v)
values.shape

torch.Size([1, 8, 4, 64])

# Funtion for Above Process 

In [29]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled=torch.matmul(q, k.transpose(-1, -2))/math.sqrt(d_k)
    if mask is not None:
        scaled+=mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

In [33]:
values, attention = scaled_dot_product(q, k, v, mask=mask)

In [34]:
attention.shape

torch.Size([1, 8, 4, 4])

In [35]:
attention[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5302, 0.4698, 0.0000, 0.0000],
        [0.3300, 0.3545, 0.3155, 0.0000],
        [0.3472, 0.1531, 0.3007, 0.1989]], grad_fn=<SelectBackward0>)

In [36]:
values.shape

torch.Size([1, 8, 4, 64])

In [37]:
values = values.reshape(batch_size, sequence_len, num_heads*head_dim)
values.shape

torch.Size([1, 4, 512])

In [38]:
linear_layer = nn.Linear(d_model, d_model)

In [39]:
out = linear_layer(values)

In [40]:
out.shape

torch.Size([1, 4, 512])

In [41]:
out

tensor([[[ 0.1038,  0.2173, -0.1450,  ...,  0.1270,  0.0147, -0.4040],
         [-0.0090, -0.2104, -0.0502,  ..., -0.0574, -0.1809,  0.2877],
         [-0.0950,  0.1971,  0.1374,  ..., -0.0719,  0.2286, -0.1615],
         [-0.2862,  0.0359,  0.1088,  ...,  0.3490, -0.1342, -0.2290]]],
       grad_fn=<ViewBackward0>)

# Final Coded Class

In [42]:
import torch
import torch.nn as nn
import math

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled=torch.matmul(q, k.transpose(-1, -2))/math.sqrt(d_k)
    if mask is not None:
        scaled+=mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class MultiheadAttention(nn.Module):
    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dum = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model//num_heads
        self.qkv_layer = nn.Linear(input_dim, 3*d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
        
    def forward(self, x, mask=None):
        batch_size, sequence_len, input_dim = x.size()
        print(f"x.size():{x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size():{qkv.size()}")
        
        qkv = qkv.reshape(batch_size, sequence_len, self.num_heads, 3*self.head_dim)
        print(f"qkv.size():{qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3)
        print(f"qkv.size():{qkv.size()}")
        
        q, k, v = qkv.chunk(3, dim=-1)
        
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}, ")
        
        values, attention = scaled_dot_product(q, k, v, mask)
        print(f"values.size(): {values.size()}, attention.size(): {attention.size()}")
        
        values = values.reshape(batch_size, sequence_len, self.num_heads*self.head_dim)
        print(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        
        print(f"out.size():{out.size()}")
        
        return out
        
    

In [43]:
input_dim = 1024
d_model = 512
num_heads = 8

batch_size = 30
sequence_len = 5
x = torch.randn((batch_size, sequence_len, input_dim))
model = MultiheadAttention(input_dim, d_model, num_heads)
out = model.forward(x)

x.size():torch.Size([30, 5, 1024])
qkv.size():torch.Size([30, 5, 1536])
qkv.size():torch.Size([30, 5, 8, 192])
qkv.size():torch.Size([30, 8, 5, 192])
q size: torch.Size([30, 8, 5, 64]), k size: torch.Size([30, 8, 5, 64]), v size: torch.Size([30, 8, 5, 64]), 
values.size(): torch.Size([30, 8, 5, 64]), attention.size(): torch.Size([30, 8, 5, 5])
values.size(): torch.Size([30, 5, 512])
out.size():torch.Size([30, 5, 512])
