In [1]:
import torch
from torch import nn

from tokenizers import Tokenizer

import info
from data import CustomDataset, collate_fn
from layers import EmbeddingWithPosition, EncoderLayer, DecoderLayer
from sublayers import MultiHeadAttention, LayerLorm, PositionWiseFeedForward
from torch.utils.data import DataLoader

In [2]:
tokenizer_path = "../data/ende_WMT14_Tokenizer.json"
tokenizer = Tokenizer.from_file(tokenizer_path)
vocab_size = tokenizer.get_vocab_size()

In [3]:
src_train_data_path = "../data/test/test_en.txt"
tgt_train_data_path = "../data/test/test_de.txt"
training_dataset = CustomDataset(tokenizer=tokenizer, src_path=src_train_data_path, tgt_path=tgt_train_data_path)

data tokenizing & loading: 2737it [00:00, 4671.54it/s]


In [4]:
train_dataloader = DataLoader(dataset=training_dataset, batch_size=6, shuffle=True, collate_fn=collate_fn)

In [5]:
src_data, tgt_data, src_len, tgt_len = next(iter(train_dataloader))

In [6]:
src_data.shape

torch.Size([6, 58])

In [7]:
tgt_data.shape

torch.Size([6, 58])

In [8]:
d_model = info.base_hyper_params['d_model']
shared_parameter = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model, padding_idx=info.PAD)
emb_layer = EmbeddingWithPosition(vocab_size=vocab_size, pos_max_len=info.max_len, embedding_dim=info.base_hyper_params['d_model'],
                                  drop_rate=info.base_hyper_params['dropout_rate'], shared_parameter=shared_parameter)

In [9]:
emb = emb_layer(src_data)

In [10]:
mha = MultiHeadAttention(head=8, d_model=d_model, d_k=info.base_hyper_params['d_k'], d_v=info.base_hyper_params['d_v'], is_masked=False)

In [11]:
def get_mask(x):
    mask = torch.zeros(x.shape).to(info.device)
    mask[x != 0] = 1. #N, L -> padding = 0, others = 1
    mask_output = (torch.bmm(mask.unsqueeze(2), mask.unsqueeze(1)) == 0) #N, L, L
    return mask, mask_output

In [12]:
pad_mask, mask_info = get_mask(src_data)

In [13]:
attn = mha(emb, emb, emb, mask_info)

In [14]:
residual_connection = attn+emb

In [15]:
residual_connection[0]

tensor([[ -8.1076,   8.7598,  36.4360,  ...,  -5.2853,  -8.7302,  36.0597],
        [-13.8747,   5.2363,  11.0384,  ...,  18.6591, -14.5472,  49.6891],
        [ 14.3876,  -3.1392, -10.8425,  ...,  12.7769, -41.5699, -22.4045],
        ...,
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],
       grad_fn=<SelectBackward0>)

In [16]:
ln_layer = LayerLorm(d_model)

In [17]:
ff_input = ln_layer(residual_connection)

In [18]:

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.inner_layer = nn.Linear(d_model, d_ff)
        self.relu = nn.ReLU()
        self.outer_layer = nn.Linear(d_ff, d_model)
        
    def forward(self, x, mask_info):
        """
        **INPUT SHAPE**
        x -> N, L, d_model
        """
        mask_info = (mask_info == 0)
        inner_output = self.relu(self.inner_layer(x).masked_fill(mask_info.unsqueeze(-1), 0))
        outer_output = self.outer_layer(inner_output).masked_fill(mask_info.unsqueeze(-1), 0) #N, L, d_model
        return outer_output
    
    def initialization(self):
        nn.init.xavier_uniform_(self.inner_layer.weight)
        nn.init.xavier_uniform_(self.outer_layer.weight)

In [19]:
PWFFN_layer = PositionWiseFeedForward(d_model=d_model, d_ff=info.base_hyper_params['d_ff'])

In [20]:
ff_output = PWFFN_layer(ff_input, pad_mask)

In [28]:
sum(pad_mask[0])

tensor(42.)

In [29]:
ff_output[0][:42]

tensor([[-0.0397, -0.1271, -0.0488,  ...,  0.0168,  0.0755, -0.1220],
        [ 0.0492, -0.1153,  0.1711,  ..., -0.2641,  0.1518,  0.1332],
        [-0.2595, -0.0150,  0.1485,  ...,  0.0032,  0.2474,  0.1573],
        ...,
        [ 0.2332,  0.0788, -0.0571,  ...,  0.0133, -0.1447,  0.1605],
        [ 0.0944, -0.1248,  0.0224,  ..., -0.1638, -0.3185, -0.0559],
        [ 0.2455,  0.3135,  0.0048,  ...,  0.1833, -0.0470,  0.1489]],
       grad_fn=<SliceBackward0>)

In [23]:
ff_output.masked_fill((ff_input == 0), 0)[0]

tensor([[-0.0397, -0.1271, -0.0488,  ...,  0.0168,  0.0755, -0.1220],
        [ 0.0492, -0.1153,  0.1711,  ..., -0.2641,  0.1518,  0.1332],
        [-0.2595, -0.0150,  0.1485,  ...,  0.0032,  0.2474,  0.1573],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       grad_fn=<SelectBackward0>)