In [48]:
import math
import numpy as np

import torch
import torch.nn as nn
import torch.functional as F

## Self Attention

In [11]:
# Initialize dimensions (L = length of sequence | d_k, d_q and d_v are the dimension of the query, key and value vector respectively)
L, d_q, d_k, d_v = 4, 8, 8, 8

q = np.random.rand(L, d_q)
k = np.random.rand(L, d_k)
v = np.random.rand(L, d_v)

In [12]:
# K.Q^T
np.matmul(q, k.T)

array([[1.97034639, 2.09695661, 2.28891774, 2.13474498],
       [1.68359999, 1.99113588, 1.76012616, 2.03409332],
       [1.35021503, 1.24722467, 0.91013416, 1.0667486 ],
       [1.76674198, 2.18895891, 2.24594791, 2.26042321]])

In [21]:
# Why do we need the sqrt(d_k) --> To stabilize the variance of the dot-product of query and key values
q.var(), k.var(), np.matmul(q, k.T).var(), (np.matmul(q, k.T)/np.sqrt(d_k)).var()

# Scaled dot-product
scaled_dot_product = np.matmul(q, k.T)/np.sqrt(d_k)
scaled_dot_product

array([[0.69662265, 0.74138612, 0.80925463, 0.75474632],
       [0.59524249, 0.70397284, 0.62229857, 0.71916059],
       [0.4773731 , 0.44096051, 0.32178102, 0.37715259],
       [0.62463762, 0.77391384, 0.7940625 , 0.79918029]])

In [24]:
# Masking - To ensure we do not look ahead for decoding
mask = np.tril(np.ones((L, L)))
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

scaled_dot_product = scaled_dot_product + mask
scaled_dot_product

array([[0.69662265,       -inf,       -inf,       -inf],
       [0.59524249, 0.70397284,       -inf,       -inf],
       [0.4773731 , 0.44096051, 0.32178102,       -inf],
       [0.62463762, 0.77391384, 0.7940625 , 0.79918029]])

In [35]:
# Attention (After Softmax)
def softmax(x):
    return np.exp(scaled_dot_product)/np.sum(np.exp(scaled_dot_product), axis=-1)

attention = softmax(scaled_dot_product)
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.90358946, 0.52715584, 0.        , 0.        ],
       [0.80312128, 0.40524178, 0.30349735, 0.        ],
       [0.93054482, 0.5653456 , 0.48670314, 0.26248191]])

In [37]:
attention_scores = np.matmul(attention, v)
attention_scores

array([[0.97760902, 0.422824  , 0.23773325, 0.25334508, 0.01667512,
        0.29083252, 0.38963413, 0.58275223],
       [1.33535946, 0.69665477, 0.22186223, 0.55445162, 0.26947833,
        0.42079728, 0.46213142, 0.59249975],
       [1.41975113, 0.66866223, 0.49369465, 0.75235602, 0.36916941,
        0.43552976, 0.61813096, 0.79486238],
       [1.90960847, 1.07926621, 0.90911764, 1.1603667 , 0.59701321,
        0.63747457, 0.91029532, 1.13494983]])

In [44]:
# Encapsulating the scaled_dot_product attention
def scaled_dot_product_attention(q, k, v, mask = None):
    d_k = q.shape[-1]
    scaled_dot_product = np.matmul(q, k.T)/np.sqrt(d_k)
    if mask is not None:
        scaled_dot_product = scaled_dot_product + mask
    attention = softmax(scaled_dot_product)
    attention_scores = np.matmul(attention, v)
    return attention_scores, attention

In [45]:
values, attention_matrix = scaled_dot_product_attention(q, k, v, mask)
attention_matrix

array([[1.        , 0.        , 0.        , 0.        ],
       [0.90358946, 0.52715584, 0.        , 0.        ],
       [0.80312128, 0.40524178, 0.30349735, 0.        ],
       [0.93054482, 0.5653456 , 0.48670314, 0.26248191]])

## Multi Head Attention

In [47]:
sequence_length = 4
batch_size = 1
input_dim = 512
d_model = 512
x = torch.randn((batch_size, sequence_length, input_dim))
x.size()

torch.Size([1, 4, 512])

In [73]:
# Scaled Dot Product in Multi Head Scenario
qkv_layer = nn.Linear(input_dim, 3 * d_model)

