##**7. Multi-head Attention**
1. Multi-head attention 및 self-attention 구현.
2. 각 과정에서 일어나는 연산과 input/output 형태 이해.

### **필요 패키지 import**

In [1]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

### **데이터 전처리**

In [7]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54],
  [75, 51],
  [66, 88, 98, 47],
  [21, 39, 10, 64, 21],
  [98],
  [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 43, 10],
  [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34],
  [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43]
]

In [3]:
def padding(data):
  max_len = len(max(data, key=len)) # 가장 긴 데이터 길이
  print(f"Maximum sequence length: {max_len}")

  for i, seq in enumerate(tqdm(data)):
    if len(seq) < max_len:
      data[i] = seq + [pad_id] * (max_len - len(seq)) # 패딩 채워주기

  return data, max_len

In [11]:
data, max_len = padding(data)

100%|██████████| 10/10 [00:00<00:00, 10240.00it/s]

Maximum sequence length: 20





In [12]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75, 0, 0, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [75, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [21, 39, 10, 64, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [77,
  65,
  51,
  77,
  19,
  15,
  35,
  19,
  23,
  97,
  50,
  46,
  53,
  42,
  45,
  91,
  66,
  3,
  43,
  10],
 [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34, 0, 0, 0, 0],
 [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43, 0, 0, 0, 0, 0, 0]]

### **Hyperparameter 세팅 및 embedding**

In [13]:
d_model = 512  # model의 hidden size
num_heads = 8  # head의 개수

In [14]:
# emb layer
embedding = nn.Embedding(vocab_size, d_model) # vocab_size = 전체 vocab 개수, d_model = emb 차원

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)  # (B, L, d_model)

In [15]:
print(batch_emb)
print(batch_emb.shape) # (10, 20, 512) -> 10개의 데이터가 20개의 길이를 가지고 있고 각 토큰은 512 차원을 가지고 있음

tensor([[[ 0.1662,  0.3838,  2.4579,  ...,  0.2768,  0.3098, -0.7172],
         [ 1.4068, -2.4069, -0.2465,  ...,  0.8217, -1.0459,  0.7447],
         [-0.0828,  0.4877, -0.2899,  ..., -1.2002,  0.3610, -0.5172],
         ...,
         [-0.6203,  0.5518,  1.4681,  ..., -0.6171,  0.1487,  2.0074],
         [-0.6203,  0.5518,  1.4681,  ..., -0.6171,  0.1487,  2.0074],
         [-0.6203,  0.5518,  1.4681,  ..., -0.6171,  0.1487,  2.0074]],

        [[-1.5674, -1.6302, -1.1374,  ...,  0.5802, -0.4963, -1.2123],
         [-1.4222,  1.3504, -1.8712,  ...,  0.9889, -0.2262,  0.4924],
         [ 0.0355, -0.9395,  1.0601,  ..., -0.6091,  1.2678,  1.9001],
         ...,
         [-0.6203,  0.5518,  1.4681,  ..., -0.6171,  0.1487,  2.0074],
         [-0.6203,  0.5518,  1.4681,  ..., -0.6171,  0.1487,  2.0074],
         [-0.6203,  0.5518,  1.4681,  ..., -0.6171,  0.1487,  2.0074]],

        [[-0.3575,  0.3955,  0.8756,  ...,  1.3762,  0.9230, -1.8228],
         [ 0.4443, -0.9241, -0.1414,  ...,  1

### **Linear transformation & 여러 head로 나누기**

Multi-head attention 내에서 쓰이는 linear transformation matrix들을 정의합니다.

In [22]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

In [23]:
w_0 = nn.Linear(d_model, d_model) # 

In [24]:
# batch_emb.size() = (10, 20, 512)
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 512])
torch.Size([10, 20, 512])
torch.Size([10, 20, 512])


Q, k, v를 `num_head`개의 차원 분할된 여러 vector로 만듭니다.

In [25]:
# num_heads = 8
batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])


In [26]:
q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])


### **Scaled dot-product self-attention 구현**

각 head에서 실행되는 self-attetion 과정입니다.

