In [1]:
import torch
import torch.nn as nn
import math

# Parameters
batch_size = 2
sequence_length = 10
d_model = 32  # Dimensionality of the model
num_heads = 4  # Number of attention heads

sample = "abcdefghijklmnopqrstuvwxyz"
sample = [ord(c) - 97 for c in sample]
vocab_size = 27 + 1 + 1 # 1 for padding, 1 for start token

# Create one training example Seq x B 
batches = [sample[i:i+sequence_length-3] for i in range(2)]
source = [[27, 27, 27] + list(b) for b in batches]
src = torch.tensor(source)
print(src.shape, src)


torch.Size([2, 10]) tensor([[27, 27, 27,  0,  1,  2,  3,  4,  5,  6],
        [27, 27, 27,  1,  2,  3,  4,  5,  6,  7]])


In [2]:
# Pass through embedding layer for Seq x B x Dim
layer = nn.Embedding(vocab_size, d_model)
src = layer(src)
print(src.shape, src)

torch.Size([2, 10, 32]) tensor([[[-1.7005e+00,  1.0782e+00, -6.7454e-01,  9.1053e-02,  1.5202e-01,
          -2.3548e-01, -4.0013e-01, -2.2318e-01, -6.6264e-01, -3.3346e-01,
          -2.4648e+00,  3.6981e-01,  6.8890e-06,  3.9696e-02,  6.4057e-01,
          -9.2828e-01, -6.6453e-01, -1.9363e+00, -1.7545e+00,  4.1251e-01,
           1.3720e+00, -2.7778e+00, -1.0247e+00,  5.3663e-01,  1.8722e-01,
          -2.0146e-01,  1.1913e+00, -2.9929e-01,  4.3982e-01,  1.2609e+00,
           5.5834e-01, -1.6000e+00],
         [-1.7005e+00,  1.0782e+00, -6.7454e-01,  9.1053e-02,  1.5202e-01,
          -2.3548e-01, -4.0013e-01, -2.2318e-01, -6.6264e-01, -3.3346e-01,
          -2.4648e+00,  3.6981e-01,  6.8890e-06,  3.9696e-02,  6.4057e-01,
          -9.2828e-01, -6.6453e-01, -1.9363e+00, -1.7545e+00,  4.1251e-01,
           1.3720e+00, -2.7778e+00, -1.0247e+00,  5.3663e-01,  1.8722e-01,
          -2.0146e-01,  1.1913e+00, -2.9929e-01,  4.3982e-01,  1.2609e+00,
           5.5834e-01, -1.6000e+00],
  

In [3]:
# Create the tgt batch for teacher forcing
batches = [sample[i+7:i+7+sequence_length-4] for i in range(2)]
target = [[27, 27, 27, 28] + list(b) for b in batches]
print(target)
tgt = torch.tensor(target)
layer2 = nn.Embedding(vocab_size, d_model)
tgt = layer2(tgt)
print(tgt.shape, tgt)

[[27, 27, 27, 28, 7, 8, 9, 10, 11, 12], [27, 27, 27, 28, 8, 9, 10, 11, 12, 13]]
torch.Size([2, 10, 32]) tensor([[[ 1.4406, -0.4199, -0.6510,  1.8684, -0.2709,  1.1193, -1.7483,
           1.2059,  0.1277,  0.1760, -0.1456,  0.5224, -0.6909, -0.9805,
           1.2229,  0.5151,  2.0802, -0.8342, -0.8798,  0.4796, -0.7941,
           0.8804,  0.4520,  0.8692,  1.4337, -0.2454, -0.2260, -0.0085,
           2.2802, -0.1143, -0.4042,  0.9016],
         [ 1.4406, -0.4199, -0.6510,  1.8684, -0.2709,  1.1193, -1.7483,
           1.2059,  0.1277,  0.1760, -0.1456,  0.5224, -0.6909, -0.9805,
           1.2229,  0.5151,  2.0802, -0.8342, -0.8798,  0.4796, -0.7941,
           0.8804,  0.4520,  0.8692,  1.4337, -0.2454, -0.2260, -0.0085,
           2.2802, -0.1143, -0.4042,  0.9016],
         [ 1.4406, -0.4199, -0.6510,  1.8684, -0.2709,  1.1193, -1.7483,
           1.2059,  0.1277,  0.1760, -0.1456,  0.5224, -0.6909, -0.9805,
           1.2229,  0.5151,  2.0802, -0.8342, -0.8798,  0.4796, -0.7941,

In [4]:
# Add positional encoding to both src and tgt
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

position = PositionalEncoding(d_model)
src_transposed = src.transpose(0, 1)
tgt_transposed = src.transpose(0, 1)
src_embedded = position(src_transposed)
tgt_embedded = position(tgt_transposed)
print(src_embedded.shape, tgt_embedded.shape)
print(src_embedded, tgt_embedded)

torch.Size([10, 2, 32]) torch.Size([10, 2, 32])
tensor([[[-1.8895e+00,  2.3091e+00, -7.4949e-01,  1.2123e+00,  1.6891e-01,
           8.4947e-01, -4.4459e-01,  0.0000e+00, -7.3626e-01,  7.4060e-01,
          -0.0000e+00,  1.5220e+00,  7.6544e-06,  1.1552e+00,  7.1174e-01,
           0.0000e+00, -7.3837e-01, -1.0403e+00, -1.9495e+00,  1.5695e+00,
           1.5245e+00, -1.9753e+00, -1.1385e+00,  1.7074e+00,  2.0802e-01,
           8.8726e-01,  1.3237e+00,  7.7856e-01,  4.8869e-01,  0.0000e+00,
           6.2038e-01, -6.6670e-01],
         [-1.8895e+00,  2.3091e+00, -7.4949e-01,  1.2123e+00,  1.6891e-01,
           8.4947e-01, -4.4459e-01,  8.6313e-01, -7.3626e-01,  7.4060e-01,
          -2.7387e+00,  1.5220e+00,  7.6544e-06,  1.1552e+00,  7.1174e-01,
           7.9686e-02, -7.3837e-01, -1.0403e+00, -1.9495e+00,  1.5695e+00,
           1.5245e+00, -1.9753e+00, -1.1385e+00,  1.7074e+00,  2.0802e-01,
           8.8726e-01,  1.3237e+00,  7.7856e-01,  4.8869e-01,  0.0000e+00,
           0.00

In [5]:
# Based on prior knowledge of padding tokens, create src_key_padding_mask and tgt_key_padding_mask

# Src mask: ignore the first three tokens always 
# Tgt mask: ignore the first three tokens always since sequence length for src = sequence length for tgt
padding = [[True for _ in range(3)] + [False for _ in range(7)] for _ in range(batch_size)]
src_key_padding_mask = torch.tensor(padding)
tgt_key_padding_mask = torch.tensor(padding)
print(src_key_padding_mask.shape, src_key_padding_mask)


torch.Size([2, 10]) tensor([[ True,  True,  True, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False]])


In [6]:
# Make tgt_mask, the triangular mask for self-attention
base = torch.tensor([[False for _ in range(sequence_length)] for _ in range(sequence_length)])
tgt_mask = base.masked_fill(torch.triu(torch.ones(sequence_length, sequence_length)) == 0, True)
print(tgt_mask.shape, tgt_mask)

torch.Size([10, 10]) tensor([[False, False, False, False, False, False, False, False, False, False],
        [ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False]])


In [7]:
# Make the model
model = torch.nn.Transformer(d_model=d_model, dim_feedforward=128)
output = model(src_embedded, tgt_embedded, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, 
               memory_key_padding_mask=src_key_padding_mask)
print(output.shape, output)

torch.Size([10, 2, 32]) tensor([[[ 1.5070e+00,  1.6663e+00, -1.0674e+00,  2.3794e-01, -4.3317e-01,
          -3.9458e-01, -5.4527e-01,  4.8898e-01,  2.5693e+00,  1.9761e+00,
          -5.6008e-02, -4.9662e-01, -3.3387e-01,  1.7740e-01,  6.7733e-02,
           6.4111e-02,  1.0068e+00, -1.9058e-01, -8.0400e-01, -9.5027e-01,
           1.1406e+00, -6.8390e-01, -2.3240e+00, -7.3452e-01, -1.1001e-01,
          -2.3021e-01,  3.5871e-01,  5.8283e-01,  7.1940e-02, -1.5147e+00,
          -8.9745e-01, -1.4918e-01],
         [ 1.7583e-01,  9.3573e-01, -1.3586e-01,  1.1002e+00,  1.8602e-01,
           2.6841e-01, -9.9677e-01,  1.3502e+00,  1.8772e+00,  1.0097e+00,
           5.1272e-01, -4.4230e-01, -9.9913e-01,  8.9823e-01,  8.4190e-01,
           5.1016e-01,  5.4254e-01,  4.2347e-01, -1.1696e+00, -1.4586e-01,
           1.5949e+00, -1.5116e+00, -1.5494e+00, -1.0213e+00, -1.4228e+00,
          -1.2020e+00, -1.3941e+00,  7.7041e-01,  1.1005e+00, -1.1492e+00,
          -6.0621e-01, -3.5190e-01]],



In [8]:
from torch.nn import functional as F

# Make sense of the model's outputs
# output_transpose = output.transpose(0, 1)
final = nn.Linear(d_model, vocab_size)
logits = final(output)
logits_transpose = logits.transpose(0, 1)
probs = F.softmax(logits_transpose[:, -1, :], dim=-1)
outputs = torch.multinomial(probs, num_samples=1)
print(outputs.shape, outputs)
chars = [[chr(c+97) for c in tensor] for tensor in outputs]
print(chars)

torch.Size([2, 1]) tensor([[11],
        [10]])
[['l'], ['k']]


In [9]:
# Create the targets tensor containing the true/expected outputs
bases = [sample[i+7:i+7+sequence_length-3] for i in range(2)]
expected = torch.tensor([[27, 27, 27] + list(b) for b in bases])
print(expected.shape, expected)

torch.Size([2, 10]) tensor([[27, 27, 27,  7,  8,  9, 10, 11, 12, 13],
        [27, 27, 27,  8,  9, 10, 11, 12, 13, 14]])


In [10]:
# Transform arguments for cross-entropy loss
T, B, C = logits.shape
logits_transformed = logits.reshape(T * B, C)
expected_transpose = torch.transpose(expected, 0, 1)
expected_transformed = expected_transpose.reshape(T * B)

print(logits_transformed.shape, expected_transformed.shape)
print(logits_transformed, expected_transformed)

torch.Size([20, 29]) torch.Size([20])
tensor([[ 1.2654,  0.0104,  0.6204,  1.0383, -0.0044, -0.3028, -0.2002, -0.1106,
          0.9930,  0.1641,  0.4270,  0.7774, -0.3238,  0.5365,  0.4778,  0.1651,
         -0.2622,  0.2606,  0.0803, -0.0943,  1.3807, -0.4365, -0.7924,  0.4097,
          0.0247, -0.0355, -0.4363,  0.8820,  0.2662],
        [ 0.7499, -0.1279,  0.2429,  0.5525, -0.5683, -0.2144,  0.0221,  0.4339,
          1.0958, -0.3788,  0.2300,  0.9413, -0.8960,  0.3285,  0.3964,  0.2731,
         -0.7853, -0.0951, -0.5549, -0.2452,  1.1259, -0.4897, -1.2872,  0.1996,
          0.5018, -0.3252,  0.5100,  1.1737,  0.4933],
        [ 0.8359, -0.0453,  0.3825,  1.3930, -0.0809, -0.5833, -0.0573,  0.7016,
          0.9716,  0.1022,  0.2368,  1.1571, -0.2503,  0.3388,  0.7116,  0.2387,
         -0.2298,  0.0577,  0.0814,  0.1958,  1.0669, -0.1804, -1.2100,  0.3075,
          0.3758,  0.5099, -0.1012,  0.7580,  0.2403],
        [ 1.3639,  0.2359,  0.5032,  0.2075, -0.5543,  0.4779, -0.34

In [11]:
# Calculate loss while ignoring padding inputs
from torch.nn import functional as F

weights = torch.tensor([1.0 for _ in range(26)] + [0.0 for _ in range(vocab_size - 26)]) # ignore tokens that aren't from our original vocabulary
print(weights)
loss = F.cross_entropy(logits_transformed, expected_transformed, weight=weights)
print(loss)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.])
tensor(3.3467, grad_fn=<NllLossBackward0>)


In [12]:
# One backprop step
optimizer = torch.optim.AdamW(model.parameters(), lr=3)
loss.backward()

# Show that model output has lower loss value when same inputs used
output_redo = model(src_embedded, tgt_embedded, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, 
               memory_key_padding_mask=src_key_padding_mask)
logits_redo = final(output)
T, B, C = logits_redo.shape
logits_redo_transformed = logits_redo.reshape(T * B, C)

weights = torch.tensor([1.0 for _ in range(26)] + [0.0 for _ in range(vocab_size - 26)]) 
loss = F.cross_entropy(logits_transformed, expected_transformed, weight=weights)
print(loss)

tensor(3.3467, grad_fn=<NllLossBackward0>)