qkv = qkv_layer(x)
print(qkv.size())

num_heads = 8
head_dim = d_model//8
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3 * head_dim)
print(qkv.size())

# Permute to get tensor (batch size, num_heads, sequence_length, d_model)
qkv = qkv.permute(0, 2, 1, 3)
print(qkv.size())

q, k, v = qkv.chunk(3, dim=-1)
print(q.size(), k.size(), v.size())

torch.Size([1, 4, 1536])
torch.Size([1, 4, 8, 192])
torch.Size([1, 8, 4, 192])
torch.Size([1, 8, 4, 64]) torch.Size([1, 8, 4, 64]) torch.Size([1, 8, 4, 64])


In [102]:
def torch_scaled_dot_product_attention(q, k, v, mask = None):
    d_k = q.size()[-1]
    scaled_dot_product = torch.matmul(q, torch.transpose(k, -2, -1))/math.sqrt(d_k)
    if mask is not None:
        scaled_dot_product += mask
    attention = torch.softmax(scaled_dot_product, -1)
    attention_scores = torch.matmul(attention, v)
    return attention_scores, attention

# For encoders
att_score, att_matrix = torch_scaled_dot_product_attention(q, k, v)
print(att_score.size())
print(att_matrix.size())

#  For decoders
mask = torch.full((batch_size, num_heads, sequence_length, sequence_length), float("-inf"))
mask = torch.triu(mask, diagonal=1)

att_score, att_matrix = torch_scaled_dot_product_attention(q, k, v, mask=mask)
print(att_score.size())
print(att_matrix.size())

torch.Size([1, 8, 4, 64])
torch.Size([1, 8, 4, 4])
torch.Size([1, 8, 4, 64])
torch.Size([1, 8, 4, 4])


In [134]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model//num_heads
        self.qkv_layer = nn.Linear(self.input_dim, 3 * d_model)
        self.feed_forward = nn.Linear(d_model, d_model)

    @staticmethod
    def scaled_dot_product_attention(q, k, v, mask = None):
        d_k = q.size()[-1]
        scaled_dot_product = torch.matmul(q, torch.transpose(k, -2, -1))/math.sqrt(d_k)
        if mask is not None:
            scaled_dot_product += mask
        attention = torch.softmax(scaled_dot_product, -1)
        attention_scores = torch.matmul(attention, v)
        return attention_scores, attention

    def forward(self, x, mask=None):
        # Input tensor (batch_size, sequence_length, input_dim)
        batch_size, sequence_length = x.size()[0], x.size()[1]
        qkv_tensor = self.qkv_layer(x)
        print(qkv_tensor.size())

        # Reshape to accomodate heads
        qkv_tensor = qkv_tensor.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        print(qkv_tensor.size())

        # Permute to get tensor (batch size, num_heads, sequence_length, d_model)
        qkv_tensor = qkv_tensor.permute(0, 2, 1, 3)
        print(qkv_tensor.size())

        # Extract individual q, k and v
        q, k, v = qkv_tensor.chunk(3, dim=-1)
        print(q.size(), k.size(), v.size())

        # Compute attention
        attention_score, self_attention_matrix = MultiHeadAttention.scaled_dot_product_attention(q, k, v)
        print(attention_score.size())

        # Gather the heads
        attention_score = attention_score.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        print(attention_score.size())

        # Layer adjustnemt
        encoded_input = self.feed_forward(attention_score)

        return encoded_input
    
# Test it out
input_dim = 1024
sequence_length = 5
batch_size = 30

input_sequence = torch.randn((batch_size, sequence_length, input_dim))
print(input_sequence.size())

attention_block = MultiHeadAttention(
    input_dim=input_dim,
    d_model=512, 
    num_heads=8
    )
value = attention_block.forward(input_sequence)


torch.Size([30, 5, 1024])
torch.Size([30, 5, 1536])
torch.Size([30, 5, 8, 192])
torch.Size([30, 8, 5, 192])
torch.Size([30, 8, 5, 64]) torch.Size([30, 8, 5, 64]) torch.Size([30, 8, 5, 64])
torch.Size([30, 8, 5, 64])
torch.Size([30, 5, 512])


In [133]:
value.size()

torch.Size([30, 5, 512])

## Positional Encoding