In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, reduce

Transformer 의 구조는 다음과 같다.

1. Encoder + Decoder
2. Encoder = Enc_sublayer + Enc_sublayer + Enc_sublayer + Enc_sublayer + Enc_sublayer + Enc_sublayer
3. Enc_sublayer = residual connection [Multi-head attention + Feed Forward]

In [2]:
class Transformer(nn.Module):
    def __init__(self, encoder, decoder):
        super(Transformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, input_sentence, some_sentence, mask):
        context         = self.encoder(input_sentence, mask)
        output_sentence = self.decoder(some_sentence, context)
        return output_sentence

In [3]:
import copy

class Encoder(nn.Module):
    def __init__(self, encoder_layer, n_layer):
        super(Encoder, self).__init__()
        self.layers = []
        for i in range(n_layer):
            self.layers.append(copy.deepcopy(encoder_layer))
            
    def forward(self, x, mask):
        out = x
        for layer in self.layers:
            out = layer(out, mask)
        return out

In [20]:
class EncoderLayer(nn.Module):
    def __init__(self, multi_head_attention_layer, position_wise_feed_forward_layer, norm_layer):
        super(EncoderLayer, self).__init__()
        self.multi_head_attention_layer = multi_head_attention_layer
        self.position_wise_feed_forward_layer = position_wise_feed_forward_layer
        self.residual_connection_layers = [ResidualConnectionLayer(copy.deepcopy(norm_layer)) for i in range(2)] #!#
        
    def forward(self, x, mask):
        out = self.residual_connection_layers[0](x, lambda x : self.multi_head_attention_layer(query = x, key = x, value = x, mask = mask))
        out = self.residual_connection_layers[1](x, lambda x : self.position_wise_feed_forward_layer(x))
        return out

### self-attention
self-Attention 에서의 계산하는 대상은 Query 가 주어졌을 때, 다른 token 에 대한 관계다. 먼저 Query, Key, Value 3가지 input 을 받게 된다.

1. Query: 현재 시점의 token을 의미
2. Key: attention을 구하고자 하는 대상 token을 의미
3. Value: attention을 구하고자 하는 대상 token을 의미 (Key와 동일한 token)

Query, Key, Value 는 input으로 들어오는 token embedding vector를 fully connected layer에 넣어 세 vector를 만들어낸다. ([n, d_embed] -> [n, dk])

수식은 이하와 같이 계산된다. 

$$
\text { Query's Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V
$$

Q[1, dk] x K.T[dk, n] = Attention Score [1, n]
즉, Attention score 는 Query 가 각각의 key 와의 유사도를 나타내다. (내적 = 유사도)

이렇게 구한 attention score 를 value 에 곱해주면 attention 결과를 계산할 수 있다. 

AS[1, n] x V[n, dk] = Query's attention [1, dk]

1개의 token 을 n개로 확장하려면 1 을 n 으로 바꿔주기만 하면 된다.

실제 논문에서는 multi-head attention 모델이 사용되는데, self-attention 을 병렬적으로 여러개 수행하는 것이다. 이를 통해 덜 중요한 attention 까지 포함할 수 있는 attention 을 얻을 수 있도록 돕는다.

h = 8 이라면, query, key, value's shape: (n_batch, seq_len, d_k * h) 가 된다. 이제 d_k * h 를 d_model 이라고 부른다. 

In [17]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, d_model, h, qkv_fc_layer, fc_layer):
        # b : batch, n : seq_len, h, d_k : key dim
        # qkv_fc_layer = (d_embed, d_model)
        # fc_layer = (d_model, d_embed)
        super(MultiHeadAttentionLayer, self).__init__()
        self.d_model = d_model
        self.h = h
        self.query_fc_layer = copy.deepcopy(qkv_fc_layer) # b n h d_embedding -> b n h d_k
        self.key_fc_layer   = copy.deepcopy(qkv_fc_layer)
        self.value_fc_layer = copy.deepcopy(qkv_fc_layer)
        self.fc_layer = fc_layer
        
    def calculate_attention(self, query, key, value, mask):
        #!# query, key, value's shape: (n_batch, seq_len, d_k)
        d_k = key.size(-1)
        attention_score = torch.einsum('bik, bjk -> bij', query, key) / torch.sqrt(d_k)
        if mask is not None:
            attention_score.masked_fill_(mask == 0, -1e9)
            # attention_score = attention_score.masked_fill(mask == 0, -1e9) #!# almost pseudo-code..
        attention_prob = F.softmax(attention_score, dim = -1)
        return torch.einsum('bik, bkj -> bij', attention_prob, value)
      
    def forward(self, query, key, value, mask = None):
        # query, key, value's shape: (n_batch, seq_len, d_embed)
        # mask's shape: (n_batch, seq_len, seq_len)
        n_batch = query.shape[0]
        
        def transform(x, fc_layer): 
            # x   : (n_batch, seq_len, d_embed)
            # out : (n_batch, h, seq_len, d_k)
            out = fc_layer(x) # (n_batch, seq_len, h * d_k)
            out = rearrange(out, 'b n (h d_k) -> b h n d_k', h = self.h) #!# need to check
            return out
        
        query = transform(query, self.query_fc_layer)
        key   = transform(key,   self.key_fc_layer)
        value = transform(value, self.value_fc_layer)
        
        if mask is not None:
            mask = rearrange(mask, 'b n m -> b () n m') # m is also seq_len
        
        out = self.calculate_attention(query, key, value, mask) # n_batch, h, seq_len, d_k -> n_batch, h, seq_len, d_k
        out = rearrange(out, 'b h n d_k -> b n (h d_k)') # n_batch, seq_len, d_model
        out = self.fc_layer(out) # n_batch, seq_len, d_model -> n_batch, seq_len, d_embed
        return out

### Position-wise Feed Forward Layer

Attetion 의 결과를 ADD & Norm 을 넣어 처리하고 Feed Forward layer 에 넣어준다. 이때, 사용하는 layer 가 바로 position-wise feed forward layer 다.

$$
\operatorname{FFN}(x)=\max \left(0, x W_{1}+b_{1}\right) W_{2}+b_{2}
$$

In [18]:
class PositionWiseFeedForwardLayer(nn.Module):
    def __init__(self, first_fc_layer, second_fc_layer):
        self.first_fc_layer  = first_fc_layer
        self.second_fc_layer = second_fc_layer
        
    def forward(self, x):
        fc_layer = nn.sequential(
            self.first_fc_layer(),
            F.relu(),
            self.dropout(),
            self.second_fc_layer()
        )
        out = fc_layer(x)
        
        return out

### residual connection + Layer Normalization

instead of using $y = f(x)$, choose $y = f(x) + x$

In [19]:
class ResidualConnectionLayer(nn.Module):
    def __init__(self, norm_layer):
        super(ResidualConnectionLayer, self).__init__()
        self.norm_layer = norm_layer
        
    def forward(self, x, sub_layer):
        out = sub_layer(x) + x
        out = self.norm_layer(out)
        return out

## Decoder

먼저 teacher force 방식으로 학습을 시키기 때문에 masking 을 위한 함수가 필요하다.

In [None]:
def subsequent_mask(size):
    attention_shape = (1, size, size)
    mask = torch.triu(torch.ones(attention_shape), k = 1)