<a href="https://colab.research.google.com/github/gagyeomkim/Deep-Learning-Paper-Review-and-Practice/blob/main/code_practice/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Attention Is All You Need

ref: https://cpm0722.github.io/pytorch-implementation/transformer

- 해당 코드는 **Transformer**의 Encoder-Decoder 구조를 이해하기 위한 실습입니다.
    - Dropout 기법을 사용하는 부분은 포함되어있지 않습니다.

### **block**

#### encoder_block.py

In [48]:
import copy
import torch.nn

# from models.layer.residual_connection_layer import ResidualConnectionLayer

class EncoderBlock(nn.Module):

    def __init__(self, self_attention, position_ff, norm):
        """
        파라미터:
        self_attention: Multi-Head Attention Layer
        position_ff: Position-wise Feed-Forward Layer
        """
        super(EncoderBlock, self).__init__()
        self.self_attention = self_attention
        self.position_ff = position_ff
        self.residuals = [ResidualConnectionLayer(copy.deepcopy(norm)) for _ in range(2)]  # Residual Connection 2개 생성

    def forward(self, src, src_mask):
        """
        파라미터:
        src: Encoder의 input sentence
        src_mask: pad mask
        """
        out = src
        # ResidualConnectionLayer에서는 sub_layer의 forward()에 인자를 1개만 받기 떄문에, lambda 식으로 전달
        out = self.residuals[0](out, lambda out: self.self_attention(query=out, key=out, value=out, mask=src_mask))
        out = self.residuals[1](out, self.position_ff)
        return out

- `lambda out: self.self_attention(query=out, key=out, value=out, mask=src_mask`은 아래와 동일한 코드이다
```python
def temp(out):
    return self.self_attention(query=out, key=out, value=out, mask=src_mask
```

#### decoder_block.py

In [49]:
import copy
import torch.nn as nn

# from models.layer.residual_connection_layer import ResidualConnectionLayer

class DecoderBlock(nn.Module):
    def __init__(self, self_attention, cross_attention, position_ff, norm):
        super(DecoderBlock, self).__init__()
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.position_ff = position_ff
        self.residuals = [ResidualConnectionLayer(copy.deepcopy(norm)) for _ in range(3)]

    def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask):
        out = tgt
        out = self.residuals[0](out, lambda out: self.self_attention(query=out, key=out, value=out, mask=tgt_mask))
        out = self.residuals[1](out, lambda out: self.cross_attention(query=out, key=encoder_out, value=encoder_out, mask=src_tgt_mask))
        out = self.residuals[2](out, self.position_ff)
        return out

### **embedding**

#### transformer_embedding.py

- `nn.Sequential()`: input으로 준 module에 대해 순차적으로 forward() method를 호출해주는 역할

In [50]:
import torch.nn as nn

class TransformerEmbedding(nn.Module):

    def __init__(self, token_embed, pos_embed):
        super(TransformerEmbedding, self).__init__()
        self.embedding = nn.Sequential(token_embed, pos_embed)

    def forward(self, x):
        """
        파라미터:
        x: (n_batch, seq_len)의 shape을 가진 embedding 처리 전 sentence
        """
        out = self.embedding(x)
        return out

#### token_embedding.py

In [51]:
import math
import torch.nn as nn

class TokenEmbedding(nn.Module):
    def __init__(self, d_embed, vocab_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_embed)
        self.d_embed = d_embed

    def forward(self, x):
        out = self.embedding(x) * math.sqrt(self.d_embed)   # 위치정보에 토큰 정보가 묻히는 것을 방지하기 위해 scailing
        return out

#### positional_encoding.py

In [52]:
import math
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_embed, max_len=256, device=torch.device("cpu")):
        super(PositionalEncoding, self).__init__()
        # encoding: (max_len, d_embed)
        encoding = torch.zeros(max_len, d_embed)
        encoding.requires_grad=False
        # position: (max_len, 1)
        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_embed, 2) * -(math.log(10000.0) / d_embed))
        # 짝수 index에는 sin 함수 사용
        encoding[:, 0::2] = torch.sin(position * div_term)
        # 홀수 index에는 cos 함수 사용
        encoding[:, 1::2] = torch.cos(position * div_term)
        # encoding: (1, max_len, d_embed)
        self.encoding = encoding.unsqueeze(0).to(device)

    def forward(self, x):
        # x: (n_batch, seq_len, d_embed)
        _, seq_len, _ = x.size()
        # positional 정보를 일정한 범위 안의 실수로 제약
        pos_embed = self.encoding[:, :seq_len, :]
        out = x + pos_embed
        return out

