<a href="https://colab.research.google.com/github/kyungminkim-dev/boostcamp-ai-tech/blob/main/7_multi_head_attention_ipynb%EC%9D%98_%EC%82%AC%EB%B3%B8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##**7. Multi-head Attention**
1. Multi-head attention 및 self-attention 구현.
2. 각 과정에서 일어나는 연산과 input/output 형태 이해.

### **필요 패키지 import**

In [None]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

### **데이터 전처리**

In [None]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54],
  [75, 51],
  [66, 88, 98, 47],
  [21, 39, 10, 64, 21],
  [98],
  [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 43, 10],
  [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34],
  [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43]
]

In [None]:
def padding(data):
  max_len = len(max(data, key=len))
  print(f"Maximum sequence length: {max_len}")

  for i, seq in enumerate(tqdm(data)):
    if len(seq) < max_len:
      data[i] = seq + [pad_id] * (max_len - len(seq))

  return data, max_len

In [None]:
data, max_len = padding(data)

100%|██████████| 10/10 [00:00<00:00, 39681.21it/s]

Maximum sequence length: 20





In [None]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75, 0, 0, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [75, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [21, 39, 10, 64, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [77,
  65,
  51,
  77,
  19,
  15,
  35,
  19,
  23,
  97,
  50,
  46,
  53,
  42,
  45,
  91,
  66,
  3,
  43,
  10],
 [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34, 0, 0, 0, 0],
 [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43, 0, 0, 0, 0, 0, 0]]

### **Hyperparameter 세팅 및 embedding**

In [None]:
d_model = 512  # model의 hidden size
num_heads = 8  # head의 개수

In [None]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)  # (B, L, d_model)

In [None]:
print(batch_emb)
print(batch_emb.shape)

tensor([[[-0.3130, -0.6466, -0.6364,  ...,  0.4092,  0.7185,  0.8237],
         [ 1.6666,  0.6779, -0.8941,  ..., -0.2906,  1.4590, -0.5031],
         [-0.7811,  0.8180,  0.6448,  ..., -0.0807,  1.2186,  0.2906],
         ...,
         [-1.2335, -1.0142, -0.8085,  ..., -0.9706, -1.2609,  1.3754],
         [-1.2335, -1.0142, -0.8085,  ..., -0.9706, -1.2609,  1.3754],
         [-1.2335, -1.0142, -0.8085,  ..., -0.9706, -1.2609,  1.3754]],

        [[ 0.1669, -0.8360,  1.5726,  ..., -1.4015,  2.0910, -0.9003],
         [-1.3691,  0.6413,  0.2234,  ...,  0.0307, -0.5075, -1.0385],
         [ 1.2657,  0.3973, -0.4212,  ...,  0.9322,  1.0555, -0.2589],
         ...,
         [-1.2335, -1.0142, -0.8085,  ..., -0.9706, -1.2609,  1.3754],
         [-1.2335, -1.0142, -0.8085,  ..., -0.9706, -1.2609,  1.3754],
         [-1.2335, -1.0142, -0.8085,  ..., -0.9706, -1.2609,  1.3754]],

        [[-0.5233,  1.3509,  0.1022,  ..., -0.0256,  1.5855,  0.4957],
         [ 0.7227, -0.0131, -1.5200,  ...,  0

### **Linear transformation & 여러 head로 나누기**

Multi-head attention 내에서 쓰이는 linear transformation matrix들을 정의합니다.

In [None]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

In [None]:
w_0 = nn.Linear(d_model, d_model)

In [None]:
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 512])
torch.Size([10, 20, 512])
torch.Size([10, 20, 512])


Q, k, v를 `num_head`개의 차원 분할된 여러 vector로 만듭니다.

In [None]:
batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])


In [None]:
q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])


### **Scaled dot-product self-attention 구현**

각 head에서 실행되는 self-attetion 과정입니다.

