# Mult-head Attention

The idea is to split the given query, key, and value into $H$ many heads, and run attention for each of them separately.

<img src="images/scaled_mhead_attentions.png" alt="Figure 2. From ‘Attention Is All You Need’ by Vaswani et al." style="width:80%;"/>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
def scaled_dot_product_attention(q, k, v, mask=None, dropout=None):
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(k.shape[-1])
    if mask is not None:
        mask = mask.unsqueeze(1) # add batch
        scores = scores.masked_fill(mask==0, -1e9)
    a = F.softmax(scores, dim=-1)
    if dropout is not None:
        a = dropout(a)
    return torch.matmul(a, v)

In [3]:
B, I, J, K, H = 3, 9, 10, 18, 3

# determine the dimension of q, k, v after splitting into H many segments
d_k = K // H
q = torch.rand((B, J, K))
k = torch.rand((B, I, K))
v = torch.rand((B, I, K))

# for each head
cs = []
for h in range(H):
    # grab the corresponding segment
    qh = q[:,:,(h * d_k):((h + 1) * d_k)]
    kh = k[:,:,(h * d_k):((h + 1) * d_k)]
    vh = v[:,:,(h * d_k):((h + 1) * d_k)]
    # get context from attention
    cs.append(scaled_dot_product_attention(qh, kh, vh))
# concat them all
c1 = torch.cat(cs, dim=-1).contiguous()

In [4]:
# matrix way of doing this
# B x J/I x K -> B x J/I x H x d_k -> B x H x J/I x d_k
qh = q.view(B, -1, H, d_k).transpose(1, 2)
kh = k.view(B, -1, H, d_k).transpose(1, 2)
vh = v.view(B, -1, H, d_k).transpose(1, 2)
c2 = scaled_dot_product_attention(qh, kh, vh)\
    .transpose(1, 2).contiguous().view(B, -1, K)

In [5]:
torch.all(c1.eq(c2))

tensor(True)

In [6]:
def multi_head_attention(q, k, v, H):
    B, _, K = q.shape
    d_k = K // H
    
    # B x J/I x K -> B x J/I x H x d_k -> B x H x J/I x d_k
    qh = q.view(B, -1, H, d_k).transpose(1, 2)
    kh = k.view(B, -1, H, d_k).transpose(1, 2)
    vh = v.view(B, -1, H, d_k).transpose(1, 2)

    c = scaled_dot_product_attention(qh, kh, vh)\
        .transpose(1, 2).contiguous().view(B, -1, K)
    return c

In [7]:
torch.all(c1.eq(multi_head_attention(q, k, v, H)))

tensor(True)

In [None]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, H, K, dropout = 0.1):
        super().__init__()
        
        self.K = K
        self.d_k = K // H
        self.H = H
        
        self.q_linear = nn.Linear(K, K)
        self.v_linear = nn.Linear(K, K)
        self.k_linear = nn.Linear(K, K)
        
        self.dropout = nn.Dropout(dropout)
        
        self.out = nn.Linear(K, K)
    
    def forward(self, q, k, v, mask=None):
        
        B = q.size(0)
        
        # perform linear operation and split into h heads
        
        qh = self.q_linear(q).view(B, -1, self.H, self.d_k).transpose(1,2)
        kh = self.k_linear(k).view(B, -1, self.H, self.d_k).transpose(1,2)
        vh = self.v_linear(v).view(B, -1, self.H, self.d_k).transpose(1,2)

        # calculate attention using function we will define next
        scores = scaled_dot_product_attention(
            qh, kh, vh, mask, self.dropout)
        
        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous().view(B, -1, self.K)
        output = self.out(concat)
        return output