In [13]:
import torch
import torch.nn as nn
import math

# Parameters
batch_size = 2
sequence_length = 10
d_model = 32  # Dimensionality of the model
num_heads = 4  # Number of attention heads

sample = "abcdefghijklmnopqrstuvwxyz"
sample = [ord(c) - 97 for c in sample]
vocab_size = 27 + 1 + 1 # 1 for padding, 1 for start token

# Create one training example Seq x B 
batches = [sample[i:i+sequence_length-3] for i in range(2)]
source = [[27, 27, 27] + list(b) for b in batches]
src = torch.tensor(source)
print(src.shape, src)


torch.Size([2, 10]) tensor([[27, 27, 27,  0,  1,  2,  3,  4,  5,  6],
        [27, 27, 27,  1,  2,  3,  4,  5,  6,  7]])


In [14]:
# Pass through embedding layer for Seq x B x Dim
layer = nn.Embedding(vocab_size, d_model)
src = layer(src)
print(src.shape, src)

torch.Size([2, 10, 32]) tensor([[[ 7.9016e-01,  8.2979e-01, -1.5933e+00, -3.9233e-02,  9.4113e-01,
           1.6581e+00,  1.2207e+00, -2.6705e-02,  1.4590e-02, -6.0749e-01,
           8.9882e-01,  7.1349e-02,  4.7115e-01, -1.0111e+00, -8.0462e-02,
           5.5488e-01, -9.7055e-01,  7.9285e-01,  4.1483e-01,  7.3875e-01,
           1.0313e+00,  1.1615e+00, -4.7090e-01,  8.3787e-02,  6.6528e-01,
           1.0267e+00, -1.2128e+00, -6.2543e-01,  1.4762e-01,  6.0096e-02,
           9.6613e-01, -1.1305e+00],
         [ 7.9016e-01,  8.2979e-01, -1.5933e+00, -3.9233e-02,  9.4113e-01,
           1.6581e+00,  1.2207e+00, -2.6705e-02,  1.4590e-02, -6.0749e-01,
           8.9882e-01,  7.1349e-02,  4.7115e-01, -1.0111e+00, -8.0462e-02,
           5.5488e-01, -9.7055e-01,  7.9285e-01,  4.1483e-01,  7.3875e-01,
           1.0313e+00,  1.1615e+00, -4.7090e-01,  8.3787e-02,  6.6528e-01,
           1.0267e+00, -1.2128e+00, -6.2543e-01,  1.4762e-01,  6.0096e-02,
           9.6613e-01, -1.1305e+00],
  

In [15]:
# Create the tgt batch for teacher forcing
batches = [sample[i+7:i+7+sequence_length-4] for i in range(2)]
target = [[27, 27, 27, 28] + list(b) for b in batches]
print(target)
tgt = torch.tensor(target)
layer2 = nn.Embedding(vocab_size, d_model)
tgt = layer2(tgt)
print(tgt.shape, tgt)

[[27, 27, 27, 28, 7, 8, 9, 10, 11, 12], [27, 27, 27, 28, 8, 9, 10, 11, 12, 13]]
torch.Size([2, 10, 32]) tensor([[[-1.4124e+00,  2.1237e+00,  1.8613e+00, -2.4381e-02, -4.8382e-01,
          -6.5562e-01, -2.7951e-01,  2.7260e-01,  7.5491e-01,  1.9111e+00,
          -2.8320e-01,  5.2296e-01, -3.7284e-02, -1.3830e+00, -4.6526e-01,
          -6.2308e-02, -4.5181e-01, -5.9649e-02, -7.5916e-01,  1.0060e+00,
           6.5362e-02,  3.0836e-01, -4.3143e-02,  1.0548e+00,  2.4573e-01,
          -3.9501e-01, -7.6813e-01, -1.1998e-01, -1.7393e-01,  8.9106e-01,
           1.6803e+00,  4.8900e-01],
         [-1.4124e+00,  2.1237e+00,  1.8613e+00, -2.4381e-02, -4.8382e-01,
          -6.5562e-01, -2.7951e-01,  2.7260e-01,  7.5491e-01,  1.9111e+00,
          -2.8320e-01,  5.2296e-01, -3.7284e-02, -1.3830e+00, -4.6526e-01,
          -6.2308e-02, -4.5181e-01, -5.9649e-02, -7.5916e-01,  1.0060e+00,
           6.5362e-02,  3.0836e-01, -4.3143e-02,  1.0548e+00,  2.4573e-01,
          -3.9501e-01, -7.6813e-01

In [16]:
# Add positional encoding to both src and tgt
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

position = PositionalEncoding(d_model)
src_transposed = src.transpose(0, 1)
tgt_transposed = src.transpose(0, 1)
src_embedded = position(src_transposed)
tgt_embedded = position(tgt_transposed)
print(src_embedded.shape, tgt_embedded.shape)
print(src_embedded, tgt_embedded)

torch.Size([10, 2, 32]) torch.Size([10, 2, 32])
tensor([[[ 8.7795e-01,  2.0331e+00, -1.7703e+00,  1.0675e+00,  1.0457e+00,
           2.9534e+00,  1.3564e+00,  1.0814e+00,  1.6211e-02,  0.0000e+00,
           9.9869e-01,  1.1904e+00,  5.2351e-01, -1.2330e-02, -8.9403e-02,
           1.7276e+00, -1.0784e+00,  1.9921e+00,  4.6092e-01,  1.9319e+00,
           1.1459e+00,  2.4017e+00, -5.2322e-01,  1.2042e+00,  7.3920e-01,
           2.2519e+00, -0.0000e+00,  4.1618e-01,  0.0000e+00,  1.1779e+00,
           0.0000e+00, -1.4498e-01],
         [ 8.7795e-01,  2.0331e+00, -1.7703e+00,  1.0675e+00,  1.0457e+00,
           0.0000e+00,  1.3564e+00,  1.0814e+00,  1.6211e-02,  4.3612e-01,
           9.9869e-01,  1.1904e+00,  5.2351e-01, -1.2330e-02, -8.9403e-02,
           1.7276e+00, -1.0784e+00,  1.9921e+00,  4.6092e-01,  1.9319e+00,
           0.0000e+00,  2.4017e+00, -5.2322e-01,  1.2042e+00,  7.3920e-01,
           0.0000e+00, -1.3475e+00,  4.1618e-01,  1.6402e-01,  1.1779e+00,
           1.07

In [17]:
# Based on prior knowledge of padding tokens, create src_key_padding_mask and tgt_key_padding_mask

# Src mask: ignore the first three tokens always 
# Tgt mask: ignore the first three tokens always since sequence length for src = sequence length for tgt
padding = [[True for _ in range(3)] + [False for _ in range(7)] for _ in range(batch_size)]
src_key_padding_mask = torch.tensor(padding)
tgt_key_padding_mask = torch.tensor(padding)
print(src_key_padding_mask.shape, src_key_padding_mask)


torch.Size([2, 10]) tensor([[ True,  True,  True, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False]])


In [18]:
# Make tgt_mask, the triangular mask for self-attention
base = torch.tensor([[False for _ in range(sequence_length)] for _ in range(sequence_length)])
tgt_mask = base.masked_fill(torch.triu(torch.ones(sequence_length, sequence_length)) == 0, True)
print(tgt_mask.shape, tgt_mask)

torch.Size([10, 10]) tensor([[False, False, False, False, False, False, False, False, False, False],
        [ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False]])


In [19]:
# Make the model
model = torch.nn.Transformer(d_model=d_model, dim_feedforward=128)
output = model(src_embedded, tgt_embedded, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, 
               memory_key_padding_mask=src_key_padding_mask)
print(output.shape, output)

torch.Size([10, 2, 32]) tensor([[[-2.2855e+00,  1.1811e+00, -1.1609e+00, -2.8465e-01,  4.7382e-01,
           1.7363e+00,  4.0376e-01, -9.0100e-01, -6.4272e-01,  7.3389e-01,
          -1.4083e+00,  1.0995e+00, -4.8675e-01,  2.6777e-01,  1.4801e-01,
           1.6436e+00,  1.5918e+00, -2.2470e-01, -1.2635e+00, -3.3192e-01,
           4.5677e-01, -3.0907e-01, -1.8443e+00, -4.8271e-01, -1.6415e-01,
           5.7380e-01,  6.3218e-01, -1.2305e+00,  1.0536e+00,  5.1774e-01,
          -2.7404e-01,  7.8098e-01],
         [-2.3898e+00,  7.9866e-01, -2.4634e-01,  7.0526e-02,  9.1736e-01,
           1.9906e+00, -5.3515e-01, -6.9564e-01, -5.2566e-01, -5.8530e-01,
          -1.3632e+00,  1.4280e+00, -7.1912e-01,  1.6548e-01,  9.6604e-01,
           1.3112e+00,  2.7866e-01, -3.6365e-01, -1.1626e+00,  5.8547e-01,
          -3.6694e-02, -2.1272e-01, -2.1815e+00,  8.8755e-02,  3.7245e-01,
           7.8782e-01, -1.0995e+00, -5.0677e-01,  1.3882e+00,  7.0478e-01,
          -1.7823e-01,  9.4795e-01]],



In [20]:
from torch.nn import functional as F

# Make sense of the model's outputs
# output_transpose = output.transpose(0, 1)
final = nn.Linear(d_model, vocab_size)
logits = final(output)
logits_transpose = logits.transpose(0, 1)
probs = F.softmax(logits_transpose[:, -1, :], dim=-1)
outputs = torch.multinomial(probs, num_samples=1)
print(outputs.shape, outputs)
chars = [[chr(c+97) for c in tensor] for tensor in outputs]
print(chars)

torch.Size([2, 1]) tensor([[ 1],
        [13]])
[['b'], ['n']]


In [21]:
# Create the targets tensor containing the true/expected outputs
bases = [sample[i+7:i+7+sequence_length-3] for i in range(2)]
expected = torch.tensor([[27, 27, 27] + list(b) for b in bases])
print(expected.shape, expected)

torch.Size([2, 10]) tensor([[27, 27, 27,  7,  8,  9, 10, 11, 12, 13],
        [27, 27, 27,  8,  9, 10, 11, 12, 13, 14]])


In [22]:
# Transform arguments for cross-entropy loss
T, B, C = logits.shape
logits_transformed = logits.reshape(T * B, C)
expected_transpose = torch.transpose(expected, 0, 1)
expected_transformed = expected_transpose.reshape(T * B)

print(logits_transformed.shape, expected_transformed.shape)
print(logits_transformed, expected_transformed)

torch.Size([20, 29]) torch.Size([20])
tensor([[-6.2342e-01,  2.7832e-01,  1.8441e-02,  1.0144e+00, -5.1528e-01,
          4.3129e-02, -1.6538e-01, -1.1743e-01, -1.7563e+00,  2.9658e-01,
          1.0451e+00, -4.6845e-02,  7.0998e-01,  5.6971e-01,  6.4561e-02,
         -5.6046e-01, -4.5784e-01,  7.1247e-01, -3.0143e-01,  2.1461e-01,
          4.0552e-02, -1.7129e-01, -6.1912e-01, -5.1148e-03, -6.5415e-01,
          6.5041e-01, -1.5502e-01, -5.9484e-01, -7.2411e-01],
        [-6.2766e-01,  5.8591e-01, -4.7423e-01,  5.5343e-01, -1.2249e-01,
         -3.1339e-01, -2.0723e-01, -1.7426e-01, -1.6982e+00, -2.5320e-01,
          5.9999e-01, -4.4895e-02,  1.4758e-01,  1.1065e+00, -2.0715e-01,
          1.0247e-01, -7.2721e-01,  5.5958e-01,  6.1067e-02,  1.8970e-01,
          2.9303e-01, -5.3665e-01, -3.4513e-01, -5.8633e-01, -1.0748e+00,
          4.9794e-01,  1.2520e-01, -4.8674e-01, -5.2026e-01],
        [-2.5942e-01,  2.1169e-01, -3.7228e-01,  6.6269e-01, -4.7197e-01,
          4.2000e-01, -5

In [23]:
# Calculate loss while ignoring padding inputs
from torch.nn import functional as F

weights = torch.tensor([1.0 for _ in range(26)] + [0.0 for _ in range(vocab_size - 26)]) # ignore tokens that aren't from our original vocabulary
print(weights)
loss = F.cross_entropy(logits_transformed, expected_transformed, weight=weights)
print(loss)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.])
tensor(3.5062, grad_fn=<NllLossBackward0>)


In [24]:
# One backprop step
optimizer = torch.optim.AdamW(model.parameters(), lr=3)
loss.backward()

# Show that model output has lower loss value when same inputs used
output_redo = model(src_embedded, tgt_embedded, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, 
               memory_key_padding_mask=src_key_padding_mask)
logits_redo = final(output)
T, B, C = logits_redo.shape
logits_redo_transformed = logits_redo.reshape(T * B, C)

weights = torch.tensor([1.0 for _ in range(26)] + [0.0 for _ in range(vocab_size - 26)]) 
loss = F.cross_entropy(logits_transformed, expected_transformed, weight=weights)
print(loss)

tensor(3.5062, grad_fn=<NllLossBackward0>)
