In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
%matplotlib inline

Implementation of harvard in pytorch  
https://colab.research.google.com/drive/1-B8obSiAgAcq-VEJBVVzKIijySDRwm3k?usp=sharing#scrollTo=f15nl75sXT1M

In [2]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

![Figure 2](../images/Multi-Head_Attention.png)

In [3]:
def scaled_dot_product_attention(query, key, value, mask=None):
    # query 크기 : (batch_size, num_heads, query의 문장 길이, d_model/num_heads)
    # key 크기 : (batch_size, num_heads, key의 문장 길이, d_model/num_heads)
    # value 크기 : (batch_size, num_heads, value의 문장 길이, d_model/num_heads)
    # mask : (batch_size, 1, 1, key의 문장 길이)

    # Q와 K의 곱. 어텐션 스코어 행렬.
    matmul_qk = torch.matmul(query, key.transpose(-2, -1))  # (batch, heads, q_len, k_len)

    print("matmul_qk.shape : ", matmul_qk.shape)
    # 스케일링
    # dk의 루트값으로 나눠준다.
    depth = key.size()[-1]
    logits = matmul_qk / torch.sqrt(torch.tensor(depth, dtype=torch.float32, device=query.device))

    print("logits.shape : ", logits.shape)
    
    if mask is not None:
        print("logits.shape : ", logits.shape) 
        logits = logits.masked_fill(mask == 0, float('-1e9'))

    # 소프트맥스 함수는 마지막 차원인 key의 문장 길이 방향으로 수행된다.
    # attention weight : (batch_size, num_heads, query의 문장 길이, key의 문장 길이)
    attention_weights = F.softmax(logits, dim=-1)

    # output : (batch_size, num_heads, query의 문장 길이, d_model/num_heads)
    output = torch.matmul(attention_weights, value)

    return output, attention_weights

Linear의 shape이 $d\_{model} \times d\_{model}$인 이유는 ?  
multi-head 연산을 병렬로 수행하기 위해  $d\_{model} \times depth$ 형을 8번 붙여 놓은 것임
``` python
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
```

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.num_heads = num_heads
        self.d_model = d_model
        self.depth = d_model // num_heads

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)

        self.dense = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        # x: (batch_size, seq_len, d_model)
        x = x.view(batch_size, -1, self.num_heads, self.depth)  # (batch_size, seq_len, num_heads, depth)
        return x.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, depth)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1. Linear projections
        q = self.wq(query)  # (batch_size, seq_len_q, d_model)
        k = self.wk(key)    # (batch_size, seq_len_k, d_model)
        v = self.wv(value)  # (batch_size, seq_len_v, d_model)
        print("1th q.shape : ", q.shape)

        # 2. Split into heads
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        print("2th q.shape : ", q.shape)
        print("maks shape : ", mask.shape)
        
        # 3. Scaled dot-product attention
        scaled_attention, attention_weight = scaled_dot_product_attention(q, k, v, mask)  # (batch_size, num_heads, seq_len_q, depth)

        # 4. Concat heads
        scaled_attention = scaled_attention.permute(0, 2, 1, 3).contiguous()  # (batch_size, seq_len_q, num_heads, depth)
        concat_attention = scaled_attention.view(batch_size, -1, self.d_model)  # (batch_size, seq_len_q, d_model)

        # 5. Final linear layer
        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output, attention_weight


마스킹은 'Key'값을 기준으로 설정됨, Query랑은 상관 없음.

In [5]:
def create_padding_mask(x,pad=0):
    """
    x: (batch_size, seq_len)
    return: mask of shape (batch_size, 1, 1, seq_len)
    """
    mask = (x != pad).float()
    return mask.unsqueeze(1).unsqueeze(1)

In [6]:
h = 8
d_model = 512

In [7]:
attn = MultiHeadAttention(d_model, h)

In [8]:
embedding = Embeddings(d_model,11)

In [9]:
x = torch.from_numpy(np.random.randint(0, 11, size=(1, 10)))
#x = embedding(x)

In [10]:
xx = embedding(x)
print("xx input shape : ", xx.shape)
pad_mask = create_padding_mask(x)
print("pad_mask shape : ", pad_mask.shape)
y, weight = attn(xx,xx,xx,pad_mask)

xx input shape :  torch.Size([1, 10, 512])
pad_mask shape :  torch.Size([1, 1, 1, 10])
1th q.shape :  torch.Size([1, 10, 512])
2th q.shape :  torch.Size([1, 8, 10, 64])
maks shape :  torch.Size([1, 1, 1, 10])
matmul_qk.shape :  torch.Size([1, 8, 10, 10])
logits.shape :  torch.Size([1, 8, 10, 10])
logits.shape :  torch.Size([1, 8, 10, 10])


In [11]:
pad_mask

tensor([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 0.]]]])

In [12]:
weight[0,0]

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         5.0001e-01, 4.9999e-01, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         5.0000e-01, 5.0000e-01, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000