# Highly referenced from : https://github.com/hyunwoongko/transformer

<img src="images/transformer.png" width="300" height="150">

##### 1. Transformer = Encoder + Decoder

##### 2. Encoder = positional encoding + <span style='color:blue'> multi-head attention </span> + Layer-norm + positionwise feedforward
##### 3. Decoder = positional encoding + <span style='color:green'> masked </span> multi-head attention + Layer-norm +  <span style='color:blue'> multi-head attention </span> + positionwise feedforward


##### 4. Multi-head attention = layer of multiple (head > 1)  <span style='color:blue'> cross-attention </span> operations
##### 5. <span style='color:green'> Maksed </span> multi-head attention = layer of multiple (head > 1) <span style='color:green'> self-attention </span> operations.
<- 미래의 정보는 사용하면 안되므로 현재와 과저 word를 제외한 나머지 정보는 masking>

# ScaleDotProductAttention: single-head attention
<img src="images/scaled_dot_product_attention.png" width="200" height="100">
<img src="images/attention(eq).png" width="400" height="200">

In [9]:
import torch.nn as nn
class ScaleDotProductAttention(nn.Module):
    """
    compute scale dot product attention

    Query : given sentence that we focused on (decoder)
    Key : every sentence to check relationship with Qeury(encoder)
    Value : every sentence same with Key (encoder)
    """

    def __init__(self):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

        
    def forward(self, q, k, v, mask=None):
        ###################### Hint ######################
        #####  dot product:  1) a@b  2) a.dot(b)     #####
        ####   matrix multiplication: torch.matmul(a,b) ##
        ##  tensor 곱은 맨 마지막 두 dimension만 맞추면 된다
        
        #1. calculate score: scaled dot product, k should be transposed
        #2. generate & apply mask
        #    """
        #    1) mask = torch.ones(length_q, length_q).float() # mask for self_attention
        #    2) mask = (torch.triu(torch.ones(length_q, length_k)) == 1) # mask for encoder-decoder attention
            
        #    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        #    """ 
        #3. apply softmax to get attention weights (sum to 1)
        #4. calculate weighted sum of values (v) 
        ##################################################
        
        d_tensor = k.size()[-1]

        # 1. dot product Query with Key^T to compute similarity 
        score =  (q @ k.transpose(2,3)) / math.sqrt(d_tensor_k) 
        
        # 2. apply masking (opt)
        if mask is not None:
            score = score.masked_fill(mask == 0, -1e-12) 

        # 3. pass them softmax to make [0, 1] range
        score = self.softmax(score)

        # 4. multiply with Value
        v = score @ v
        
        return v, score

# multihead_attention
<img src="images/multihead_attention.png" width="200" height="100">
<img src="images/multihead_att(eq).png" width="400" height="200">

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.attention = ScaleDotProductAttention()
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_concat = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        # 1. dot product with weight matrices
        q, k, v = 

        # 2. split tensor by number of heads
        q, k, v = self.split(q), self.split(k), self.split(v)

        # 3. do scale dot product to compute similarity
        out, attention = self.attention(q, k, v, mask=mask)
        
        # 4. concat and pass to linear layer
        out = self.concat(out)
        out = self.w_concat(out)

        return out

    def split(self, tensor):
        """
        split tensor by number of head

        :param tensor: [batch_size, length, d_model]
        :return: [batch_size, head, length, d_tensor]
        """
        batch_size, length, d_model = tensor.size()

        d_tensor = d_model // self.n_head
        tensor =  tensor

        return tensor

    def concat(self, tensor):
        """
        inverse function of self.split(tensor : torch.Tensor)

        :param tensor: [batch_size, head, length, d_tensor]
        :return: [batch_size, length, d_model]
        """
        batch_size, head, length, d_tensor = tensor.size()
        
        d_model = 
        tensor = 
        return tensor

# positionwise feedforward
<img src="images/positionwise_feedforward(eq).png" width="300" height="150">

In [None]:
class PositionwiseFeedForward(nn.Module):

    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        ###################### Hint ######################
        # d_model: h_dim of x1 & x3
        # hidden: h_dim of x2
        ##################################################
        self.linear1 = 
        self.linear2 = 
        self.relu = 
        self.dropout = 

    def forward(self, x1):
        x2 = self.linear1(x1)
        x2 = self.relu(x2)
        x2 = self.dropout(x2)
        x3 = self.linear2(x2)
        return x3

# positional encoding
<img src="images/position_encoding(eq).png" width="300" height="150">

