In [47]:
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

# generate_square_subsequent_mask

In [48]:
torch.triu(torch.ones(3, 3))

tensor([[1., 1., 1.],
        [0., 1., 1.],
        [0., 0., 1.]])

In [49]:
torch.triu(torch.ones(3, 3) == 1)

tensor([[ True,  True,  True],
        [False,  True,  True],
        [False, False,  True]])

In [50]:
torch.triu(torch.ones(3, 3) == 1).transpose(0,1)

tensor([[ True, False, False],
        [ True,  True, False],
        [ True,  True,  True]])

In [51]:
mask = torch.triu(torch.ones(3, 3) == 1).transpose(0, 1)

In [52]:
mask

tensor([[ True, False, False],
        [ True,  True, False],
        [ True,  True,  True]])

In [53]:
mask.float()

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [54]:
mask.float().masked_fill(mask==0, float('-inf'))

tensor([[1., -inf, -inf],
        [1., 1., -inf],
        [1., 1., 1.]])

In [55]:
mask.float().masked_fill(mask==0, float('-inf')).masked_fill(mask==1, float(0))

tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])

# create_mask

In [56]:
import pickle as pkl

In [57]:

dataset_dir_path = '../../data/processed/tokenized_data/'
with open(dataset_dir_path + 'train_data.pkl', 'rb') as f:
    tokenized_train_data = pkl.load(f)

with open(dataset_dir_path + 'valid_data.pkl', 'rb') as f:
    tokenized_valid_data = pkl.load(f)

In [58]:
vocab_dir_path = '../../data/processed/vocab/'

with open(vocab_dir_path + 'token2idx_de.pkl', 'rb') as f:
    token2idx_de= pkl.load(f)
with open(vocab_dir_path + 'token2idx_en.pkl', 'rb') as f:
    token2idx_en = pkl.load(f)
with open(vocab_dir_path + 'idx2token_de.pkl', 'rb') as f:
    idx2token_de = pkl.load(f)
with open(vocab_dir_path + 'idx2token_en.pkl', 'rb') as f:
    idx2token_en = pkl.load(f)

In [59]:
batch_size = 128
PAD_INDEX = token2idx_de['<pad>']
START_INDEX = token2idx_en['<start>']
END_INDEX = token2idx_en['<end>']

In [60]:
def generate_batch(data_batch):
    batch_src = []
    batch_tgt = []
    for src, tgt in data_batch:
        batch_src.append(src)
        batch_tgt.append(tgt)
    
    batch_src = pad_sequence(batch_src, padding_value=PAD_INDEX)
    batch_tgt = pad_sequence(batch_tgt, padding_value=PAD_INDEX)

    return batch_src, batch_tgt

In [61]:
train_iter = DataLoader(tokenized_train_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(tokenized_valid_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)

In [62]:
src, tgt = next(iter(train_iter))
src.shape, tgt.shape

(torch.Size([26, 128]), torch.Size([29, 128]))

In [112]:
sample_src = src[:, 3]
print(src.shape)
sample_src

torch.Size([26, 128])


tensor([  2,  60, 120, 234,  42,  34, 244,   3,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1])

In [111]:
counts = []
for sentence in src.T:
    counter = 0
    for token in sentence:
        if token == 1:
            counter += 1
    #print(counter)
    counts.append(counter)

print(max(counts))

22


In [84]:
# sample src 
for idx in sample_src:
    token = idx2token_de[idx.item()]
    print(token, end=' ')

<start> Zwei Männer bauen etwas zusammen <end> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 

In [90]:
# padding_mask_src
print(f'src shape (number of tokens x minibatch size) : {src.shape}')
seq_len_src = src.shape[0]
print(f'uniformed token size : {seq_len_src}')

src shape (number of tokens x minibatch size) : torch.Size([26, 128])
uniformed token size : 26


In [95]:
mask_src = torch.zeros((seq_len_src, seq_len_src), dtype=torch.bool)
print(f'mask shape : {mask_src.shape}')
print(mask_src)

mask shape : torch.Size([26, 26])
tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, Fal

In [98]:
(src == PAD_INDEX).transpose(0, 1)

tensor([[False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True]])

In [97]:
src

tensor([[  2,   2,   2,  ...,   2,   2,   2],
        [ 21,  21, 126,  ...,   5,  60,   5],
        [ 31,  47,  69,  ...,  12,  27,  12],
        ...,
        [  1,   1,   1,  ...,   1,   1,   1],
        [  1,   1,   1,  ...,   1,   1,   1],
        [  1,   1,   1,  ...,   1,   1,   1]])