In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import torch
import torch.nn as nn


# Batch size
B = 10
# src sequence size
SRC = 40
# d_model size
D = 512
# sequence size
S = TGT = 32

# header size
H = 16
L = 12

transformer_model = nn.Transformer(d_model=D, nhead=H, 
                                num_encoder_layers=L,
                                batch_first=True)

In [None]:
src = torch.rand((B, SRC, D))
tgt = torch.rand((B, S, D))
out = transformer_model(src, tgt)

In [None]:
assert out.shape == torch.Size((B, S, D))

In [None]:
# how to create up triangle matrix
torch.triu(torch.full((S, S), float('-inf')), diagonal=1)

tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [None]:
# Encoder

encoder_layer = nn.TransformerEncoderLayer(d_model=D, nhead=H, batch_first=True)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=L)
src = torch.rand(B, SRC, D)
out = transformer_encoder(src)

In [None]:
out.shape

torch.Size([10, 40, 512])

In [None]:
assert out.shape == torch.Size([B, SRC, D])

In [None]:
# Decoder

decoder_layer = nn.TransformerDecoderLayer(d_model=D, nhead=H, batch_first=True)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=L)

memory = torch.rand(B, SRC, D)
tgt = tgt = torch.rand((B, S, D))
out = transformer_decoder(tgt, memory)

In [None]:
out.shape

torch.Size([10, 32, 512])

In [None]:
assert out.shape == torch.Size([B, S, D])

In [None]:
# Encoder Layer

encoder_layer = nn.TransformerEncoderLayer(d_model=D, nhead=H, batch_first=True)
src = torch.rand(B, SRC, D)
out = encoder_layer(src)

In [None]:
assert out.shape == torch.Size((B, SRC, D))

In [None]:
# Decoder Layer

decoder_layer = nn.TransformerDecoderLayer(d_model=D, nhead=H, batch_first=True)
memory = torch.rand(B, SRC, D)
tgt = tgt = torch.rand((B, S, D))
out = decoder_layer(tgt, memory)

In [None]:
assert out.shape == torch.Size((B, S, D))

In [None]:
# Multhead attention

multihead_attn = nn.MultiheadAttention(embed_dim=D, num_heads=H, batch_first=True)

query = torch.rand((B, S, D))
key = torch.rand((B, S, D))
value = torch.rand((B, S, D))

attn_output, attn_output_weights = multihead_attn(query, key, value)

In [None]:
attn_output.shape

torch.Size([10, 32, 512])

In [None]:
assert attn_output.shape == query.shape

In [None]:
attn_output_weights.shape

torch.Size([10, 32, 32])

In [None]:
assert attn_output_weights.shape == torch.Size((B, S, S))