In [None]:
import torch as tc
from tqdm import tqdm
import torch.nn as nn
from time import perf_counter
from torchsummary import summary

bar = '-'*64
seed = 42
tc.manual_seed(seed)
device = 'cuda' if tc.cuda.is_available() else 'cpu'
print(f"using {device}")

using cuda


# <u>Using a Transformer with PyTorch</u>

With their capacity for parallelization and the ability to capture long-term dependencies in data, Transformers have immense potential in various fields, especially NLP tasks like translation, summarization, and sentiment analysis. In this notebook, we show how PyTorch's Transformer module `torch.nn.Transformer` can be used.

![](fig/transformer.png)
![](https://drive.google.com/uc?export=view&id=1N_tpFwQQ2FNk4yD829q3BLOfehJtQoWo)

 ### <u>Positional Encoding Layer</u>

 Since the transformer model lacks inherent knowledge of the order of tokens (due to its self-attention mechanism), the following class helps the model to consider the position of tokens in the sequence. Sinusoidal functions are used in order to allow the model to more easily learn to attend to relative positions, as they produce a unique and smooth encoding for each position in the sequence:

\begin{equation}
\text{PE}(\text{pos}, i) =
\begin{cases}
\sin\left(f(\text{pos}, 2i) \right) \\
\cos\left(f(\text{pos}, 2i-1) \right)
\end{cases},
\end{equation}
where:
\begin{equation}
  f(\text{pos}, i) = \text{pos} \cdot \exp \left(- \frac{4 \cdot \ln(10) \cdot i}{d_\text{model}} \right),
\end{equation}
and $i = 0, 1, \dots, d_\text{model}-1$, $\ pos = 0, 1, \dots, \text{max_seq_len}-1
$.

Afterwards, the positional encoding (PE) matrix is **added** (not appended) to the input (embeddings) matrix $X$, i.e., $\text{out}_{\text{pos}, i} = x_{\text{pos}, i} + \text{pe}_{\text{pos}, i}$. This provides a sense of sequence because it introduces a unique signal at each position in the sequence. This signal is designed in such a way that the model can theoretically determine the position of each token or the distance between different tokens based on their positional encodings alone. Concatenation instead would (1) increase the dimensionality and (2) might require the model to learn to separate positional information from semantic information explicitly.


In [None]:
loge4 = 4*tc.log(tc.tensor(10))

class PositionalEncoding(nn.Module):
    """Positional Encoding using sine and cosine functions of different frequencies to generate the positional encoding."""
    def __init__(self, d_model, max_seq_len):
        super(PositionalEncoding, self).__init__()
        # position = [0, 1, 2, ..., max_seq_len]:
        position = tc.arange(0, max_seq_len, dtype=tc.float).unsqueeze(1)
        # calculate for all i = [0, 2, 4, ..., d_model]:
        # exp((-4*ln(10)*i/d_model)):
        div_term = tc.exp(tc.arange(0, d_model, 2).float() * -(loge4 / d_model))
        # The positional encoding matrix is:
        pe = tc.zeros(max_seq_len, d_model)
        pe[:, 0::2] = tc.sin(position * div_term)
        pe[:, 1::2] = tc.cos(position * div_term)
        # self.register_buffer is used to register a tensor as a buffer in a Module.
        # Buffers are tensors that are not to be considered as model parameters;
        # i.e., they are not trainable and don't require gradients.
        # pe.unsqueeze(0) just adds a 0-th dimension (for the batch):
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

### <u>Embedding Layer</u>
Embeddings convert tokens (which could be words, characters, or subwords) into dense vectors of fixed size. In the case of the Transformer model, `nn.Embedding layers` are used to map each token in the source `src_voc_size` and target `tgt_voc_size` vocabularies to a high-dimensional space `d_model`. The embedding layers will transform each token into a dense vector of a fixed size, `d_model`. This means each integer in the sequence is replaced by a vector of `d_model` dimensions. Consequently, the shape of the output from the embedding layers for both src and tgt becomes `[batch_size, seq_len, d_model]`.

An embedding layer contains a matrix of trainable parameters. For a vocabulary size of `src_vocab_size` and an embedding dimension of `d_model`, the size of this matrix is `[V, d_model]`. Each row of the matrix corresponds to the dense vector representation of a token in the vocabulary. The function for this transformation is essentially a lookup operation. Given a token ID, the embedding layer returns the corresponding row from its parameter matrix; it is like the key-value concept in a hashtable. This can be mathematically represented as a matrix multiplication where the input token IDs are used to index into the embedding matrix.

### <u> Masking </u>

As we explained in our tutorial document, masking in Transformers is crucial for controlling the flow of information, especially in the decoder to prevent future tokens from influencing the prediction of the current token, maintaining the autoregressive property. Let us see an example:

Suppose we have a target sequence for a language translation task: ["Hello", "world", "!"]. During training, the transformer model tries to predict the next word based on the previous words. The input to the decoder at each step would ideally be:

Step 1: Input: ["\<bos\>"] Target: ["Hello"]
Step 2: Input: ["\<bos\>", "Hello"] Target: ["world"]
Step 3: Input: ["\<bos\>", "Hello", "world"] Target: ["!"]

To ensure the model only uses past information (and Beginning of Sequence \<bos\> token initially) to predict the next word, we use a target mask (`tgt_mask`) that looks like this for a sequence of length $3$:

\begin{bmatrix}
  0 & -\infty & -\infty \\
  0 & 0 & -\infty \\
  0 & 0 & 0
\end{bmatrix}
This triangular matrix that we would multiply (element-wise) the input with, ensures that for each position in the sequence, the model can only attend to previous positions and itself, not future positions:
+ 0 means the model can attend to that position, i.e., information is visible/unmasked.
+ $-\infty$ means the model cannot attend to that position, i.e., information is invisible/masked.
In code, that could be done as: <br>
`tgt_mask = tc.triu(torch.ones((3, 3)) * float('-inf'), diagonal=1)` <br>
A better option is to just use a very small number, e.g., $-10^{9}$, or $-10^{10}$, instead of $-\infty$. Another option, is to use a boolean mask where `True` corresponds to the unmasked and `False` to the masked values. We will see that in the next notebook.
In our code below, `tgt_mask` is created using `generate_square_subsequent_mask(tgt.size(1))` because it is already part of PyTorch's `nn.Transformer` module; we do not need to manually code it. This function generates the triangular mask matrix as described, tailored to the length of the target sequence. The mask is then applied in the transformer during the forward pass to prevent the decoder from "peeking" at future tokens.

In [None]:
class TransformerModel(nn.Module):
    def __init__(self, src_voc_size, tgt_voc_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, d_ff, max_seq_len, dropout):
        super(TransformerModel, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        # Using PyTorch's Transformer:
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=d_ff,
            dropout=dropout,
            batch_first=True
        )
        # Embedding layers for encoder and decoder:
        self.encoder_embedding = nn.Embedding(num_embeddings=src_voc_size, embedding_dim=d_model)
        self.decoder_embedding = nn.Embedding(num_embeddings=tgt_voc_size, embedding_dim=d_model)
        # Same positional encoding for encoder & decoder (no trainable parameters):
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
        # Final linear layer:
        self.W_O = nn.Linear(d_model, tgt_voc_size)
        # Same dropout probability everywhere:
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, tgt):
        # Apply embedding, then PE, then dropout, to both source and target:
        src = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))
        # Generate mask for the decoder:
        tgt_mask = self.transformer.generate_square_subsequent_mask(sz=tgt.size(1), device=device)
        # Apply mask and pass through the Transformer (encoder & decoder):
        out = self.transformer(src=src, tgt=tgt, tgt_mask=tgt_mask)
        # Apply final transformation (with W_O) and return:
        return self.W_O(out)

