# Multi-head Attention

1. Multi-head attention 및 self-attention 구현
2. 각 과정에서 일어나는 연산과 input/output 형태 이해

## 라이브러리

In [1]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

## 데이터 전처리

In [2]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54],
  [75, 51],
  [66, 88, 98, 47],
  [21, 39, 10, 64, 21],
  [98],
  [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 43, 10],
  [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34],
  [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43]
]

In [3]:
def padding(data):
    max_len = len(max(data, key=len))
    print(f"Maximum sequence length : {max_len}")
    
    for i, seq in enumerate(tqdm(data)):
        if len(seq) < max_len:
            data[i] = seq + [pad_id] * (max_len - len(seq))
            
    return data, max_len

In [4]:
data, max_len = padding(data)

100%|██████████| 10/10 [00:00<00:00, 137068.76it/s]

Maximum sequence length : 20





In [5]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75, 0, 0, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [75, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [21, 39, 10, 64, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [77,
  65,
  51,
  77,
  19,
  15,
  35,
  19,
  23,
  97,
  50,
  46,
  53,
  42,
  45,
  91,
  66,
  3,
  43,
  10],
 [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34, 0, 0, 0, 0],
 [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43, 0, 0, 0, 0, 0, 0]]

## Hyperparameter 세팅 및 embedding

In [6]:
d_model = 512 # model 의 hidden size
num_heads = 8 # head 의 개수

In [7]:
embedding = nn.Embedding(vocab_size, d_model)

# B : batch size, L : maximum sequnce length
batch = torch.LongTensor(data) # (B, L)
batch_emb = embedding(batch) # (B, L, d_model)

In [8]:
print(batch)
print(batch_emb)

tensor([[62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75,  0,  0,
          0,  0],
        [60, 96, 51, 32, 90,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [35, 45, 48, 65, 91, 99, 92, 10,  3, 21, 54,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [75, 51,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [66, 88, 98, 47,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [21, 39, 10, 64, 21,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [98,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0],
        [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66,  3,
         43, 10],
        [70, 64, 98, 25, 99, 53,  4, 13, 69, 62, 66, 76, 15, 75, 45, 34,  0,  0,
          0,  0],
        [20, 64, 81, 35, 76, 85,  1, 62,  8, 45, 99, 77, 19, 43,  0,  0,  0,  0,
          0,  0]])
tensor([[

## Linear transformation & 여러 head 로 나누기

Multi-head attention 내에서 쓰이는 linear transformation matrix 들을 정의합니다.

In [9]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

In [10]:
w_0 = nn.Linear(d_model, d_model)

In [13]:
batch_emb.shape

torch.Size([10, 20, 512])

In [11]:
q = w_q(batch_emb) # (B, L, d_model)
k = w_k(batch_emb) # (B, L, d_model)
v = w_v(batch_emb) # (B, L, d_model)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 512])
torch.Size([10, 20, 512])
torch.Size([10, 20, 512])


Q, K, V 를 `num_head` 개의 차원 분할된 여러 vector 로 만듭니다.

In [14]:
batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])


In [15]:
q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])


## Scaled dot-product self-attention 구현

각 head에서 실행되는 self-attention 과정

In [16]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
# (B, num_heads, L, L)
attn_dists = F.softmax(attn_scores, dim=-1) # (B, num_heads, L, L)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[0.0699, 0.0504, 0.0650,  ..., 0.0240, 0.0240, 0.0240],
          [0.0390, 0.0360, 0.0306,  ..., 0.0468, 0.0468, 0.0468],
          [0.0395, 0.0448, 0.0584,  ..., 0.0456, 0.0456, 0.0456],
          ...,
          [0.0339, 0.0286, 0.0389,  ..., 0.0845, 0.0845, 0.0845],
          [0.0339, 0.0286, 0.0389,  ..., 0.0845, 0.0845, 0.0845],
          [0.0339, 0.0286, 0.0389,  ..., 0.0845, 0.0845, 0.0845]],

         [[0.1091, 0.0719, 0.0330,  ..., 0.0496, 0.0496, 0.0496],
          [0.0264, 0.0488, 0.0625,  ..., 0.0522, 0.0522, 0.0522],
          [0.0280, 0.0555, 0.0380,  ..., 0.0912, 0.0912, 0.0912],
          ...,
          [0.0367, 0.0368, 0.0342,  ..., 0.0537, 0.0537, 0.0537],
          [0.0367, 0.0368, 0.0342,  ..., 0.0537, 0.0537, 0.0537],
          [0.0367, 0.0368, 0.0342,  ..., 0.0537, 0.0537, 0.0537]],

         [[0.0822, 0.0392, 0.0309,  ..., 0.0501, 0.0501, 0.0501],
          [0.0572, 0.0349, 0.0709,  ..., 0.0453, 0.0453, 0.0453],
          [0.0317, 0.0567, 0.0387,  ..., 0

In [17]:
attn_values = torch.matmul(attn_dists, v) # (B, num_heads, L, d_k)
print(attn_values.shape)

torch.Size([10, 8, 20, 64])


## 각 head 의 결과물 병합

각 head 의 결과물을 concat 하고 동일 차원으로 linear transformation 합니다.

In [18]:
attn_values = attn_values.transpose(1, 2)  # (B, L, num_heads, d_k)
attn_values = attn_values.contiguous().view(batch_size, -1, d_model)  # (B, L, d_model)

print(attn_values.shape)

torch.Size([10, 20, 512])


In [19]:
outputs = w_0(attn_values)

print(outputs)
print(outputs.shape)

tensor([[[-0.0524, -0.0433,  0.0039,  ...,  0.0866, -0.0977,  0.0717],
         [-0.0616, -0.1470,  0.1049,  ...,  0.0226, -0.1224,  0.0990],
         [ 0.0364, -0.0454,  0.0514,  ...,  0.0789, -0.0852,  0.1245],
         ...,
         [ 0.0358,  0.0075,  0.0507,  ...,  0.0781, -0.1399,  0.1180],
         [ 0.0358,  0.0075,  0.0507,  ...,  0.0781, -0.1399,  0.1180],
         [ 0.0358,  0.0075,  0.0507,  ...,  0.0781, -0.1399,  0.1180]],

        [[ 0.2211,  0.0014, -0.0457,  ...,  0.0663, -0.2574,  0.1505],
         [ 0.1832, -0.0408, -0.0717,  ...,  0.0201, -0.2497,  0.1716],
         [ 0.1919, -0.0275, -0.0324,  ...,  0.0249, -0.2271,  0.1705],
         ...,
         [ 0.2243, -0.0071, -0.0034,  ...,  0.0448, -0.2512,  0.1563],
         [ 0.2243, -0.0071, -0.0034,  ...,  0.0448, -0.2512,  0.1563],
         [ 0.2243, -0.0071, -0.0034,  ...,  0.0448, -0.2512,  0.1563]],

        [[ 0.1840,  0.0386,  0.0989,  ..., -0.0412, -0.1912,  0.1261],
         [ 0.1636,  0.0452,  0.0928,  ..., -0

## 전체코드

하나의 Multi-head attention 모듈 구현

In [20]:
class MultiheadAttention(nn.Module):
    def __init__(self):
        super(MultiheadAttention, self).__init__()

        # Q, K, V learnable matrices
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        # Linear transformation for concatenated outputs
        self.w_0 = nn.Linear(d_model, d_model)

    def forward(self, q, k, v):
        batch_size = q.shape[0]

        q = self.w_q(q)  # (B, L, d_model)
        k = self.w_k(k)  # (B, L, d_model)
        v = self.w_v(v)  # (B, L, d_model)

        q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
        k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
        v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

        q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
        k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
        v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

        attn_values = self.self_attention(q, k, v)  # (B, num_heads, L, d_k)
        attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

        return self.w_0(attn_values)

    def self_attention(self, q, k, v):
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
        attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

        attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

        return attn_values

In [21]:
multihead_attn = MultiheadAttention()
outputs = multihead_attn(batch_emb, batch_emb, batch_emb)
# (B, L, d_model)

In [22]:
print(outputs)
print(outputs.shape)

tensor([[[-0.0924, -0.1833, -0.0233,  ..., -0.0078, -0.2165,  0.1387],
         [-0.0651, -0.1497, -0.0011,  ..., -0.0165, -0.2354,  0.1337],
         [-0.0770, -0.1244, -0.0269,  ...,  0.0083, -0.2136,  0.1760],
         ...,
         [-0.0547, -0.1334, -0.0058,  ...,  0.0088, -0.2223,  0.1259],
         [-0.0547, -0.1334, -0.0058,  ...,  0.0088, -0.2223,  0.1259],
         [-0.0547, -0.1334, -0.0058,  ...,  0.0088, -0.2223,  0.1259]],

        [[ 0.0800, -0.2903, -0.0331,  ...,  0.0686,  0.0250,  0.3676],
         [ 0.1047, -0.2659,  0.0068,  ...,  0.0529,  0.0105,  0.3308],
         [ 0.1233, -0.2852,  0.0123,  ...,  0.0626,  0.0367,  0.3798],
         ...,
         [ 0.1296, -0.2597,  0.0376,  ...,  0.0455,  0.0120,  0.3828],
         [ 0.1296, -0.2597,  0.0376,  ...,  0.0455,  0.0120,  0.3828],
         [ 0.1296, -0.2597,  0.0376,  ...,  0.0455,  0.0120,  0.3828]],

        [[ 0.0602, -0.2527,  0.0954,  ...,  0.0286, -0.1102,  0.2559],
         [ 0.1172, -0.2607,  0.0771,  ...,  0