<a href="https://colab.research.google.com/github/jiyun1006/deeplearning-pytorch/blob/main/multi_head_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##**7. Multi-head Attention**
1. Multi-head attention 및 self-attention 구현.
2. 각 과정에서 일어나는 연산과 input/output 형태 이해.

### **필요 패키지 import**

In [None]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

### **데이터 전처리**

In [None]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54],
  [75, 51],
  [66, 88, 98, 47],
  [21, 39, 10, 64, 21],
  [98],
  [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 43, 10],
  [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34],
  [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43]
]

In [None]:
def padding(data):
  max_len = len(max(data, key=len))
  print(f"Maximum sequence length: {max_len}")

  for i, seq in enumerate(tqdm(data)):
    if len(seq) < max_len:
      data[i] = seq + [pad_id] * (max_len - len(seq))

  return data, max_len

In [None]:
data, max_len = padding(data)

100%|██████████| 10/10 [00:00<00:00, 38872.14it/s]

Maximum sequence length: 20





In [None]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75, 0, 0, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [75, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [21, 39, 10, 64, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [77,
  65,
  51,
  77,
  19,
  15,
  35,
  19,
  23,
  97,
  50,
  46,
  53,
  42,
  45,
  91,
  66,
  3,
  43,
  10],
 [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34, 0, 0, 0, 0],
 [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43, 0, 0, 0, 0, 0, 0]]

### **Hyperparameter 세팅 및 embedding**

In [None]:
d_model = 512  # model의 hidden size
num_heads = 8  # head의 개수

In [None]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)  # (B, L, d_model)

In [None]:
print(batch_emb)
print(batch_emb.shape)

tensor([[[ 2.4761e+00, -7.2860e-01,  6.0909e-01,  ...,  3.9985e-01,
          -7.9964e-02,  6.7592e-01],
         [-8.6604e-01, -4.7887e-01,  3.3405e-01,  ..., -8.2642e-01,
          -8.2004e-01,  1.4619e+00],
         [-2.4253e-03,  6.0598e-01,  1.0527e+00,  ..., -1.4679e+00,
           6.8125e-01, -3.7941e-01],
         ...,
         [ 1.4205e+00,  2.5069e+00, -1.3622e+00,  ...,  1.2034e+00,
          -5.0359e-01, -9.8363e-01],
         [ 1.4205e+00,  2.5069e+00, -1.3622e+00,  ...,  1.2034e+00,
          -5.0359e-01, -9.8363e-01],
         [ 1.4205e+00,  2.5069e+00, -1.3622e+00,  ...,  1.2034e+00,
          -5.0359e-01, -9.8363e-01]],

        [[-6.2930e-01,  1.4235e+00,  1.4935e+00,  ..., -6.0022e-01,
           5.0821e-01,  6.9357e-01],
         [ 6.1116e-01,  1.2564e+00, -6.4565e-01,  ...,  1.5727e+00,
          -1.0037e+00,  2.4589e+00],
         [-8.1986e-01, -1.8498e+00, -1.3489e+00,  ..., -1.4555e+00,
          -6.7207e-02,  1.0520e-01],
         ...,
         [ 1.4205e+00,  2

### **Linear transformation & 여러 head로 나누기**

Multi-head attention 내에서 쓰이는 linear transformation matrix들을 정의합니다.

In [None]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

In [None]:
w_0 = nn.Linear(d_model, d_model) # multi head attention이 끝나고 최종적으로 합쳐주기 위한 linear

In [None]:
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 512])
torch.Size([10, 20, 512])
torch.Size([10, 20, 512])


Q, k, v를 `num_head`개의 차원 분할된 여러 vector로 만듭니다.

In [None]:
batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])


In [None]:
q = q.transpose(1, 2)  # (B, num_heads, L, d_k) 각 head가 L x d_k 개의 행렬을 가지게 되는 꼴.
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])


### **Scaled dot-product self-attention 구현**

각 head에서 실행되는 self-attetion 과정입니다.