In [None]:
# Hyperparameters:
src_vocab_size = 5000
tgt_vocab_size = 5000
dmodel = 256
max_seq_length = 512
nheads = 4
nlayers = 2
dff = 512
drop = 0.1

# Initialize the model:
transformer = TransformerModel(
    src_voc_size=src_vocab_size,
    tgt_voc_size=tgt_vocab_size,
    d_model=dmodel,
    num_heads=nheads,
    num_encoder_layers=nlayers,
    num_decoder_layers=nlayers,
    d_ff=dff,
    max_seq_len=max_seq_length,
    dropout=drop
).to(device)

total_params = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
print(f"Total trainable parameters = {total_params:,}")

# this does not work in colab:
_ = summary(model=transformer)

Total trainable parameters = 6,481,800
Layer (type:depth-idx)                             Param #
├─Transformer: 1-1                                 --
|    └─TransformerEncoder: 2-1                     --
|    |    └─ModuleList: 3-1                        1,054,208
|    |    └─LayerNorm: 3-2                         512
|    └─TransformerDecoder: 2-2                     --
|    |    └─ModuleList: 3-3                        1,581,568
|    |    └─LayerNorm: 3-4                         512
├─Embedding: 1-2                                   1,280,000
├─Embedding: 1-3                                   1,280,000
├─PositionalEncoding: 1-4                          --
├─Linear: 1-5                                      1,285,000
├─Dropout: 1-6                                     --
Total params: 6,481,800
Trainable params: 6,481,800
Non-trainable params: 0