In [29]:
# 10개의 데이터가 존재하는 8개의 head의 (20 * 64) dot (64 * 20)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[0.0489, 0.0607, 0.0438,  ..., 0.0750, 0.0750, 0.0750],
          [0.0513, 0.0421, 0.0390,  ..., 0.0581, 0.0581, 0.0581],
          [0.0638, 0.0425, 0.0495,  ..., 0.0432, 0.0432, 0.0432],
          ...,
          [0.0502, 0.0530, 0.0599,  ..., 0.0377, 0.0377, 0.0377],
          [0.0502, 0.0530, 0.0599,  ..., 0.0377, 0.0377, 0.0377],
          [0.0502, 0.0530, 0.0599,  ..., 0.0377, 0.0377, 0.0377]],

         [[0.0497, 0.0355, 0.0502,  ..., 0.0544, 0.0544, 0.0544],
          [0.0725, 0.0616, 0.0459,  ..., 0.0627, 0.0627, 0.0627],
          [0.0421, 0.0309, 0.0679,  ..., 0.0679, 0.0679, 0.0679],
          ...,
          [0.0479, 0.0736, 0.0361,  ..., 0.0455, 0.0455, 0.0455],
          [0.0479, 0.0736, 0.0361,  ..., 0.0455, 0.0455, 0.0455],
          [0.0479, 0.0736, 0.0361,  ..., 0.0455, 0.0455, 0.0455]],

         [[0.0701, 0.0398, 0.0367,  ..., 0.0546, 0.0546, 0.0546],
          [0.0520, 0.0305, 0.0380,  ..., 0.0591, 0.0591, 0.0591],
          [0.0575, 0.0413, 0.0418,  ..., 0

In [30]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

print(attn_values.shape)

torch.Size([10, 8, 20, 64])


### **각 head의 결과물 병합**

각 head의 결과물을 concat하고 동일 차원으로 linear transformation합니다.

In [31]:
attn_values = attn_values.transpose(1, 2)  # (B, L, num_heads, d_k)
attn_values = attn_values.contiguous().view(batch_size, -1, d_model)  # (B, L, d_model) # concat하는 과정인가?

print(attn_values.shape)

torch.Size([10, 20, 512])


In [46]:
outputs = w_0(attn_values) # Linear transformation 해주는

print(outputs.shape)
print(outputs)

torch.Size([10, 20, 512])
tensor([[[-0.0258,  0.1026, -0.2086,  ..., -0.1510,  0.2324, -0.0641],
         [-0.0064,  0.1259, -0.2044,  ..., -0.1109,  0.2817, -0.0371],
         [-0.0137,  0.1545, -0.1765,  ..., -0.0479,  0.2172, -0.0493],
         ...,
         [-0.0681,  0.0860, -0.2196,  ..., -0.0833,  0.2378, -0.0051],
         [-0.0681,  0.0860, -0.2196,  ..., -0.0833,  0.2378, -0.0051],
         [-0.0681,  0.0860, -0.2196,  ..., -0.0833,  0.2378, -0.0051]],

        [[-0.2378,  0.0199, -0.1467,  ..., -0.4187,  0.3819, -0.3111],
         [-0.2250,  0.0102, -0.1895,  ..., -0.4359,  0.4152, -0.3523],
         [-0.2186,  0.0431, -0.1793,  ..., -0.4798,  0.4759, -0.3452],
         ...,
         [-0.1684,  0.0335, -0.1808,  ..., -0.4452,  0.4642, -0.3485],
         [-0.1684,  0.0335, -0.1808,  ..., -0.4452,  0.4642, -0.3485],
         [-0.1684,  0.0335, -0.1808,  ..., -0.4452,  0.4642, -0.3485]],

        [[-0.0591, -0.0127, -0.2687,  ..., -0.2445,  0.2560, -0.2716],
         [-0.0912, 

### **전체 코드**

위의 과정을 모두 합쳐 하나의 Multi-head attention 모듈을 구현하겠습니다.

In [47]:
class MultiheadAttention(nn.Module):
  def __init__(self):
    super(MultiheadAttention, self).__init__()

    # Q, K, V learnable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)

    # Linear transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)

  def forward(self, q, k, v):
    batch_size = q.shape[0]

    q = self.w_q(q)  # (B, L, d_model)
    k = self.w_k(k)  # (B, L, d_model)
    v = self.w_v(v)  # (B, L, d_model)

    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

    q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
    k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
    v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

    attn_values = self.self_attention(q, k, v)  # (B, num_heads, L, d_k)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

    return self.w_0(attn_values)

  def self_attention(self, q, k, v):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
    attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

    attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

    return attn_values

In [48]:
multihead_attn = MultiheadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb)  # (B, L, d_model)

In [50]:
print(outputs.shape)
print(outputs)

torch.Size([10, 20, 512])
tensor([[[-0.0049,  0.1423,  0.0913,  ...,  0.0972, -0.0080,  0.1448],
         [ 0.0450,  0.0733,  0.0623,  ...,  0.0858, -0.0140,  0.1039],
         [-0.0068,  0.0650,  0.0804,  ...,  0.0119, -0.0195,  0.0922],
         ...,
         [-0.0017,  0.0737,  0.1249,  ...,  0.0576, -0.0383,  0.1278],
         [-0.0017,  0.0737,  0.1249,  ...,  0.0576, -0.0383,  0.1278],
         [-0.0017,  0.0737,  0.1249,  ...,  0.0576, -0.0383,  0.1278]],

        [[-0.1339,  0.2719,  0.2828,  ..., -0.2396, -0.1392,  0.3415],
         [-0.1135,  0.3218,  0.2687,  ..., -0.2579, -0.0988,  0.3660],
         [-0.0815,  0.3035,  0.2919,  ..., -0.2346, -0.0671,  0.3198],
         ...,
         [-0.1189,  0.2846,  0.2677,  ..., -0.2408, -0.1038,  0.3249],
         [-0.1189,  0.2846,  0.2677,  ..., -0.2408, -0.1038,  0.3249],
         [-0.1189,  0.2846,  0.2677,  ..., -0.2408, -0.1038,  0.3249]],

        [[-0.0894,  0.1179,  0.1524,  ..., -0.1467, -0.0475,  0.2280],
         [-0.0531, 