# Understanding Multi-head Attention

In [57]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
     
sequence_length = 10
batch_size = 1
d_model = 512 # output

x = torch.randn( (batch_size, sequence_length, d_model) ) # random input sequence

> Stratch input sequence to tripled dimension sequence

In [58]:
qkv_layer = nn.Linear(d_model, 3 * d_model)
qkv = qkv_layer(x)
qkv.shape

torch.Size([1, 10, 1536])

> Divide into multi-heads

In [59]:
num_heads = 8
head_dim = d_model // num_heads
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3 * head_dim)
qkv = qkv.permute(0, 2, 1, 3) # [batch_size, num_heads, sequence_length, 3*head_dim]
qkv.shape

torch.Size([1, 8, 10, 192])

> Extract each Q, K, and V tensors

In [60]:
q, k, v = qkv.chunk(3, dim=-1)
q.shape, k.shape, v.shape

(torch.Size([1, 8, 10, 64]),
 torch.Size([1, 8, 10, 64]),
 torch.Size([1, 8, 10, 64]))

> Process Attention Mechanism (from self_attention.ipynb)

In [61]:
import math

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

def create_mask(batch_size=1, num_heads=8, sequence_length=512):
  mask = torch.full([batch_size, num_heads, sequence_length, sequence_length] , float('-inf'))
  mask = torch.triu(mask, diagonal=1)

In [62]:
# Encoder
encoder_v, encoder_a = scaled_dot_product(q, k, v)

# Decoder
mask = create_mask(batch_size, num_heads, sequence_length)
decoder_v, decoder_a = scaled_dot_product(q, k, v, mask=mask)

decoder_v.shape, decoder_a.shape

(torch.Size([1, 8, 10, 64]), torch.Size([1, 8, 10, 10]))

> Concatenation

In [63]:
decoder_v = decoder_v.reshape(batch_size, sequence_length, num_heads * head_dim)
decoder_v.shape

torch.Size([1, 10, 512])

> Let heads affect each other

In [64]:
linear_layer = nn.Linear(d_model, d_model)
decoder_out = linear_layer(decoder_v)
decoder_out.shape

torch.Size([1, 10, 512])

# Class Representation

In [65]:
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        # Learnable layer to generate QKV
        self.qkv_layer = nn.Linear(d_model , 3 * d_model)
        # Multi-head layer to mix all heads
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.size()
        qkv = self.qkv_layer(x)
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)
        val, attention = scaled_dot_product(q, k, v, mask)
        val = val.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        new_val = self.linear_layer(val)
        return new_val