# Transformer

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

# Input Layer

In [2]:
input_sentence = "I have a dream"
tokenize_input = ["<sos>","I", "have", "a", 'dream', "<eos>"]
input_idx = [1,5,34,7,45,2]
target_idx = [1,5,34,7,45]

input_tensor = torch.tensor(input_idx)
target_tensor = torch.tensor(target_idx)
input_tensor

tensor([ 1,  5, 34,  7, 45,  2])

In [3]:
class Config():
    max_position_dim = 20
    dim_token_emb = 10
    num_dict = 100
    num_head = 2
    hidden_size = 50
config = Config()

In [4]:
class Embedding(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.token_embeddings = nn.Embedding(config.num_dict, config.dim_token_emb)
        self.position_embeddings = nn.Embedding(config.max_position_dim, config.dim_token_emb)

    def forward(self, input_ids):
        position_len = input_ids.size(0)
        position_ids = torch.arange(position_len, dtype=torch.long).unsqueeze(0)

        word_emb = self.token_embeddings(input_ids)
        pos_emb = self.position_embeddings(position_ids)

        embeddings = word_emb + pos_emb

        return embeddings

In [5]:
embedding_layer = Embedding(config)
emb = embedding_layer(input_tensor)
embedded_size=emb.size(2)
embedded_size

10

In [6]:
def scaled_dot_product_attention(query, key, value, mask=None):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(dim_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    weights = F.softmax(scores, dim=-1)

    return weights.bmm(value)

In [7]:
class Attention(nn.Module):
    def __init__(self, embedded_size, head_dim):
        super().__init__()

        self.q_linear = nn.Linear(embedded_size, head_dim)
        self.k_linear = nn.Linear(embedded_size, head_dim)
        self.v_linear = nn.Linear(embedded_size, head_dim)

    def forward(self, x, mask=None):
        q = self.q_linear(x)
        k = self.k_linear(x)
        v = self.v_linear(x)

        self_att_result = scaled_dot_product_attention(q,k,v,mask)

        return self_att_result

In [8]:
class Multi_Attention(nn.Module):
    def __init__(self, embedded_size, num_head):
        super().__init__()

        self.head_dim = int(embedded_size//num_head)
        self.heads = nn.ModuleList([Attention(embedded_size, self.head_dim) for _ in range(num_head)])

    def forward(self, x, mask = None):
        atts = torch.cat([h(x, mask) for h in self.heads], dim=-1)
        added_resout = atts + x

        normal = nn.BatchNorm1d(added_resout.size(1))
        add_norm_result = normal(added_resout)

        return add_norm_result

In [9]:
class TransformEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.multihead = Multi_Attention(config.dim_token_emb, config.num_head)
        self.lin1 = nn.Linear(config.dim_token_emb, config.hidden_size)
        self.act = nn.GELU()
        self.lin2 = nn.Linear(config.hidden_size, config.dim_token_emb)

    def forward(self, x):
        add_norm_result = self.multihead(x)
        fnn_out = self.lin2(self.act(self.lin1(add_norm_result)))
        final_feature = fnn_out + add_norm_result
        return final_feature


In [10]:
encoder = TransformEncoder(config)
encoder_emb = encoder(emb)
encoder_emb.size()

torch.Size([1, 6, 10])

In [11]:
encoder_emb

tensor([[[-7.1782e-01,  5.6422e-01,  1.0185e+00,  2.3639e-01, -2.2755e+00,
          -9.6843e-01,  2.6404e-01, -8.5887e-01,  1.6938e+00,  2.2063e-01],
         [-1.5367e+00, -6.7713e-02,  9.5526e-01, -1.1703e+00,  4.4261e-01,
          -1.2672e+00,  1.3898e+00, -1.5787e-01,  1.0250e+00,  1.0717e-01],
         [-1.0632e+00,  8.0510e-01, -1.4137e+00,  1.9267e-01, -8.7029e-02,
           1.1389e+00,  3.4189e-01,  1.2672e+00,  6.8516e-01, -1.5228e+00],
         [-1.8357e+00, -9.3293e-01,  3.2706e-01,  6.7380e-01, -1.5521e+00,
          -5.0011e-01,  1.0309e+00,  2.0341e-01,  4.5301e-01,  1.3568e+00],
         [ 1.3247e+00,  6.2438e-01, -1.7641e+00, -7.6097e-01,  9.4227e-01,
          -7.9003e-01,  1.1809e+00,  9.4054e-01, -6.3874e-01, -2.1180e-01],
         [ 1.2781e+00, -9.5282e-02,  1.1536e+00,  6.4716e-02, -8.2029e-02,
          -2.4706e+00, -5.2237e-01,  2.3329e-01,  1.1082e-01,  2.0039e-03]]],
       grad_fn=<AddBackward0>)

### Do it your-self!

# Decoder

In [12]:
class Decoder_Attention(nn.Module):
    def __init__(self, embedded_size, head_dim):
        super().__init__()

        self.q_linear = nn.Linear(embedded_size, head_dim)
        self.k_linear = nn.Linear(embedded_size, head_dim)
        self.v_linear = nn.Linear(embedded_size, head_dim)

    def forward(self, encoder_emb, decoder_emb):
        q = self.q_linear(decoder_emb)
        k = self.k_linear(encoder_emb)
        v = self.v_linear(encoder_emb)

        self_att_result = scaled_dot_product_attention(q,k,v)

        return self_att_result

In [13]:
class Decoder_Multi_Attention(nn.Module):
    def __init__(self, embedded_size, num_head):
        super().__init__()

        self.head_dim = int(embedded_size//num_head)
        self.heads = nn.ModuleList([Decoder_Attention(embedded_size, self.head_dim) for _ in range(num_head)])

    def forward(self, encoder_emb, decoder_emb):
        atts = torch.cat([h(encoder_emb, decoder_emb) for h in self.heads], dim=-1)
        added_resout = atts + decoder_emb

        normal = nn.BatchNorm1d(added_resout.size(1))
        add_norm_result = normal(added_resout)

        return add_norm_result

In [14]:
class TransformDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mask_multihead = Multi_Attention(config.dim_token_emb, config.num_head)
        self.decoder_multihead = Decoder_Multi_Attention(config.dim_token_emb, config.num_head)

        self.lin1 = nn.Linear(config.dim_token_emb, config.hidden_size)
        self.act = nn.GELU()
        self.lin2 = nn.Linear(config.hidden_size, config.dim_token_emb)


    def forward(self, encoder_emb, ouputs_emb, mask):

        mask_result = self.mask_multihead(ouputs_emb, mask)

        decoder_norm_result = self.decoder_multihead(encoder_emb, mask_result)

        fnn_out = self.lin2(self.act(self.lin1(decoder_norm_result)))
        final_feature = fnn_out + decoder_norm_result

        return final_feature

In [15]:
def make_trg_mask( trg):

    trg_pad_mask = (trg != 0).unsqueeze(0).unsqueeze(1)
    trg_len = trg.size(0)
    trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).int()
    trg_mask = trg_pad_mask & trg_sub_mask

    return trg_mask

In [16]:
embedding_layer = Embedding(config)
trg_emb = embedding_layer(target_tensor)
decoder = TransformDecoder(config)
decoder_output = decoder(encoder_emb, trg_emb, make_trg_mask(target_tensor))
decoder_output

tensor([[[-1.0912, -0.3218, -0.5308,  1.0944,  0.9801, -0.5948,  1.7363,
           0.3040,  0.0286, -2.4114],
         [-0.7211, -0.8557, -0.6374,  0.5509, -0.2350,  0.5361,  1.6870,
          -0.1681,  0.6900, -2.3676],
         [-1.9349,  0.5703,  2.1938,  0.2725, -0.2195, -0.0857,  0.5558,
           0.7392, -0.5505, -0.5934],
         [-1.2148, -0.7244,  2.0587,  1.0142, -0.4511, -0.2802,  0.8883,
          -0.4689, -0.3711, -0.5686],
         [-1.4425,  1.5077, -0.7445,  0.9362,  1.3313, -1.9959,  0.2558,
           1.6181, -0.6208, -0.1947]]], grad_fn=<AddBackward0>)

In [17]:
decoder_output.shape

torch.Size([1, 5, 10])

# Transformer

In [18]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.embedding_layer = Embedding(config)
        self.encoder = TransformEncoder(config)
        self.decoder = TransformDecoder(config)
        self.lin = nn.Linear(config.dim_token_emb, config.dim_token_emb)


    def forward(self, source, target):
        src_emb = self.embedding_layer(source)
        trg_emb = self.embedding_layer(target)

        encoder_emb = self.encoder(emb)
        decoder_emb = self.decoder(encoder_emb, trg_emb, make_trg_mask(target))
        output =  torch.softmax(self.lin(decoder_emb),dim=-1)

        return output

    def make_trg_mask(self, trg):

        trg_pad_mask = (trg != 0).unsqueeze(0).unsqueeze(1)
        trg_len = trg.size(0)
        trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).int()
        trg_mask = trg_pad_mask & trg_sub_mask

        return trg_mask


In [19]:
transformer = Transformer(config)

In [20]:
output = transformer(input_tensor, target_tensor)
print(output.shape)
output

torch.Size([1, 5, 10])


tensor([[[0.0474, 0.1468, 0.0896, 0.0633, 0.1248, 0.0795, 0.0196, 0.0529,
          0.1699, 0.2060],
         [0.0598, 0.0592, 0.0764, 0.0705, 0.0693, 0.0879, 0.3266, 0.1419,
          0.0411, 0.0673],
         [0.0972, 0.0495, 0.0499, 0.0472, 0.0786, 0.1216, 0.2386, 0.1413,
          0.0717, 0.1043],
         [0.0563, 0.0525, 0.1111, 0.0408, 0.1222, 0.0676, 0.3066, 0.1098,
          0.0489, 0.0840],
         [0.0630, 0.0997, 0.0903, 0.0733, 0.1625, 0.2256, 0.0476, 0.0661,
          0.0568, 0.1152]]], grad_fn=<SoftmaxBackward0>)