In [None]:
class PositionalEncoding(nn.Module):
    """
    compute sinusoid encoding.
    """
    def __init__(self, d_model, max_len):
        """
        :param d_model: dimension of model
        :param max_len: max sequence length
        """
        super(PositionalEncoding, self).__init__()

        # same size with input matrix (for adding with input matrix)
        self.encoding = torch.zeros(max_len, d_model) 
        self.encoding.requires_grad = False  # we don't need to compute gradient

        pos = torch.arange(0, max_len)
        pos = pos.float().unsqueeze(dim=1) # 1D => 2D unsqueeze to represent word's position

        _2i = torch.arange(0, d_model, step=2, device=device).float() # step=2 for cos/sin

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        # compute positional encoding to consider positional information of words

    def forward(self, x):
        _, seq_len = x.size() # x: [batch_size, seq_len]

        return self.encoding[:seq_len, :]

# Encoder
##### 2. Encoder = positional encoding + <span style='color:blue'> multi-head attention </span> + Layer-norm + positionwise feedforward
<img src="images/encoder.png" width="300" height="150">

In [None]:
class EncoderLayer(nn.Module):

    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
        self.norm1 = LayerNorm(d_model=d_model)
        self.dropout1 = nn.Dropout(p=drop_prob)

        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm2 = LayerNorm(d_model=d_model)
        self.dropout2 = nn.Dropout(p=drop_prob)

    def forward(self, x, src_mask):
        # x: encoder input
        # src_mask: mask for multi-head attention
        
        # 1. compute self attention
        _x = x # a temporal variable which will be used for residual connection
        x = 
        
        # 2. add and norm
        x = 
        x = 
        
        # 3. positionwise feed forward network
        _x = x # a temporal variable which will be used for residual connection
        x = 
      
        # 4. add and norm
        x = 
        x = 
        return x

In [None]:
class Encoder(nn.Module):

    def __init__(self, enc_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, device):
        super().__init__()
        self.emb = TransformerEmbedding(d_model=d_model,
                                        max_len=max_len,
                                        vocab_size=enc_voc_size,
                                        drop_prob=drop_prob,
                                        device=device)

        self.layers = nn.ModuleList([EncoderLayer(d_model=d_model,
                                                  ffn_hidden=ffn_hidden,
                                                  n_head=n_head,
                                                  drop_prob=drop_prob)
                                     for _ in range(n_layers)])

    def forward(self, x, src_mask):
        x = self.emb(x)

        for layer in self.layers:
            x = layer(x, src_mask)

        return x

# Decoder
##### 3. Decoder = positional encoding + <span style='color:green'> masked </span> multi-head attention + Layer-norm +  <span style='color:blue'> multi-head attention </span> + positionwise feedforward
<img src="images/decoder.png" width="300" height="150">

In [None]:
class DecoderLayer(nn.Module):

    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
        self.norm1 = LayerNorm(d_model=d_model)
        self.dropout1 = nn.Dropout(p=drop_prob)

        self.enc_dec_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
        self.norm2 = LayerNorm(d_model=d_model)
        self.dropout2 = nn.Dropout(p=drop_prob)

        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm3 = LayerNorm(d_model=d_model)
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, dec, enc, trg_mask, src_mask):    
        # dec: decoder input
        # enc: encoder output
        # trg_mask: mask for masked multi-head attention (upper triangle)
        # src_mask: mask for multi-head attention
        
        # 1. compute self attention
        _x = dec
        x = 
        
        # 2. add and norm
        x = 
        x = 

        if enc is not None:
            # 3. compute encoder - decoder attention
            _x = x
            x = 
            
            # 4. add and norm
            x = 
            x = 

        # 5. positionwise feed forward network
        _x = x
        x = 
        
        # 6. add and norm
        x = 
        x = 
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self, dec_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, device):
        super().__init__()
        self.emb = TransformerEmbedding(d_model=d_model,
                                        drop_prob=drop_prob,
                                        max_len=max_len,
                                        vocab_size=dec_voc_size,
                                        device=device)

        self.layers = nn.ModuleList([DecoderLayer(d_model=d_model,
                                                  ffn_hidden=ffn_hidden,
                                                  n_head=n_head,
                                                  drop_prob=drop_prob)
                                     for _ in range(n_layers)])

        self.linear = nn.Linear(d_model, dec_voc_size)

    def forward(self, trg, src, trg_mask, src_mask):
        trg = self.emb(trg)

        for layer in self.layers:
            trg = layer(trg, src, trg_mask, src_mask)

        # pass to LM head
        output = self.linear(trg)
        return output