- `torch.arange(0, d_emebd, 2)`: 2i
- `-(math.log(10000.0) / d_embed)`: $- \frac{\ln10000}{d_{embed}}$
    - `math.log()`의 base의 default는 $e$
- `torch.arange(0, d_embed, 2) * -(math.log(10000.0) / d_embed)`: $2i * - \frac{ln10000}{d_{embed}}$
- `torch.exp(torch.arange(0, d_embed, 2) * -(math.log(10000.0) / d_embed)`: $exp(2i * - \frac{ln10000}{d_{embed}})$
- 식 전개시 아래와 같이 변형된다

$$exp(-\frac{2i}{d_{embed}} * \ln 10000)$$

$$exp(\ln 10000^{-\frac{2i}{d_{embed}}})$$

$$10000^{-\frac{2i}{d_{embed}}}$$

$$\frac{1}{10000^{\frac{2i}{d_{embed}}}}$$

### **model**

#### encoder.py

In [53]:
import copy
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, encoder_block, n_layer, norm):
        """
        Encoder 인스턴스를 초기화합니다.

        파라미터:
        encoder_block: Encoder Block 인스턴스 1개
        n_layer: Encoder Block의 개수
        """
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([copy.deepcopy(encoder_block) for _ in range(self.n_layer)])
        self.norm = norm


    def forward(self, src, src_mask):
        """
        Encoder Block들을 순서대로 실행하면서,
        이전 block의 output을 이후 block의 input으로 넣습니다.

        파라미터:
        src: Encoder의 input sentence
        src_mask: pad mask

        리턴값:
        out: 마지막 block의 output, 즉, context
        """
        out = src
        for layer in self.layers:
            out = layer(out, src_mask)
        out = self.norm(out)
        return out

- Decoder block마다 LayerNormalization을 적용해주었는데, Decoder이 끝나고도 적용해준 이유  
: 모델의 안정성과 성능을 위한 의도적인 설계로 추가해주었다고 함

#### decoder.py

In [54]:
import copy
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, decoder_block, n_layer, norm):
        super(Decoder, self).__init__()
        self.n_layer = n_layer
        self.layers = nn.ModuleList([copy.deepcopy(decoder_block) for _ in range(self.n_layer)])    # Module들을 Python list에 넣어 보관한다면, 꼭 마지막에 이들을 nn.ModuleList로 wrapping 해줘야 한다.
        self.norm = norm

    def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask):
        """
        파라미터:
        encoder_out: Encoder에서 생성된 최종 output. 즉, context
        tgt_mask: Decoder의 input으로 주어지는 target sentence의 subsequent + pad masking.
                  self-Multi-Head attention layer에서 사용됨
        src_tgt_mask: Self-Multi-Head Attention Layer에서 넘어온 query, Encoder에서 넘어온 key, value 사이의 pad masking
        """
        out = tgt
        for layer in self.layers:
            out = layer(out, encoder_out, tgt_mask, src_tgt_mask)
        out = self.norm(out)
        return out


#### transformer.py

