In [16]:
# Importing necessary libraries
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

### Multi-Head Attention

In [17]:
sequence_length = 4 # Eg: "My name is Nav"
batch_size = 1
input_dim = 512 # Embedding dimension
d_model = 512
x = torch.randn( (batch_size, sequence_length, input_dim) )

In [18]:
x.size()

torch.Size([1, 4, 512])

In [19]:
qkv_layer = nn.Linear(input_dim, 3 * d_model) # For Q, K, V

In [20]:
qkv = qkv_layer(x) # passing I/P to the layer to generate Q, K, v vector

In [21]:
qkv.size()

torch.Size([1, 4, 1536])

In [22]:
num_heads = 8
head_dim = d_model // num_heads
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3*head_dim) # splitting into 8 heads, with each having (qkv) -> 192 dim (where q - 64, k - 64, v - 64)

In [23]:
qkv.size()

torch.Size([1, 4, 8, 192])

In [24]:
# Interchanging sequence_length and num_heads for ease of calculation
qkv = qkv.permute(0, 2, 1, 3) # (batch_size, num_heads, sequence_length, 3*head_dim)
qkv.size()

torch.Size([1, 8, 4, 192])

In [25]:
q, k, v = torch.chunk(qkv, 3, dim=-1) # splitting last dim into 3 parts
q.size(), k.size(), v.size()

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

### Self Attention for multi-heads

self attention = softmax(Q.K^T/sqrt(d_k) + M)

In [26]:
d_k = q.size()[-1]
scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
scaled.size()

torch.Size([1, 8, 4, 4])

In [27]:
# masking
mask = torch.full(scaled.size(), float('-inf'))
mask = torch.triu(mask, diagonal=1) # upper triangular matrix
mask[0][0]

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [28]:
(scaled + mask)[0][0]

tensor([[ 0.2947,    -inf,    -inf,    -inf],
        [ 0.4688,  0.0434,    -inf,    -inf],
        [ 0.1908, -0.1342,  0.0264,    -inf],
        [ 0.1232,  0.4566, -0.5165, -0.1695]], grad_fn=<SelectBackward0>)

In [29]:
scaled+= mask
scaled.size()

torch.Size([1, 8, 4, 4])

In [30]:
attention = F.softmax(scaled, dim=-1)

In [31]:
attention[0][0]
attention.size()

torch.Size([1, 8, 4, 4])

In [32]:
values = torch.matmul(attention, v) # new value vectors are more context aware than v
values.size()

torch.Size([1, 8, 4, 64])

In [36]:
# function to return values and attention
def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

### Positional encoding

In [37]:
max_sequence_length = 10
d_model = 6

In [38]:
even_i = torch.arange(0, d_model, 2).float()  # even indices
even_i

tensor([0., 2., 4.])

In [39]:
even_denominator = torch.pow(10000, even_i/d_model)
even_denominator

tensor([  1.0000,  21.5443, 464.1590])

In [40]:
odd_i = torch.arange(1, d_model, 2).float()  # odd indices
odd_i

tensor([1., 3., 5.])

In [42]:
odd_denominator = torch.pow(10000, (odd_i-1)/d_model)
odd_denominator

tensor([  1.0000,  21.5443, 464.1590])

In [43]:
# even and odd denominators are the same
denominator = even_denominator

In [49]:
position = torch.arange(max_sequence_length, dtype=torch.float).reshape(max_sequence_length, 1)
position.size()

torch.Size([10, 1])

In [50]:
denominator.size()

torch.Size([3])

In [51]:
even_PE = torch.sin(position / denominator)
odd_PE = torch.cos(position / denominator)

In [56]:
even_PE

tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.8415,  0.0464,  0.0022],
        [ 0.9093,  0.0927,  0.0043],
        [ 0.1411,  0.1388,  0.0065],
        [-0.7568,  0.1846,  0.0086],
        [-0.9589,  0.2300,  0.0108],
        [-0.2794,  0.2749,  0.0129],
        [ 0.6570,  0.3192,  0.0151],
        [ 0.9894,  0.3629,  0.0172],
        [ 0.4121,  0.4057,  0.0194]])

In [53]:
odd_PE

tensor([[ 1.0000,  1.0000,  1.0000],
        [ 0.5403,  0.9989,  1.0000],
        [-0.4161,  0.9957,  1.0000],
        [-0.9900,  0.9903,  1.0000],
        [-0.6536,  0.9828,  1.0000],
        [ 0.2837,  0.9732,  0.9999],
        [ 0.9602,  0.9615,  0.9999],
        [ 0.7539,  0.9477,  0.9999],
        [-0.1455,  0.9318,  0.9999],
        [-0.9111,  0.9140,  0.9998]])

### Interleaving

In [57]:
stacked  = torch.stack( [even_PE, odd_PE], dim=2)
stacked.shape

torch.Size([10, 3, 2])

In [None]:
PE = stacked.reshape(max_sequence_length, d_model)

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0464,  0.9989,  0.0022,  1.0000],
        [ 0.9093, -0.4161,  0.0927,  0.9957,  0.0043,  1.0000],
        [ 0.1411, -0.9900,  0.1388,  0.9903,  0.0065,  1.0000],
        [-0.7568, -0.6536,  0.1846,  0.9828,  0.0086,  1.0000],
        [-0.9589,  0.2837,  0.2300,  0.9732,  0.0108,  0.9999],
        [-0.2794,  0.9602,  0.2749,  0.9615,  0.0129,  0.9999],
        [ 0.6570,  0.7539,  0.3192,  0.9477,  0.0151,  0.9999],
        [ 0.9894, -0.1455,  0.3629,  0.9318,  0.0172,  0.9999],
        [ 0.4121, -0.9111,  0.4057,  0.9140,  0.0194,  0.9998]])

In [None]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_sequence_length):
        super().__init__()
        self.max_sequence_length = max_sequence_length
        self.d_model = d_model

    def forward(self):
        even_i = torch.arange(0, self.d_model, 2).float()  # even indices
        denominator = torch.pow(10000, even_i/self.d_model)
        position = torch.arange(self.max_sequence_length).float().reshape(self.max_sequence_length, 1)
        even_PE = torch.sin(position / denominator)
        odd_PE = torch.cos(position / denominator)
        stacked = torch.stack( [even_PE, odd_PE], dim=2)
        PE = torch.flatten(stacked, start_dim=1, end_dim=2)
        return PE
    
pos_enc = PositionalEncoding(d_model=6, max_sequence_length=10)
pos_enc.forward()


tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0464,  0.9989,  0.0022,  1.0000],
        [ 0.9093, -0.4161,  0.0927,  0.9957,  0.0043,  1.0000],
        [ 0.1411, -0.9900,  0.1388,  0.9903,  0.0065,  1.0000],
        [-0.7568, -0.6536,  0.1846,  0.9828,  0.0086,  1.0000],
        [-0.9589,  0.2837,  0.2300,  0.9732,  0.0108,  0.9999],
        [-0.2794,  0.9602,  0.2749,  0.9615,  0.0129,  0.9999],
        [ 0.6570,  0.7539,  0.3192,  0.9477,  0.0151,  0.9999],
        [ 0.9894, -0.1455,  0.3629,  0.9318,  0.0172,  0.9999],
        [ 0.4121, -0.9111,  0.4057,  0.9140,  0.0194,  0.9998]])