In [1]:
import torch
from torch import nn

from tokenizers import Tokenizer

import info
from data import CustomDataset, collate_fn
from layers import EmbeddingWithPosition, EncoderLayer, DecoderLayer
from sublayers import MultiHeadAttention, LayerLorm, PositionWiseFeedForward
from torch.utils.data import DataLoader

In [2]:
tokenizer_path = "../data/ende_WMT14_Tokenizer.json"
tokenizer = Tokenizer.from_file(tokenizer_path)
vocab_size = tokenizer.get_vocab_size()

In [3]:
src_train_data_path = "../data/test/test_en.txt"
tgt_train_data_path = "../data/test/test_de.txt"
training_dataset = CustomDataset(tokenizer=tokenizer, src_path=src_train_data_path, tgt_path=tgt_train_data_path)

data tokenizing & loading: 2737it [00:00, 4636.18it/s]


In [4]:
train_dataloader = DataLoader(dataset=training_dataset, batch_size=6, shuffle=True, collate_fn=collate_fn)

In [5]:
src_data, tgt_data, src_len, tgt_len = next(iter(train_dataloader))

In [6]:
src_data.shape

torch.Size([6, 42])

In [7]:
tgt_data.shape

torch.Size([6, 42])

In [8]:
d_model = info.base_hyper_params['d_model']
shared_parameter = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model, padding_idx=info.PAD)
emb_layer = EmbeddingWithPosition(vocab_size=vocab_size, pos_max_len=info.max_len, embedding_dim=info.base_hyper_params['d_model'],
                                  drop_rate=info.base_hyper_params['dropout_rate'], shared_parameter=shared_parameter)

In [9]:
emb = emb_layer(src_data)

In [10]:
mha = MultiHeadAttention(head=8, d_model=d_model, d_k=info.base_hyper_params['d_k'], d_v=info.base_hyper_params['d_v'], is_masked=False)

In [11]:
def get_mask(x):
    mask = torch.zeros(x.shape).to(info.device)
    mask[x != 0] = 1. #N, L -> padding = 0, others = 1
    mask_output = (torch.bmm(mask.unsqueeze(2), mask.unsqueeze(1)) == 0) #N, L, L
    return mask, mask_output

In [12]:
pad_mask, mask_info = get_mask(src_data)

In [13]:
attn = mha(emb, emb, emb, mask_info)

In [14]:
residual_connection = attn+emb

In [15]:
residual_connection[0]

tensor([[-43.6471,   6.5885,  14.8586,  ...,  43.5617, -28.3112,  -6.9557],
        [-14.7832, -37.1398,  13.9888,  ...,   8.8846,  -3.9687,   0.1097],
        [ 12.4530,  11.3051,  -3.7756,  ...,   1.7943,  -4.7080,  23.5679],
        ...,
        [  9.7276, -20.4963, -42.4083,  ..., -11.8166, -19.7441, -47.0385],
        [  3.2305,  23.4806,  14.3141,  ...,  10.5617,  -1.6362, -25.9419],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],
       grad_fn=<SelectBackward0>)

In [16]:
ln_layer = LayerLorm(d_model)

In [17]:
ff_input = ln_layer(residual_connection)

In [26]:

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.inner_layer = nn.Linear(d_model, d_ff)
        self.relu = nn.ReLU()
        self.outer_layer = nn.Linear(d_ff, d_model)
        
    def forward(self, x, mask_info):
        """
        **INPUT SHAPE**
        x -> N, L, d_model
        """
        mask_info = (mask_info == 0)
        inner_output = self.relu(self.inner_layer(x).masked_fill(mask_info.unsqueeze(-1), 0))
        outer_output = self.outer_layer(inner_output).masked_fill(mask_info.unsqueeze(-1), 0) #N, L, d_model
        return outer_output
    
    def initialization(self):
        nn.init.xavier_uniform_(self.inner_layer.weight)
        nn.init.xavier_uniform_(self.outer_layer.weight)

In [27]:
PWFFN_layer = PositionWiseFeedForward(d_model=d_model, d_ff=info.base_hyper_params['d_ff'])

In [28]:
ff_output = PWFFN_layer(ff_input, pad_mask)

In [33]:
sum(pad_mask[1])

tensor(21.)

In [38]:
ff_output

tensor([[[ 1.2204e-01, -5.8269e-03, -1.1710e-01,  ..., -2.0872e-01,
           5.3894e-02,  4.2081e-01],
         [-2.6326e-01, -1.4075e-01,  9.9380e-02,  ...,  1.0107e-01,
          -2.3084e-02,  3.9233e-01],
         [-9.9831e-03,  5.6520e-01, -3.6336e-01,  ..., -1.9493e-01,
          -2.3575e-01,  3.2691e-01],
         ...,
         [-3.0033e-01,  1.1433e-01, -2.2971e-01,  ..., -4.5300e-01,
           2.2035e-01,  2.7764e-01],
         [ 2.3948e-03,  2.9507e-01, -2.6200e-01,  ..., -1.3607e-01,
           1.2082e-01,  2.4062e-01],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[-3.8747e-03, -1.2472e-01, -1.1271e-01,  ..., -1.4071e-02,
           2.5923e-01,  2.5664e-01],
         [-4.4890e-01, -2.5148e-01, -1.5256e-02,  ...,  8.1923e-02,
          -7.6544e-02, -4.1019e-03],
         [-2.7214e-01, -3.9339e-01, -4.2240e-01,  ..., -1.6787e-01,
           2.4117e-01, -1.5117e-01],
         ...,
         [ 0.0000e+00,  0

In [33]:
ff_output.masked_fill((ff_input == 0), 0)[0]

tensor([[-0.0378,  0.2440, -0.1261,  ..., -0.0192, -0.2499, -0.0563],
        [ 0.0783, -0.0004, -0.1130,  ..., -0.0748, -0.2160,  0.0278],
        [-0.1758, -0.1294,  0.1887,  ..., -0.2898, -0.0820, -0.1759],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       grad_fn=<SelectBackward0>)