In [55]:
"""
transformer.py
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class Transformer(nn.Module):
    """
    Transformer의 전체 구조에 관련된 클래스입니다.
    """
    def __init__(self, src_embed, tgt_embed, encoder, decoder, generator):
        """
        Transformer 인스턴스를 초기화합니다

        파라미터:
        encoder: 인코더 인스턴스
        decoder: 디코더 인스턴스
        generator: Decoder의 output dimension 변경을 위한 FC Layer
        """
        super(Transformer, self).__init__()
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator

    def encode(self, src, src_mask):
        """
        Transformer의 Encode 부분입니다.

        파라미터:
        src: sentence
        src_mask: pad mask
        """
        out = self.encoder(self.src_embed(src), src_mask)
        return out

    def decode(self, tgt, encoder_out, tgt_mask, src_tgt_mask):
        """
        Transformer의 decode 부분입니다.

        파라미터:
        tgt: sentence
        encoder_out: encoder의 최종 output인 context vector
        tgt_mask: Decoder의 input sentence에 대한 mask
        """
        out = self.decode(self.tgt_embed(tgt), encoder_out, tgt_mask, src_tgt_mask)
        return out

    def forward(self, src, tgt):
        """
        Encoder의 input sentence인 src에 Encoding을 적용하여 context를 만들고
        Decoder의 input sentence인 tgt와 context에 Decoding을 진행합니다.

        파라미터:
        src: 인코더의 input sentence
        tgt: 디코더의 input sentence
        """
        # src_mask: src에 대한 pad mask
        src_mask = self.make_src_mask(src)
        tgt_mask =  self.make_tgt_mask(tgt)
        src_tgt_mask = self.make_src_tgt_mask(src, tgt)
        encoder_out = self.encode(src, src_mask)
        # decode_out: (n_batch, seq_len, d_embed)
        decode_out = self.decode(tgt, encoder_out, tgt_mask, src_tgt_mask)
        # out: (n_batch, seq_len)
        out = self.generator(decode_out)
        out = F.log_softmax(out, dim=-1)    # 마지막 dimension인 len(vocab)에 대해 확률값 계산
        return out, decoder_out


    def make_src_mask(self, src):
        pad_mask = self.make_pad_mask(src, src)
        return pad_mask

    def make_subsequent_mask(query, key):
        """
        (teacher forcing을 위한) subsequent masking
        파라미터:
        query: (n_batch, query_seq_len)
        key: (n_batch, key_seq_len)
        """
        query_seq_len, key_seq_len = query.size(1), key.size(1)
        tril = np.tril(np.ones((query_seq_len, key_seq_len)), k=0).astype('uint8')  # lower triangle without diagonal
        mask = torch.tensor(tril, dtype=torch.bool, requires_grad=False, device=query.device)
        return mask

    def make_tgt_mask(self, tgt):
        """
        Decoder의 subsequent + pad masking
        """
        pad_mask = self.make_pad_mask(tgt, tgt)
        seq_mask = self.make_subsequent_mask(tgt, tgt)
        mask = pad_mask & seq_mask
        return pad_mask & seq_mask

    def make_src_tgt_mask(self, src, tgt):
        """
        Self-Multi-Head Attention Layer에서 넘어온 query, Encoder에서 넘어온 key, value 사이의 pad masking
        """
        pad_mask = self.make_pad_mask(tgt, src)
        return pad_mask

    def make_pad_mask(self, query, key, pad_idx=1):
        """
        pad masking을 만드는 함수

        파라미터:
        query: (n_batch, query_seq_len)
        key: (n_batch, key_seq_len)
        pad_idx=1: <PAD> 토큰을 '1'로 나타낼 것
        """
        query_seq_len, key_seq_len = query.size(1), key.size(1)

        # key_mask의 key_seq_len 차원에 패딩 여부가 표시됨
        key_mask = key.ne(pad_idx).unsqueeze(1).unsqueeze(2)    # (n_batch, 1, 1, key_seq_len)
        key_mask = key_mask.repeat(1, 1, query_seq_len, 1)  # (n_batch, 1, query_seq_len, key_seq_len)

        # query_mask의 query_seq_len 차원에 패딩 여부가 표시됨
        query_mask = query.ne(pad_idx).unsqueeze(1).unsqueeze(2)    # (n_batch, 1, query_seq_len, 1)
        query_mask = query_mask.repeat(1, 1, 1, key_seq_len)  # (n_batch, 1, query_seq_len, key_seq_len)

        # Query와 Key가 둘다 실제 단어여야 Attention 계산이 의미를 가짐
        mask = key_mask & query_mask
        mask.requires_grad = False
        return mask



In [56]:
import numpy as np

# query_seq_len과 key_seq_len이 모두 10일 때 np.tril의 결과
np.tril(np.ones((10, 10)), k=0).astype('uint8') # k=0이면 주대각선 위 모든 원소들을 0로 설정

- `.ne`: not equal. 같지 않으면 True. 즉 pad_idx가 아닌 곳들은 `True`로, pad_idx인 곳들은 `False`로

### **layer**

Self-Attention Code in Pytorch

In [57]:
def calculate_attention(query, key, value, mask):
    """
    Self-Attention을 수행

    파라미터:
    query, key, value: (n_batch, seq_len, d_k)
    mask: (n_batch, seq_len, seq_len)   # pad mask
    """
    d_k = key.shape[-1]
    # attention_score: (n_batch, seq_len, seq_len)
    attention_score = torch.matmul(query, key.transpose(-2, -1))    # Q x K^T
    attention_score = attention_score / math.sqrt(d_k)
    if mask is None:
        attention_score = attention_score.masked_fill(mask==0, -1e9)
    attention_prob = F.softmax(attention_score, dim=-1) # (n_batch, seq_len, seq_len)
    out = torch.matmul(attention_prob, value)   # (n_batcch, seq_len, d_k)
    return out

#### multi_head_attention_layer.py

In [58]:
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, d_model, h, qkv_fc, out_fc):
        """
        파라미터:
        d_model: d_k * h
        h: attenntion 수행 횟수. 논문에선 h=8
        qkv_fc: (d_embed, d_model)의 weight matrix를 갖는 FC Layer 인스턴스
                # deepcopy로 Q, K, V에 대해 서로 다른 weight를 갖고 운용할 것

        out_fc: (d_model, d_embed)의 weight matrix를 갖는 attention 계산 이후 거쳐가는 FC Layer 인스턴스
        """
        super(MultiHeadAttentionLayer, self).__init__()
        self.d_model = d_model
        self.h = h
        self.q_fc = copy.deepcopy(qkv_fc)   # (d_embed, d_model)
        self.k_fc = copy.deepcopy(qkv_fc)   # (d_embed, d_model)
        self.v_fc = copy.deepcopy(qkv_fc)   # (d_embed, d_model)
        self.out_fc = out_fc    # (d_model, d_k)

    def calculate_attention(query, key, value, mask):
        """
        분리된 각 head에 대해서, Self-Attention을 수행

        파라미터:
        query, key, value: (n_batch, h, seq_len, d_k)
        mask: (n_batch, 1, seq_len, seq_len)   # pad mask
        """
        d_k = key.shape[-1]
        # attention_score: (n_batch,h, seq_len, seq_len)
        attention_score = torch.matmul(query, key.transpose(-2, -1))    # Q x K^T
        attention_score = attention_score / math.sqrt(d_k)
        if mask is None:
            attention_score = attention_score.masked_fill(mask==0, -1e9)
        attention_prob = F.softmax(attention_score, dim=-1) # (n_batch, h, seq_len, seq_len)
        out = torch.matmul(attention_prob, value)   # (n_batcch, h, seq_len, d_k)
        return out

    def forward(self, *args, query, key, value, mask=None):
        """
        파라미터:
        query, key, value: (n_batch, seq_len, d_embed)
        mask: (n_batch, seq_len, seq_len)

        리턴값:
        out: (n_batch, h, seq_len, d_k)
        """
        n_batch = query.size(0)

        def transform(x, fc):
            """
            파라미터
            x: (n_batch, seq_len, d_embed)
            fc: (d_embed, d_model)의 weight matrix를 갖는 FC Layer
            """
            # out: (n_batch, seq_len, d_model)
            out = fc(x)
            # out: (n_batch, seq_len, h, d_k)
            out = out.view(n_batch, -1, self.h, self.d_model//self.h)
            # out: (n_batch, h, seq_len, d_k)
            out = out.transpose(1, 2)
            return out

        # query, key, value: (n_batch, h, seq_len, d_k)
        query = transform(query, self.q_fc)
        key = transform(key, self.k_fc)
        value = transform(value, self.v_fc)

        # out: (n_batch, h, seq_len, d_k)
        out = self.calculate_attention(query, key, value, mask)
        # out: (n_batch, seq_len, h, d_k)
        out = out.transpose(1, 2)
        # out: (n_batch, seq_len, d_model)
        out = out.contiguous().view(n_batch, -1, self.d_model)
        # out: (n_batch, seq_len, d_embed)
        out = self.out_fc(out)
        return out

#### position_wise_feed_forward_layer.py

In [59]:
import torch.nn as nn

class PositionWiseFeedForwardLayer(nn.Module):
    def __init__(self, fc1, fc2):
        """
        파라미터:
        fc1: (d_embed, d_ff)의 weight matrix를 갖는 FC Layer
        fc2: (d_ff, d_embed)의 weight matrix를 갖는 FC Layer
        """
        super(PositionWiseFeedForwardLayer, self).__init__()
        self.fc1 = fc1
        self.relu = nn.ReLU()
        self.fc2 = fc2

    def forward(self, x):
        out = x
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        return out

#### residual_connection_layer.py

In [60]:
import torch.nn as nn

class ResidualConnectionLayer(nn.Module):
    def __init__(self, norm):
        super(ResidualConnectionLayer, self).__init__()
        self.norm = norm

    def forward(self, x, sub_layer):
        """
        파라미터:
        x: (n_batch, seq_len, d_embed)크기의 input
        sub_layer: Encoder Block의 sub layer
        """
        out = x
        out = sub_layer(out)
        out = out + x
        out = self.norm(out)
        return out

### **build model**

In [61]:
import torch
import torch.nn as nn

# from models.model.transformer import Transformer
# from models.model.encoder import Encoder
# from models.model.decoder import Decoder
# from models.block.encoder_block import EncoderBlock
# from models.block.decoder_block import DecoderBlock
# from models.layer.multi_head_attention_layer import MultiHeadAttentionLayer
# from models.layer.position_wise_feed_forward_layer import PositionWiseFeedForwardLayer
# from models.embedding.transformer_embedding import TransformerEmbedding
# from models.embedding.token_embedding import TokenEmbedding
# from models.embedding.positional_encoding import PositionalEncoding

def build_model(src_vocab_size,
                tgt_vocab_size,
                device=torch.device("cpu"),
                max_len=256,
                d_embed=512,
                n_layer=6,
                d_model=512,
                h=8,
                d_ff=2048,
                dr_rate=0.1,
                norm_eps=1e-5):
    import copy
    copy = copy.deepcopy
    src_token_embed = TokenEmbedding(d_embed=d_embed,
                                     vocab_size=src_vocab_size)

    tgt_token_embed = TokenEmbedding(d_embed=d_embed,
                                     vocab_size=tgt_vocab_size)

    pos_embed = PositionalEncoding(d_embed=d_embed,
                                   max_len=max_len,
                                   device=device)

    src_embed = TransformerEmbedding(token_embed = src_token_embed,
                                     pos_embed = copy(pos_embed))

    tgt_embed = TransformerEmbedding(token_embed=tgt_token_embed,
                                     pos_embed=copy(pos_embed))

    attention = MultiHeadAttentionLayer(d_model=d_model,
                                        h=h,
                                        qkv_fc=nn.Linear(d_embed, d_model),
                                        out_fc = nn.Linear(d_model, d_embed))

    position_ff = PositionWiseFeedForwardLayer(fc1 = nn.Linear(d_embed, d_ff),
                                               fc2 = nn.Linear(d_ff, d_embed))

    # Layer Normalization: Encoder block과 Decoder block의 각 layer마다 수행함
    norm = nn.LayerNorm(d_embed, eps=norm_eps)

    encoder_block = EncoderBlock(self_attention=copy(attention),
                                 position_ff = copy(position_ff),
                                 norm=copy(norm))

    decoder_block = DecoderBlock(self_attention = copy(attention),
                                 cross_attention = copy(attention),
                                 position_ff = copy(position_ff),
                                 norm=copy(norm))

    encoder = Encoder(encoder_block=encoder_block,
                      n_layer=n_layer,
                      norm=copy(norm))

    decoder = Decoder(decoder_block=decoder_block,
                      n_layer=n_layer,
                      norm=copy(norm))

    generator = nn.Linear(d_model, tgt_vocab_size)

    model = Transformer(src_embed=src_embed,
                        tgt_embed=tgt_embed,
                        encoder=encoder,
                        decoder=decoder,
                        generator=generator).to(device)

    model.device=device

    return model