<a href="https://colab.research.google.com/github/dongwon0002/Paper/blob/main/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [27]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    """
      Multi-head self-attention 구현
      1. Query, Key, Value 생성 heads로 나누기 전에 (, embed_size)-> (, embed_size)
      2. 가지도록 nn.Linear(embed_size, embed_size)로 만들고, head의 수에 따라 나누어 준다
      3. query, key, value를 받아서 (문장수, 단어수, heads, embed_size//heads)를 reshape
      4. einsum을 이용하여 Query*key를 구현
      5. softmax(), dim=3으로 하나의 query에 대해 모든 key의 점수를 구하기
      6. softmax 행렬에 value 곱
      7. WO를 곱하여 출력
    """
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        #embed size(d_model)/heads가 정수일 경우만 진행
        assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"

        #Query, Key, Value생성 (head_dim = d_k = d_v = d_model//heads)
        self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out = nn.Linear(self.embed_size, self.embed_size)

    def forward(self, values, keys, query, mask):
      # N = num_train
      N = query.shape[0]
      value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

      ### 1.
      values = self.values(values)
      keys = self.keys(keys)
      queries = self.queries(query)

      ### 2., 3.
      values = values.reshape(N, value_len, self.heads, self.head_dim)
      keys = keys.reshape(N, key_len, self.heads, self.head_dim)
      queries = query.reshape(N, query_len, self.heads, self.head_dim)

      ### 4.
      energy = torch.einsum("nqhd, nkhd -> nhqk",queries, keys)

      if mask is not None:
        energy = energy.masked_fill(mask == 0, float("-1e9"))

      ### 5.
      attention = torch.softmax(energy / (self.head_dim**(1/2)), dim=3)

      ### 6.
      out = torch.einsum("nhql, nlhd->nqhd",attention, values).reshape(N,query_len, self.heads*self.head_dim)

      ### 7.
      out = self.fc_out(out)
      return out

class TransformerBlock(nn.Module):
    """
      SelfAttention블럭을 불러오고, layernorm, feedforward network 구성
      encoder transformer 구조
      하는일: query, key, value가 될 tensor를 받아 가중치로 query, key, value로 구성
      residual connection과 layer norm적용, dropout 적용
      이후 feedforward network거치고 다시layer norm, dropout 적용
      feedforward expansion만큰 feedforward network의 유닛 개수를 늘렸다가 줄인다
    """
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
            )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out


class PositionalEncoding(nn.Module):
    """
      입력을 받았을 경우 positional encoding 값을 봔환하여 이후 embed tensor와 더함
      초기 입력을 받으면 max_length와 embed_size로 텐서를 만들고, 삼각함수로 값을 채운다음
      입력되는 데이터의 문장길이에 따 slicing 하여 positional encoding 값을 반환
    """
    def __init__(self,embed_size, max_length, device):
        super(PositionalEncoding, self).__init__()
        self.device = device

        self.encoding = torch.zeros(max_length, embed_size, device=device)
        self.encoding.requires_grad = False

        pos = torch.arange(0,max_length, device=device).float().unsqueeze(1)
        _2i = torch.arange(0, embed_size, step=2, device=device).float()

        self.encoding[:, 0::2] = torch.sin(pos/(10000**(_2i/embed_size)))
        self.encoding[:, 1::2] = torch.cos(pos/(10000**(_2i/embed_size)))

    def forward(self, x):
        batch_size, seq_length = x.size()
        encoding = self.encoding[:seq_length, :].unsqueeze(0).expand(batch_size, seq_length, -1)
        return encoding


class Encoder(nn.Module):
    """
      encoder 블록으로 초기 입력으로 embeding look up table를 구성하고 이와 positional encoding
      값을 더하고 dropout 적용하고 이를 입력으로 transformer block을 layer의 개수만큼 transformer block
      을 stack하여 구성하고, forward과정에서는 입력을 word embeding, positional encoding 시키고 query,
      key, value를 구성하도록 transformer block의 입력으로 준다.
    """
    def __init__(self,
                 src_vacab_size,
                 embed_size,
                 num_layers,
                 heads,
                 device,
                 forward_expansion,
                 dropout,
                 max_length
                ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = PositionalEncoding(embed_size, max_length, device)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout=dropout,
                                 forward_expansion=forward_expansion) for _ in range(num_layers)
                ]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        out = self.dropout(self.word_embedding(x).expand(N,seq_length,-1) + self.position_embedding(x))

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out


class DecoderBlock(nn.Module):
    """
      decoder block은 타겟입력은 SelfAttention block으로 받고 이후 encoder와 연결되는 부분은
      transformer block을 사용하고 입력을 다르게 하여 구현
    """
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out


class Decoder(nn.Module):
    """
      Decoder시 encoder와 같이 embedding을 불러오고, num_layer만큼 decoder block로 스택한다
      encoder와 비슷
    """
    def __init__(
                self,
                trg_vocab_size,
                embed_size,
                num_layers,
                heads,
                forward_expansion,
                dropout,
                device,
                max_length
              ):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = PositionalEncoding(embed_size, max_length, device)
        self.layers = nn.ModuleList(
            [DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
             for _ in range(num_layers)]
        )
        self.dropout = nn.Dropout(dropout)
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        x = self.dropout(self.word_embedding(x).expand(N,seq_length,-1) + self.position_embedding(x))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)
        out = self.fc_out(x)
        return out

class Transformer(nn.Module):
    """
      transformer로 모든 class를 유기적으로 연결시킨다
      초기 입력 파라미터를 받고, encoder와 decoder를 만든다
      mask도 encoder와 decoder의 맞게 만들어 주고
      입력, 입력 마스크를 입력으로 encoder를 돌리고, encoder의 출려과,decoder 마스크를
      decoder에 넣어 출력을 한다.
    """
    def __init__(self,
                 src_vocab_size,
                 trg_vocab_size,
                 src_pad_idx,
                 trg_pad_idx,
                 embed_size=512,
                 num_layers=6,
                 forward_expansion=4,
                 heads=8,
                 dropout=0,
                 device="cpu",
                 max_length=100
                ):
        super(Transformer, self).__init__()

        self.encoder = Encoder(
              src_vocab_size,
              embed_size,
              num_layers,
              heads,
              device,
              forward_expansion,
              dropout,
              max_length)
        self.decoder = Decoder(
              trg_vocab_size,
              embed_size,
              num_layers,
              heads,
              forward_expansion,
              dropout,
              device,
              max_length
        )
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        ## (N ,1, 1, src_len)의 차원으로 unsqueeze
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
      N, trg_len = trg.shape
      trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
          N, 1, trg_len, trg_len
      )
      return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out

In [28]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
        device
    )
    trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

    src_pad_idx = 0
    trg_pad_idx = 0
    src_vocab_size = 10
    trg_vocab_size = 10
    model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx,embed_size=8, device=device).to(
        device
    )
    out = model(x, trg[:, :-1])
    print(out.shape)

cuda
torch.Size([2, 7, 10])
