Reference:
https://atcold.github.io/pytorch-Deep-Learning/

In [1]:
import torch 
from torch import nn
import torch.nn.functional as f
import numpy as np 

## Multi head attention

<img src='./self_attn_full.png' width=350 height=350>

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        
        self.d_model = d_model # dimension of the input embedding
        self.num_heads = num_heads # number of heads in multi-head attention
        self.d_xq = self.d_xk = self.d_xv = self.d_model # dimension of query, key and value
        
        # Make sure the input embedding dimesion is divisible by num_heads
        assert self.d_model % self.num_heads == 0
        
        # Dimension of heads
        self.d_k = self.d_model // self.num_heads # d_model=512, num_heads=8, d_k=64        
        
        # Initial linear layers Q = Wq x X_q, K = W_k x X_k, V = W_v x X_v 
        self.W_q = nn.Linear(self.d_xq, self.d_model, bias=False)
        self.W_k = nn.Linear(self.d_xk, self.d_model, bias=False)
        self.W_v = nn.Linear(self.d_xv, self.d_model, bias=False)        
        
        # Final linear layer
        self.W_h = nn.Linear(self.d_model, self.d_model, bias=False)
        
    def scaled_dot_product(self, Q, K, V):
        ''' Scaled dot product to calculate self-attention '''
        
        print("Size of Q:", Q.size())
        print("Size of K:", K.size())
        print("Size of V:", V.size())
        # scaling - divide by square root of d_k
        Q = Q / np.sqrt(self.d_k)  # (batch_size, num_heads, q_length, dim_per_head)
        # dot-product
        scores = torch.matmul(Q, K.transpose(2,3))  # (batch_size, num_heads, q_length, k_length)
        
        # Softmax of scores
        A = nn.Softmax(dim=-1)(scores)  # (batch_size, num_heads, q_length, k_length)
        print("Size of A:", A.size())
        
        # Get the weighted average of values,V
        H = torch.matmul(A, V)  # (batch_size, num_heads, q_length, dim_per_head)
        print("Size of H:", H.size())
        
        return H, A
    
    def split_heads(self, x):
        ''' Split the embeddings into multiple heads '''
        
        batch_size = x.size(0)  # x -> (batch_size, seq_length, d_model)
        return x.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)   # (batch_size, num_heads, seq_length, dim_per_head)
    
    def group_heads(self, x): 
        ''' Combine the heads again to get batch_size X seq_length X d_model (num_heads x dim_per_head) '''
        
        batch_size = x.size(0)
        return x.transpose(1,2).view(batch_size, -1, self.num_heads * self.d_k)     # (batch_size X seq_length X (num_heads x dim_per_head))
    
    def forward(self, X_q, X_k, X_v): # for self-attention X_q = X_k = X_v = X
        
        batch_size, seq_length, d_model = X_q.size()
        
        # Step 1: Linear layer and split heads
        Q = self.split_heads(self.W_q(X_q))  # (batch_size, num_heads, q_length, dim_per_head)
        K = self.split_heads(self.W_q(X_k))  # (batch_size, num_heads, k_length, dim_per_head)
        V = self.split_heads(self.W_q(X_v))  # (batch_size, num_heads, v_length, dim_per_head)
        
        # Step 2: Scaled dot product 
        H_cat, A = self.scaled_dot_product(Q, K, V)  # (batch_size, num_heads, q_length, k_length) 
        
        # Step 3: Combine the heads
        H_cat = self.group_heads(H_cat)  # (batch_size, q_length, d_model)
        H = self.W_h(H_cat)  # (batch_size, q_length, d_model)
        
        return H, A      

In [8]:
# Initiate the multi-head self-attention model
mha = MultiHeadAttention(d_model=512, num_heads=8)

In [9]:
def attention(Q, K, V):
    ''' Print out attention scores and output '''
    
    temp_out, temp_attn = mha.scaled_dot_product(Q, K, V)
    print('Attention weights are:', temp_attn.squeeze())
    print('Output is:', temp_out.squeeze())

In [10]:
test_K = torch.tensor(
    [[10, 0, 0],
     [ 0,10, 0],
     [ 0, 0,10],
     [ 0, 0,10]]
).float()[None,None]

test_V = torch.tensor(
    [[   1,0,0],
     [  10,0,0],
     [ 100,5,0],
     [1000,6,0]]
).float()[None,None]

In [11]:
test_Q = torch.tensor(
    [[0, 10, 0]]
).float()[None,None]
attention(test_Q, test_K, test_V)

Size of Q: torch.Size([1, 1, 1, 3])
Size of K: torch.Size([1, 1, 4, 3])
Size of V: torch.Size([1, 1, 4, 3])
Size of A: torch.Size([1, 1, 1, 4])
Size of H: torch.Size([1, 1, 1, 3])
Attention weights are: tensor([3.7266e-06, 9.9999e-01, 3.7266e-06, 3.7266e-06])
Output is: tensor([1.0004e+01, 4.0993e-05, 0.0000e+00])


#### The query focuses on second key and returned the second value. 

In [13]:
test_Q = torch.tensor([[0, 0, 10]])[None, None].float()  
attention(test_Q, test_K, test_V)

Size of Q: torch.Size([1, 1, 1, 3])
Size of K: torch.Size([1, 1, 4, 3])
Size of V: torch.Size([1, 1, 4, 3])
Size of A: torch.Size([1, 1, 1, 4])
Size of H: torch.Size([1, 1, 1, 3])
Attention weights are: tensor([1.8633e-06, 1.8633e-06, 5.0000e-01, 5.0000e-01])
Output is: tensor([549.9979,   5.5000,   0.0000])


#### If we give a query that matches two keys exactly, it focuses on the two keys equally and returns the average of the two values for those two keys. 