# Library

In [2]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Transformer

reference: [Transformer](https://wikidocs.net/31379) <br>
paper: [Transformer](https://arxiv.org/pdf/1706.03762)

## Implementation

### Positional Encoding

In [19]:
class PositionalEncoding(nn.Module):
    def __init__(self, seq_len: int, d_model: int):
        super().__init__()
        self.seq_len = seq_len
        self.d_model = d_model

        self.embedding = self.positional_encoding(self.seq_len, self.d_model)
    
    def get_angles(self, seq_len, i, d_model):
        angles = 1 / (10000**(2*(i//2)/d_model))

        return seq_len * angles
    
    def positional_encoding(self, seq_len, d_model):
        angle_rads = self.get_angles(
            seq_len=torch.arange(seq_len, dtype=torch.float32).unsqueeze(1),
            i=torch.arange(d_model, dtype=torch.float32).unsqueeze(0),
            d_model=d_model,
        )

        sines = torch.sin(angle_rads[:, ::2])
        cosines = torch.cos(angle_rads[:, 1::2])

        embedding = torch.zeros(angle_rads.shape)
        embedding[:, ::2] = sines
        embedding[:, 1::2] = cosines

        return embedding
    
    def forward(self, x):
        # x: batch, seq_len, dim
        # embedding: seq_len, dim
        return x + self.embedding[:x.size(1)]

In [20]:
x = torch.randn((32, 20, 512))
positional_encodng = PositionalEncoding(20, 512)
positional_encodng(x)

tensor([[[ 1.4819e+00,  2.9181e+00, -4.2600e-01,  ...,  5.4064e-01,
          -1.0419e+00, -1.3650e+00],
         [ 1.5785e+00,  8.8602e-01, -5.9404e-01,  ...,  1.7294e+00,
           1.7922e+00,  1.8669e+00],
         [ 1.1808e+00, -8.4693e-01,  1.1339e+00,  ...,  1.2836e+00,
          -1.3964e+00,  9.3487e-01],
         ...,
         [ 8.5370e-02,  1.0630e+00, -1.6198e-01,  ...,  2.3108e+00,
           1.0211e+00,  2.5693e+00],
         [ 6.9371e-01,  1.6353e+00,  4.7266e-01,  ...,  2.5780e+00,
           1.5265e+00, -3.4820e-01],
         [ 9.0398e-01,  2.8847e-01, -9.7043e-01,  ...,  1.3065e+00,
           2.4085e-01,  7.7492e-01]],

        [[-9.2326e-01,  1.4857e+00, -2.5263e+00,  ...,  8.4135e-01,
          -2.2634e-02,  1.6392e+00],
         [ 1.1847e+00, -4.9983e-01,  2.1163e+00,  ..., -6.9308e-01,
          -2.7996e-01,  8.5579e-01],
         [-3.5988e-01, -1.3606e-01,  8.6927e-01,  ...,  2.3313e-01,
           1.0369e+00, -7.9081e-02],
         ...,
         [-2.5520e+00, -4

### Attention

#### scaled_dot_product_attention

In [56]:
def scaled_dot_product_attention(query, key, value, mask=None):
    matmul_qk = query @ key.transpose(-2, -1)
    depth = query.shape[-1]

    logits = matmul_qk / np.sqrt(depth)

    if mask is not None:
        logits += (mask * 1e-9)

    attention_weights = F.softmax(logits, dim=1)
    output = attention_weights @ value

    return output, attention_weights

In [57]:
query = torch.randn((32, 20, 512))
scaled_dot_product_attention(query, query, query)

(tensor([[[-1.9979, -0.1066, -0.4298,  ...,  0.7281,  1.2474, -0.1675],
          [ 0.7960,  0.2741,  0.5972,  ...,  1.4414, -1.4739,  0.8064],
          [ 0.6012, -0.6820, -0.5358,  ..., -0.0535,  2.1505,  1.3132],
          ...,
          [-0.1742,  0.0524,  0.7115,  ..., -0.5074, -0.6272, -1.1105],
          [-0.1750, -0.8036, -0.9733,  ...,  0.0575,  1.4971, -1.5334],
          [ 1.3329,  1.6873,  1.3500,  ..., -0.2212,  1.1462, -0.9416]],
 
         [[ 0.8099,  1.6177, -0.3885,  ..., -1.9721, -1.6653, -0.9206],
          [ 0.9991, -2.8307, -1.2456,  ...,  1.4407,  1.0783,  1.5209],
          [ 2.3108, -1.5950,  0.8184,  ..., -0.4678, -0.5187, -0.7061],
          ...,
          [-1.1083,  0.8813,  2.1368,  ..., -0.8860,  2.0806,  0.4568],
          [-0.2779,  1.9584,  0.3380,  ..., -0.2399,  1.5800, -0.8274],
          [-1.1893, -1.6869,  1.2096,  ..., -0.3806,  2.1417,  0.8775]],
 
         [[ 0.8766,  0.7907, -0.9868,  ...,  0.3646, -1.8486, -0.0483],
          [-0.7748,  0.7050,

#### Multi-head Attention

In [124]:
class MultiheadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        assert self.d_model%self.num_heads == 0

        self.depth = self.d_model // self.num_heads

        self.query_dense = nn.Linear(self.d_model, self.d_model)
        self.key_dense = nn.Linear(self.d_model, self.d_model)
        self.value_dense = nn.Linear(self.d_model, self.d_model)

        self.dense = nn.Linear(self.d_model, self.d_model)
    
    def forward(self, inputs: dict):
        query, key, value = inputs.get('query'), inputs.get('key'), inputs.get('value')
        mask = inputs.get('mask')
        batch_size, seq_len = query.shape[:2]

        query = self.query_dense(query) # batch_size, seq_len, dim
        key = self.key_dense(key)
        value = self.value_dense(value)

        query = query.reshape(batch_size, seq_len, self.num_heads, self.depth)  # batch, seq_len, num_heads, depth
        key = key.reshape(batch_size, seq_len, self.num_heads, self.depth)
        value = value.reshape(batch_size, seq_len, self.num_heads, self.depth)

        query = query.permute(0, 2, 1, 3)   # batch, num_heads, seq_len, depth
        key = key.permute(0, 2, 1, 3)
        value = value.permute(0, 2, 1, 3)

        scaled_attention, _ = scaled_dot_product_attention(query, key, value, mask)   # batch, num_heads, seq_len, depth
        scaled_attention = scaled_attention.permute(0, 2, 1, 3)                 # batch, seq_len, num_heads, depth
        concat_attention = scaled_attention.reshape(batch_size, seq_len, self.d_model)  # batch, seq_len, dim

        outputs = self.dense(concat_attention)  # batch, seq_len, dim

        return outputs

### Encoder

In [125]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, d_ff: int, num_heads: int, dropout_ratio: float):    
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio

        self.multi_head_attention = MultiheadAttention(self.d_model, self.num_heads)
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.layer_norm1 = nn.LayerNorm(self.d_model)

        self.ffn = nn.Sequential(
            nn.Linear(self.d_model, self.d_ff),
            nn.ReLU(),
            nn.Linear(self.d_ff, self.d_model),
        )

        self.dropout2 = nn.Dropout(self.dropout_ratio)
        self.layer_norm2 = nn.LayerNorm(self.d_model)

    def forward(self, x, mask=None):
        inputs = {'query': x, 'key': x, 'value': x, 'mask': mask}
        x_multi_head_output = self.multi_head_attention(inputs)
        x_multi_head_output = self.dropout1(x_multi_head_output)
        x = self.layer_norm1(x_multi_head_output + x)
        
        ffn_output = self.ffn(x)
        ffn_output = self.dropout2(ffn_output)
        output = self.layer_norm2(x + ffn_output)

        return output

In [126]:
x = torch.randn(32, 20, 512)

encoder_layer = EncoderLayer(512, 2048, 8, 0.1)
encoder_layer(x)

tensor([[[-1.9848,  0.6702,  1.3187,  ..., -1.8875,  0.7444, -0.6271],
         [-0.2133, -0.5521, -0.4357,  ...,  1.3083, -0.9965,  0.5198],
         [ 0.9016,  0.3068, -0.1425,  ..., -0.2294,  0.3499,  0.8375],
         ...,
         [ 1.6197,  0.8537,  0.2259,  ...,  0.1337, -0.7509,  0.5322],
         [ 1.5149,  1.3291,  1.7559,  ...,  1.4643, -1.3589,  2.1033],
         [-0.7850, -0.8179, -0.9058,  ...,  0.9911,  0.9380,  1.4185]],

        [[-1.8846, -0.8105,  1.5370,  ...,  0.6021, -0.8217, -2.6228],
         [-0.9016,  0.0991,  0.5283,  ..., -0.1668, -1.8993, -0.6681],
         [-1.6052, -0.0443,  0.9234,  ...,  1.5764, -1.0744,  1.1181],
         ...,
         [-0.5053,  0.0063, -2.5317,  ..., -0.6285, -2.0232, -1.3003],
         [-0.4995, -1.2293, -0.7102,  ...,  0.9446,  0.7619,  0.2118],
         [-1.8678,  0.9061, -0.1156,  ..., -0.2425, -1.2721, -0.6108]],

        [[-1.7102,  0.6463, -2.1659,  ...,  0.0577, -0.5304,  0.7939],
         [ 0.6090,  0.9775,  0.4625,  ..., -0

In [127]:
class Encoder(nn.Module):
    def __init__(
        self,
        seq_len: int,
        vocab_size: int,
        num_layers: int,
        d_model: int,
        d_ff: int,
        num_heads: int,
        dropout_ratio: float,
        ):
        super().__init__()
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio

        self.embedding = nn.Embedding(self.vocab_size, self.d_model)
        self.positional_encoding = PositionalEncoding(self.seq_len, self.d_model)
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(self.d_model, self.d_ff, self.num_heads, self.dropout_ratio)
            for _ in range(self.num_layers)
        ])
    
    def forward(self, x, mask=None):
        x = self.embedding(x)
        x *= (self.d_model ** 0.5)
        x += self.positional_encoding(x)

        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x, mask)
        output = x

        return output

In [128]:
x = torch.randint(0, 1000, (32, 20))

encoder = Encoder(20, 1000, 6, 512, 2048, 8, 0.1)
encoder(x)

tensor([[[ 1.8475,  0.0231, -0.0385,  ..., -0.3078, -0.0699, -0.7248],
         [ 1.9553, -0.5205, -0.0569,  ..., -0.3857, -0.1446, -1.7681],
         [ 0.5508,  0.8090, -0.1973,  ...,  0.0668, -0.9554, -1.3088],
         ...,
         [ 1.3308, -1.1035, -0.9705,  ...,  0.1107, -0.8299, -2.2690],
         [ 1.6041,  0.3398,  0.5741,  ...,  0.4815,  0.1334, -2.1437],
         [-0.7404,  0.8648,  0.5233,  ...,  0.1159, -0.4453, -1.6110]],

        [[ 0.3412, -1.1536,  0.4852,  ..., -0.0238, -1.6976, -1.2767],
         [ 1.4794, -0.4740,  1.0270,  ..., -1.1288, -1.7871,  0.1446],
         [ 0.1760,  0.5634,  1.4709,  ..., -0.7118, -1.2055,  0.0260],
         ...,
         [ 0.0421, -0.2462,  0.5400,  ..., -0.6138, -0.7388,  0.9914],
         [-0.6090,  0.6252, -0.3613,  ..., -0.8281, -1.5281,  0.1960],
         [-0.7196, -0.4042,  1.1014,  ..., -0.4013, -1.5679, -0.3405]],

        [[ 0.5150, -0.8768, -0.0494,  ...,  0.0285, -0.1379, -1.7948],
         [ 0.3135, -0.1646,  0.2634,  ..., -0

### Decoder

In [149]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, d_ff, num_heads, dropout_ratio):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio

        self.multi_head_attention1 = MultiheadAttention(self.d_model, self.num_heads)
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.layer_norm1 = nn.LayerNorm(self.d_model)

        self.multi_head_attention2 = MultiheadAttention(self.d_model, self.num_heads)
        self.dropout2 = nn.Dropout(self.dropout_ratio)
        self.layer_norm2 = nn.LayerNorm(self.d_model)

        self.ffn = nn.Sequential(
            nn.Linear(self.d_model, self.d_ff),
            nn.ReLU(),
            nn.Linear(self.d_ff, self.d_model),
        )
        self.dropout3 = nn.Dropout(self.dropout_ratio)
        self.layer_norm3 = nn.LayerNorm(self.d_model)

    def forward(self, x, encoder_out, mask=None):
        inputs = {'query': x, 'key': x, 'value': x, 'mask': mask}
        x_multi_head_output = self.multi_head_attention1(inputs)
        x_multi_head_output = self.dropout1(x_multi_head_output)
        x = self.layer_norm1(x + x_multi_head_output)

        inputs = {'query': x, 'key': x, 'value': x, 'mask': mask}
        x_multi_head_output = self.multi_head_attention2(inputs)
        x_multi_head_output = self.dropout2(x_multi_head_output)
        x = self.layer_norm2(x + x_multi_head_output)

        ffn_output = self.ffn(x)
        ffn_output = self.dropout3(ffn_output)
        output = self.layer_norm3(x + ffn_output)

        return output

In [150]:
x = torch.randn(32, 20, 512)

decoder_layer = DecoderLayer(512, 2048, 8, 0.1)
encoder_out = encoder_layer(x)
decoder_layer(x, encoder_out)

tensor([[[ 0.6483, -0.6845, -0.4586,  ...,  0.0303, -0.1651,  1.7737],
         [ 0.7005, -0.3682, -0.5978,  ..., -0.6435, -0.8065,  0.1445],
         [-0.2216,  0.1596,  2.0457,  ..., -0.1270,  1.4141,  0.0807],
         ...,
         [-0.2865,  1.1048, -0.3875,  ..., -0.5319, -0.7416,  0.5476],
         [ 1.0356,  0.6441,  0.9505,  ...,  0.9825,  0.9149, -0.3902],
         [-1.2154, -0.7707,  0.4412,  ...,  0.2220,  1.2852, -2.4506]],

        [[ 0.5901,  0.6001, -0.5165,  ...,  0.2963,  0.2034,  1.0321],
         [ 0.8078,  0.1995, -0.3824,  ...,  0.9018, -0.7067,  1.0001],
         [ 0.2083, -2.0880,  0.2439,  ...,  1.8097,  0.4219,  0.5539],
         ...,
         [-1.3043,  0.3543, -0.9915,  ...,  0.5974,  1.4838,  1.6512],
         [-0.5786, -0.1079, -1.2314,  ..., -0.3962, -0.2895,  0.5198],
         [ 1.3889, -0.6902, -0.6552,  ...,  0.6637, -0.3890,  0.2125]],

        [[-0.1468,  0.2401,  0.8810,  ...,  1.5274,  1.8057, -1.3392],
         [-0.7990,  1.9811,  0.5768,  ...,  0

In [151]:
class Decoder(nn.Module):
    def __init__(
        self,
        seq_len: int,
        vocab_size: int,
        num_layers: int,
        d_model: int,
        d_ff: int,
        num_heads: int,
        dropout_ratio: float,
        ):
        super().__init__()
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio

        self.embedding = nn.Embedding(self.vocab_size, self.d_model)
        self.positional_encoding = PositionalEncoding(self.seq_len, self.d_model)
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(self.d_model, self.d_ff, self.num_heads, self.dropout_ratio)
            for _ in range(self.num_layers)
        ])

    def forward(self, x, encoder_out, mask=None):
        x = self.embedding(x)
        x *= (self.d_model ** 0.5)
        x += self.positional_encoding(x)

        for decoder_layer in self.decoder_layers:
            x = decoder_layer(x, encoder_out, mask)
        output = x

        return output


In [152]:
x = torch.randint(0, 1000, (32, 20))

decoder = Decoder(20, 1000, 6, 512, 2048, 8, 0.1)
decoder(x, encoder_out)

tensor([[[ 9.7404e-01,  1.3745e-02,  1.3542e+00,  ..., -1.0685e+00,
          -4.3461e-01,  2.5127e-01],
         [ 1.1715e+00,  5.2758e-01,  4.6482e-01,  ...,  8.1149e-02,
           4.0944e-01, -7.3005e-01],
         [ 4.7255e-01,  3.1805e-01,  1.0587e+00,  ..., -5.2677e-01,
          -7.0303e-01, -6.6924e-01],
         ...,
         [-5.1178e-01,  4.9820e-01,  5.7020e-01,  ..., -1.8420e-01,
           2.9207e-01, -4.9416e-01],
         [ 1.2077e+00, -4.2450e-01,  7.2602e-01,  ..., -4.5297e-02,
           1.3392e-01, -4.9377e-01],
         [ 1.1497e+00,  4.7561e-01,  7.0672e-01,  ...,  4.8869e-02,
          -5.4957e-02, -1.4425e-01]],

        [[ 6.7264e-01,  8.7291e-01,  1.0770e+00,  ...,  1.9916e+00,
           1.5692e+00, -4.6445e-01],
         [ 7.0538e-01,  4.7634e-01,  9.8668e-01,  ...,  1.2911e+00,
           1.0658e+00, -7.0589e-01],
         [ 1.8358e+00,  8.0074e-01,  8.4752e-01,  ...,  1.0242e+00,
           1.4352e+00, -5.6544e-01],
         ...,
         [ 1.0219e+00,  8

### Transformer

In [153]:
class Transformer(nn.Module):
    def __init__(
        self,
        seq_len: int,
        encoder_vocab_size: int,
        decoder_vocab_size: int,
        num_layers: int,
        d_model: int,
        d_ff: int,
        num_heads: int,
        dropout_ratio: float,
        ):
        super().__init__()
        self.seq_len = seq_len
        self.encoder_vocab_size = encoder_vocab_size
        self.decoder_vocab_size = decoder_vocab_size
        self.num_layers = num_layers
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio

        self.encoder = Encoder(
            self.seq_len,
            self.encoder_vocab_size,
            self.num_layers,
            self.d_model,
            self.d_ff,
            self.num_heads,
            self.dropout_ratio,
        )
        self.decoder = Decoder(
            self.seq_len,
            self.decoder_vocab_size,
            self.num_layers,
            self.d_model,
            self.d_ff,
            self.num_heads,
            self.dropout_ratio,
        )

        self.linear = nn.Linear(self.d_model, self.decoder_vocab_size)

    def create_padding_mask(self, x):
        padding_mask = torch.where(x==0, 1, 0).unsqueeze(1).unsqueeze(2)

        return padding_mask
    
    def _create_look_ahead_mask(self, x):
        look_ahead_mask = torch.triu(torch.ones(self.seq_len, self.seq_len), diagonal=1)

        return look_ahead_mask
    
    def create_decoder_mask(self, x):
        padding_mask = self.create_padding_mask(x).squeeze(1)  # batch_size, 1, 1, seq_len -> batch_size, 1, seq_len
        look_ahead_mask = self._create_look_ahead_mask(x).unsqueeze(0)  # seq_len, seq_len -> batch_size, seq_len, seq_len
        decoder_mask = torch.max(padding_mask, look_ahead_mask) # batch_size, seq_len, seq_len
        decoder_mask = decoder_mask.unsqueeze(1)    # batch_size, num_heads, seq_len, seq_len

        return decoder_mask
    
    def forward(self, encoder_input, decoder_input):
        encoder_padding_mask = self.create_padding_mask(encoder_input)
        decoder_mask = self.create_decoder_mask(decoder_input)

        encoder_output = self.encoder(encoder_input, encoder_padding_mask)
        decoder_output = self.decoder(decoder_input, encoder_output, decoder_mask)
        output = self.linear(decoder_output)

        return output
        

In [154]:
transformer = Transformer(
    20,
    1000,
    2000,
    6,
    512,
    2048,
    8,
    0.1
)

In [155]:
x_encoder = torch.randint(0, 1000, (32, 20))
x_decoder = torch.randint(0, 2000, (32, 20))

transformer(x_encoder, x_decoder)

tensor([[[-0.0820, -1.1812, -0.1125,  ...,  1.4003, -0.2161,  0.1365],
         [-0.4783, -0.3005, -0.3276,  ...,  0.7966,  0.4404,  0.2628],
         [-0.5069, -0.5645,  0.0715,  ...,  0.9491,  0.0972, -0.2219],
         ...,
         [-0.4835, -0.2690, -0.1206,  ...,  0.5565,  0.0024,  0.0447],
         [-0.0030, -0.3034, -0.4317,  ...,  0.6190,  0.1846,  0.1748],
         [ 0.1121, -0.7533, -0.6052,  ...,  0.9467,  0.6045, -0.1068]],

        [[ 0.0652, -0.0562, -0.4453,  ...,  0.1493,  0.0606, -0.5253],
         [ 0.2705, -0.1263, -0.4444,  ..., -0.3763, -0.2798, -0.1256],
         [ 0.3461, -0.5293, -0.3556,  ..., -0.0823, -0.4584, -0.4155],
         ...,
         [ 0.4573, -0.4650, -0.6154,  ..., -0.0576,  0.1152, -0.4791],
         [ 0.6136, -0.4380, -0.5972,  ...,  0.0552, -0.2022, -0.6156],
         [ 0.4565, -0.0433, -0.8857,  ..., -0.2820, -0.3574, -0.6190]],

        [[ 0.0997, -0.1967,  0.2268,  ...,  0.1672, -0.6713,  0.4421],
         [-0.0685, -0.2879,  0.3876,  ...,  0