In [None]:
# Training hyperparameters:
lr_ = 1e-4
betas_ = (0.9, 0.98)
eps_ = 1e-9
num_epochs = 30
batch_size = 32

# We can "tell" cross-entropy not to consider targets of a specified index.
# Let us chose 0 because typically it is reserved for padding tokens:
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
optimizer = tc.optim.Adam(params=transformer.parameters(), lr=lr_, betas=betas_, eps=eps_)

In this tutorial, we will not use an actual dataset to train our model. We will just create a batch of random integers and just see if loss decreases or if we encounter any issues with our implementation so far.

+ `src_data`: Random integers between 1 and `src_vocab_size`, representing a batch of source sequences with shape `(batch_size, max_seq_length)`.

+ `tgt_data`: Random integers between 1 and tgt_vocab_size, representing a batch of target sequences with shape `(batch_size, max_seq_length)`.

+ These random sequences can be used as inputs to the transformer model, simulating a batch of data with batch_size examples and sequences of length `max_seq_len`. This setup could be part of a larger script where the model is trained and evaluated on actual sequence-to-sequence tasks, such as machine translation or text summarization.

In [None]:
# Random batch of data for training and validation to test if the model
# works so far until we find a real dataset:
src_data = tc.randint(1, src_vocab_size, (batch_size, max_seq_length), device=device)
tgt_data = tc.randint(1, tgt_vocab_size, (batch_size, max_seq_length), device=device)

## <u>Training our PyTorch Transformer Model</u>

What follows now is our main training loop. We have established that the model predicts the next token given the previous ones, so, we have to:
+ (1) Exclude the last token from the target ("shifted right" part) when calling the model, because essentially nothing follows afterwards for the model to predict.
+ (2) Exclude the first token when computing the loss because the first token the model tries to predict is not the start-of-sequence token, but the first actual token of the sequence.

In [None]:
tic = perf_counter()
# Put model in training mode:
transformer.train()
# Prettier print with tqdm:
pbar = tqdm(range(num_epochs), desc='Epochs')
# Each epoch here is just iterating over the same batch:
for epoch in pbar:
    optimizer.zero_grad()
    # Exclude the last token from target ("shifted right" part):
    output = transformer(src_data, tgt_data[:, :-1])
    # [batch_size, seq_len, vocab_size] -> [batch_size * seq_len, vocab_size]:
    output_aligned = output.contiguous().view(-1, tgt_vocab_size)
    # Exclude the first token:
    target_aligned = tgt_data[:, 1:].contiguous().view(-1)
    # Calculate loss:
    loss = loss_fn(output_aligned, target_aligned)
    # Backpropagation:
    loss.backward()
    # Update weights:
    optimizer.step()
    # Print:
    pbar.set_description(f"Epoch: {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")
toc = perf_counter()
print(f"Time elapsed: {toc-tic:.4f} seconds.")

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Epoch: 30/30, Loss: 8.3667: 100%|██████████| 30/30 [00:09<00:00,  3.31it/s]

Time elapsed: 9.0783 seconds.





Let's also simulate how one could evalute this model using a validation batch. This could be done during training to find an early-stop criterion and to adjust the hyperparameters better:

In [None]:
# Validation batch:
val_src_data = tc.randint(1, src_vocab_size, (batch_size, max_seq_length), device=device)
val_tgt_data = tc.randint(1, tgt_vocab_size, (batch_size, max_seq_length), device=device)

# Put model in evaluation mode:
transformer.eval()
# Same commands:
with tc.no_grad():
    output = transformer(val_src_data, val_tgt_data[:, :-1])
    output_aligned = output.contiguous().view(-1, tgt_vocab_size)
    target_aligned = val_tgt_data[:, 1:].contiguous().view(-1)
    val_loss = loss_fn(output_aligned, target_aligned)
    print(f"Validation Loss: {val_loss.item():.4f}")

Validation Loss: 8.6479