In [None]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[0.0730, 0.0249, 0.0589,  ..., 0.0396, 0.0396, 0.0396],
          [0.0851, 0.0448, 0.0288,  ..., 0.0371, 0.0371, 0.0371],
          [0.0614, 0.0490, 0.0504,  ..., 0.0252, 0.0252, 0.0252],
          ...,
          [0.0506, 0.0413, 0.0359,  ..., 0.0433, 0.0433, 0.0433],
          [0.0506, 0.0413, 0.0359,  ..., 0.0433, 0.0433, 0.0433],
          [0.0506, 0.0413, 0.0359,  ..., 0.0433, 0.0433, 0.0433]],

         [[0.0404, 0.0511, 0.0407,  ..., 0.0616, 0.0616, 0.0616],
          [0.0303, 0.0663, 0.0474,  ..., 0.0398, 0.0398, 0.0398],
          [0.0704, 0.0413, 0.0385,  ..., 0.0379, 0.0379, 0.0379],
          ...,
          [0.0553, 0.0407, 0.0458,  ..., 0.0453, 0.0453, 0.0453],
          [0.0553, 0.0407, 0.0458,  ..., 0.0453, 0.0453, 0.0453],
          [0.0553, 0.0407, 0.0458,  ..., 0.0453, 0.0453, 0.0453]],

         [[0.0469, 0.0534, 0.0682,  ..., 0.0433, 0.0433, 0.0433],
          [0.0702, 0.0597, 0.0324,  ..., 0.0644, 0.0644, 0.0644],
          [0.0436, 0.0659, 0.0606,  ..., 0

In [None]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

print(attn_values.shape)

torch.Size([10, 8, 20, 64])


### **각 head의 결과물 병합**

각 head의 결과물을 concat하고 동일 차원으로 linear transformation합니다.

In [None]:
attn_values = attn_values.transpose(1, 2)  # (B, L, num_heads, d_k)
attn_values = attn_values.contiguous().view(batch_size, -1, d_model)  # (B, L, d_model)

print(attn_values.shape)

torch.Size([10, 20, 512])


In [None]:
outputs = w_0(attn_values)

print(outputs)
print(outputs.shape)

tensor([[[-3.2953e-01, -6.5045e-03,  5.5121e-02,  ...,  6.3708e-03,
           1.0475e-02,  8.1765e-02],
         [ 9.4638e-02, -8.5143e-03, -5.8232e-02,  ..., -1.6051e-01,
          -1.5074e-01,  7.4297e-02],
         [ 8.4103e-02, -4.2499e-02, -1.4555e-01,  ...,  9.6339e-03,
           8.3171e-04,  6.2981e-02],
         ...,
         [-1.3716e-01,  6.6446e-02,  1.8032e-03,  ...,  1.9378e-01,
          -1.7945e-01, -9.1599e-03],
         [-1.9815e-01, -1.0303e-01, -2.4722e-02,  ...,  2.4668e-01,
           1.2503e-01, -8.2504e-02],
         [ 1.3707e-01, -8.2942e-02,  8.0011e-02,  ...,  1.5869e-01,
          -1.2215e-02,  8.4874e-02]],

        [[-5.8641e-01,  9.1846e-02,  1.4971e-01,  ..., -6.4201e-02,
           1.8333e-01, -9.2120e-02],
         [ 2.7623e-01, -2.9758e-02, -1.5169e-01,  ..., -3.2475e-01,
          -2.6500e-01,  2.1315e-01],
         [ 1.8831e-01,  4.3489e-03, -2.8373e-01,  ...,  3.2250e-01,
          -1.1052e-02,  1.2384e-01],
         ...,
         [ 1.1189e-01,  3

### **전체 코드**

위의 과정을 모두 합쳐 하나의 Multi-head attention 모듈을 구현하겠습니다.

In [None]:
class MultiheadAttention(nn.Module):
  def __init__(self):
    super(MultiheadAttention, self).__init__()

    # Q, K, V learnable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)

    # Linear transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)

  def forward(self, q, k, v):
    batch_size = q.shape[0]

    q = self.w_q(q)  # (B, L, d_model)
    k = self.w_k(k)  # (B, L, d_model)
    v = self.w_v(v)  # (B, L, d_model)

    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

    q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
    k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
    v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

    attn_values = self.self_attention(q, k, v)  # (B, num_heads, L, d_k)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

    return self.w_0(attn_values)

  def self_attention(self, q, k, v):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
    attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

    attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

    return attn_values

In [None]:
multihead_attn = MultiheadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb)  # (B, L, d_model)

In [None]:
print(outputs)
print(outputs.shape)

tensor([[[ 0.0529, -0.0699,  0.0349,  ...,  0.0748, -0.0798,  0.0258],
         [ 0.1503, -0.0778, -0.0345,  ...,  0.0374, -0.0475,  0.0037],
         [ 0.0719, -0.0942,  0.0203,  ...,  0.0674, -0.0556, -0.0073],
         ...,
         [ 0.1416, -0.0343, -0.0094,  ...,  0.0688, -0.0589,  0.0386],
         [ 0.1416, -0.0343, -0.0094,  ...,  0.0688, -0.0589,  0.0386],
         [ 0.1416, -0.0343, -0.0094,  ...,  0.0688, -0.0589,  0.0386]],

        [[ 0.1498, -0.0180,  0.1246,  ...,  0.2544,  0.2209,  0.0353],
         [ 0.1812,  0.0320,  0.1634,  ...,  0.2869,  0.1846,  0.0479],
         [ 0.1669,  0.0575,  0.1120,  ...,  0.2746,  0.2328,  0.0891],
         ...,
         [ 0.1625,  0.0239,  0.1223,  ...,  0.3176,  0.1976,  0.0513],
         [ 0.1625,  0.0239,  0.1223,  ...,  0.3176,  0.1976,  0.0513],
         [ 0.1625,  0.0239,  0.1223,  ...,  0.3176,  0.1976,  0.0513]],

        [[ 0.1589,  0.0334,  0.2046,  ...,  0.2300,  0.1656, -0.0525],
         [ 0.1099,  0.0454,  0.0962,  ...,  0