In [None]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[0.0517, 0.0424, 0.0419,  ..., 0.0877, 0.0877, 0.0877],
          [0.0289, 0.0506, 0.0353,  ..., 0.0510, 0.0510, 0.0510],
          [0.0750, 0.0927, 0.0622,  ..., 0.0343, 0.0343, 0.0343],
          ...,
          [0.0643, 0.0409, 0.0585,  ..., 0.0502, 0.0502, 0.0502],
          [0.0643, 0.0409, 0.0585,  ..., 0.0502, 0.0502, 0.0502],
          [0.0643, 0.0409, 0.0585,  ..., 0.0502, 0.0502, 0.0502]],

         [[0.0556, 0.0473, 0.0396,  ..., 0.0671, 0.0671, 0.0671],
          [0.0564, 0.0414, 0.0341,  ..., 0.0437, 0.0437, 0.0437],
          [0.0512, 0.0475, 0.0431,  ..., 0.0638, 0.0638, 0.0638],
          ...,
          [0.0568, 0.0541, 0.0358,  ..., 0.0442, 0.0442, 0.0442],
          [0.0568, 0.0541, 0.0358,  ..., 0.0442, 0.0442, 0.0442],
          [0.0568, 0.0541, 0.0358,  ..., 0.0442, 0.0442, 0.0442]],

         [[0.0437, 0.0581, 0.0405,  ..., 0.0456, 0.0456, 0.0456],
          [0.0508, 0.0783, 0.0448,  ..., 0.0461, 0.0461, 0.0461],
          [0.0363, 0.0614, 0.0407,  ..., 0

가중합 구하는 부분


In [None]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k) 가중 합 구하는 부분

print(attn_values.shape)

torch.Size([10, 8, 20, 64])


### **각 head의 결과물 병합**

각 head의 결과물을 concat하고 동일 차원으로 linear transformation합니다.

In [None]:
attn_values = attn_values.transpose(1, 2)  # (B, L, num_heads, d_k)
attn_values = attn_values.contiguous().view(batch_size, -1, d_model)  # (B, L, d_model)

print(attn_values.shape)

torch.Size([10, 20, 512])


In [None]:
outputs = w_0(attn_values)

print(outputs)
print(outputs.shape)

tensor([[[ 0.1122,  0.0884,  0.0375,  ...,  0.1817,  0.0401,  0.0583],
         [ 0.0847,  0.0543,  0.1429,  ...,  0.1294,  0.0264,  0.0867],
         [ 0.0859,  0.0524,  0.1587,  ...,  0.1333,  0.0111,  0.0725],
         ...,
         [ 0.0919,  0.1168,  0.1214,  ...,  0.1824,  0.0553, -0.0015],
         [ 0.0919,  0.1168,  0.1214,  ...,  0.1824,  0.0553, -0.0015],
         [ 0.0919,  0.1168,  0.1214,  ...,  0.1824,  0.0553, -0.0015]],

        [[ 0.4535,  0.3756, -0.3434,  ...,  0.3122,  0.3434, -0.0211],
         [ 0.4135,  0.4138, -0.3151,  ...,  0.3449,  0.2774, -0.0910],
         [ 0.4243,  0.4217, -0.3060,  ...,  0.3528,  0.3420, -0.0577],
         ...,
         [ 0.3870,  0.4032, -0.2868,  ...,  0.3416,  0.3164, -0.0681],
         [ 0.3870,  0.4032, -0.2868,  ...,  0.3416,  0.3164, -0.0681],
         [ 0.3870,  0.4032, -0.2868,  ...,  0.3416,  0.3164, -0.0681]],

        [[ 0.2257,  0.3273, -0.1103,  ...,  0.2016,  0.2097, -0.1306],
         [ 0.2509,  0.2792, -0.1773,  ...,  0

### **전체 코드**

위의 과정을 모두 합쳐 하나의 Multi-head attention 모듈을 구현하겠습니다.

In [None]:
class MultiheadAttention(nn.Module):
  def __init__(self):
    super(MultiheadAttention, self).__init__()

    # Q, K, V learnable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)

    # Linear transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)

  def forward(self, q, k, v):
    batch_size = q.shape[0]

    q = self.w_q(q)  # (B, L, d_model)
    k = self.w_k(k)  # (B, L, d_model)
    v = self.w_v(v)  # (B, L, d_model)

    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

    q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
    k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
    v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

    attn_values = self.self_attention(q, k, v)  # (B, num_heads, L, d_k)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

    return self.w_0(attn_values)

  def self_attention(self, q, k, v):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
    attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

    attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

    return attn_values

In [None]:
multihead_attn = MultiheadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb)  # (B, L, d_model)

In [None]:
print(outputs)
print(outputs.shape)

tensor([[[ 2.8643e-03,  1.1555e-01,  1.2496e-01,  ..., -3.7488e-02,
          -3.0124e-02,  3.1039e-02],
         [-1.4078e-02,  1.0972e-01,  1.0326e-01,  ..., -6.7813e-02,
          -5.5103e-02,  6.8791e-02],
         [-2.8697e-02,  6.5568e-02,  1.0969e-01,  ..., -8.0120e-02,
          -1.6853e-02,  1.4531e-02],
         ...,
         [ 1.7992e-02,  1.0360e-01,  6.1830e-02,  ..., -1.1121e-01,
          -4.9953e-02,  8.3009e-03],
         [ 1.7992e-02,  1.0360e-01,  6.1830e-02,  ..., -1.1121e-01,
          -4.9953e-02,  8.3010e-03],
         [ 1.7992e-02,  1.0360e-01,  6.1830e-02,  ..., -1.1121e-01,
          -4.9953e-02,  8.3010e-03]],

        [[-1.9409e-01, -1.6303e-01,  8.1865e-02,  ..., -3.5009e-01,
          -2.2865e-01, -5.5588e-03],
         [-1.8015e-01, -1.0694e-01,  1.1501e-01,  ..., -3.4100e-01,
          -3.0557e-01,  2.6842e-02],
         [-1.8776e-01, -1.3219e-01,  1.0427e-01,  ..., -3.7695e-01,
          -2.3110e-01,  6.5802e-02],
         ...,
         [-1.7091e-01, -1