In [1]:
from src.layers import (
    MultiheadAttention, PointwiseFeedForward, 
    EncoderLayer, DecoderLayer, Encoder, Decoder,
    Transformer, MaskedCrossEntropyLoss
)
from src.utils import create_masks, create_padding_mask, create_look_ahead_mask
import torch

In [2]:
m = MultiheadAttention(512, 2)

256


In [3]:
x = torch.randn(5, 16, 512)

In [4]:
o, a = m(x, x, x)

In [5]:
o.size(), a.size()

(torch.Size([5, 16, 512]), torch.Size([5, 2, 16, 16]))

In [6]:
f = PointwiseFeedForward(512, 64)

In [7]:
f(o).size()

torch.Size([5, 16, 512])

In [8]:
sample_encoder_layer = EncoderLayer(512, 8, 2048)

x = torch.randn(64, 43, 512)

sample_encoder_layer_output = sample_encoder_layer(x)

sample_encoder_layer_output.shape  # (batch_size, input_seq_len, d_model)

64


torch.Size([64, 43, 512])

In [9]:
sample_decoder_layer = DecoderLayer(512, 8, 2048)

sample_decoder_layer_output, _, _ = sample_decoder_layer(x, sample_encoder_layer_output)

sample_decoder_layer_output.shape  # (batch_size, target_seq_len, d_model)

64
64


torch.Size([64, 43, 512])

In [10]:
sample_encoder = Encoder(num_layers=2, d_model=512, num_heads=8, 
                         dff=2048, input_vocab_size=8500,
                         maximum_position_encoding=10000)

sample_encoder_output = sample_encoder(torch.rand(64, 62).long(), mask=None)

print (sample_encoder_output.shape)  # (batch_size, input_seq_len, d_model)

64
64
torch.Size([64, 62, 512])


In [11]:
sample_decoder = Decoder(num_layers=2, d_model=512, num_heads=8, 
                         dff=2048, target_vocab_size=8000,
                         maximum_position_encoding=5000)

output, attn = sample_decoder(torch.rand(64, 26).long(), 
                              enc_output=sample_encoder_output,
                              look_ahead_mask=None, 
                              padding_mask=None)

output.shape, attn['decoder_layer2_block2'].shape

64
64
64
64


(torch.Size([64, 26, 512]), torch.Size([64, 8, 26, 62]))

In [12]:
sample_transformer = Transformer(
    num_layers=2, d_model=512, num_heads=8, dff=2048, 
    input_vocab_size=8500, target_vocab_size=8000, 
    pe_input=10000, pe_target=6000)

temp_input = torch.rand(64, 62).long()
temp_target = torch.rand(64, 26).long()

fn_out, _ = sample_transformer(temp_input, temp_target, 
                               enc_padding_mask=None, 
                               look_ahead_mask=None,
                               dec_padding_mask=None)

fn_out.shape  # (batch_size, tar_seq_len, target_vocab_size)

64
64
64
64
64
64


torch.Size([64, 26, 8000])

In [13]:
num_layers = 4
d_model = 128
dff = 512
num_heads = 8

input_vocab_size = 10000 + 2
target_vocab_size = 10000 + 2
dropout_rate = 0.1

In [15]:
transformer = Transformer(num_layers, d_model, num_heads, dff, None,
                          input_vocab_size, target_vocab_size, 
                          pe_input=input_vocab_size, 
                          pe_target=target_vocab_size,
                          rate=dropout_rate)

16
16
16
16
16
16
16
16
16
16
16
16


In [16]:
loss_function = MaskedCrossEntropyLoss()

In [17]:
inp = (input_vocab_size*torch.rand(2, 3)).long()
tar_inp = (input_vocab_size*torch.rand(2, 4)).long()
inp[0, 1] = 0.
inp[1, 2] = 0.
inp

tensor([[7626,    0, 5443],
        [3561, 4571,    0]])

In [18]:
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

In [19]:
enc_padding_mask.size(), combined_mask.size(), dec_padding_mask.size()

(torch.Size([2, 3]), torch.Size([2, 4]), torch.Size([2, 3]))

In [20]:
print(enc_padding_mask)
print(combined_mask )
print(dec_padding_mask)

tensor([[0., 1., 0.],
        [0., 0., 1.]])
tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.]])
tensor([[0., 1., 0.],
        [0., 0., 1.]])


In [21]:
predictions, _ = transformer(inp, tar_inp, 
                             enc_padding_mask, 
                             combined_mask, 
                             dec_padding_mask)
loss = loss_function(predictions, tar_inp)

In [22]:
loss

tensor(9.6417, grad_fn=<DivBackward0>)

In [23]:
(tar_inp != 0.).float().sum()

tensor(8.)