# Attention

This notebook used material from the PyTorch tutorial [Language Translation with nn.Transformer and torchtext](https://pytorch.org/tutorials/beginner/translation_transformer.html).

I used Black to format the code in this notebook. If you want to contribute, use the following to format code cells upon running them.

In [None]:
import jupyter_black

jupyter_black.load(
    lab=False,
    line_length=120,
)

## Prerequisites

In addition to packages you've installed in previous notebooks, you'll need the natural language processing tools `torchtext` and `spacy`. Also, to download the dataset we'll be using, you'll need `portalocker`.

In [None]:
# !pip install torchtext spacy portalocker

Tokenizers split a string into symbols. There are tokenizers for different languages and genres. Download the necessary Spacy tokenizers by entering the following into a terminal.

```sh
python -m spacy download en_core_web_sm
python -m spacy download de_core_news_sm
```

In [None]:
import base64
import copy
import time
from typing import Iterable

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchtext
from torchtext.datasets import Multi30k
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import tqdm

You should use a GPU to run this notebook. If you don't, you might die before completing it.

In [None]:
DEVICE = torch.device("cuda" if (have_cuda := torch.cuda.is_available()) else "cpu")
print(f"""You {"don't " if not have_cuda else ""}have a GPU.""")

# Why

I'm going to make up a language, called Nyan, and teach it to you through parallel examples of Nyan and English. Let's go.

* nyan nyaan: Look at me. 
* nyaa nyaan: Look behind you.
* nyaa ya nyauu a: There is a human behind you.
* nyauu u nyao: Follow the human.

Now, translate the following English sentence into Nyan.

* Follow me.

Uncomment the following cell to see the expected translation:

In [None]:
# print(base64.b64decode(b'bnlhbiB1IG55YW8=').decode())

While processing languages, we often split a stream of information into tokens, which we treat as individual symbols. While looking at those Nyan example sentences, you probably treated each sequence of space-separated characters as a token.

In the previous examples, I gave one sentence with the idea of *following*, and one with the idea of *me*. I showed at least two examples of the other major concepts *looking*, *behind*, and *human*.

Then, when I asked you to translate "Follow me," you probably plated a lot of importance on the tokens *nyan*, *u*, and *nyao*, since you saw those only once, and in those cases, the other tokens had established correlations with other concepts through two examples.

In the sentence pair "nyaa ya nyauu a: There is a human behind you" which Nyan token corresponds most strongly to the to the English token "human"?

In [None]:
# print(base64.b64decode(b'bnlhdXU=').decode())

Thus, when presented with the token *human* you're able to focus almost exclusively on this token to learn how to express this concept quickly. If you were to have a correspondence representation in the form of a probability distribution over the tokens [nyaa, ya, nyauu, a], it would be something like $[\approx 0, \approx 0, \approx 1, \approx 0]$.

# Multi-Head Attention in Transformers

Now, please read [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/). You should get to know at least the following concepts.
* The encoder-decoder structure
* Why we want many attention heads
* Why we want positional encodings
* What masks are for

In this notebook, you'll implement masked multi-head attention, and stick it into a Transformer, diagrammed in the following figure.

<div>
<img src="data/transformer.png" width="400"/>
</div>

(Transformer diagram from [Attention Is All You Need](https://commons.wikimedia.org/wiki/File:Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg). The Transformer is composed of an encoder shown on the left half and a decoder shown on the right half.)

We will then use the Transformer to translate German to English using the [Multi30k](https://www.statmt.org/wmt16/multimodal-task.html#task1) dataset. This dataset consists of English and German descriptions of images, among other things, but we'll use only the English and German sentences.

First, we will create the input embedding and output embedding layers.

<div>
<img src="data/transformer_embedding.png" width="400"/>
</div>

They take integer token indices and learn vector representations of them. They are just `nn.Embedding` modules with the output scaled by the square root of the number of elements in each vector.

In [None]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int):
        """
        :param vocab_size: Number of tokens in the token space.
        :param embed_dim: Number of elements in embedding vectors.
        """
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.embed_dim = embed_dim

    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
        """
        :param tokens: Tensor of integers.
        :returns: Embeddings of tokens, shaped (*(tokens shape), embedding size)
        """
        return self.embedding(tokens) * self.embed_dim**0.5

Worth discussing with your friends:
> Why scale it by the square root of the embedding size? What happens during attention calculation if we don't?

Next, the positional encoding layers.

<div>
<img src="data/transformer_positional.png" width="400"/>
</div>

In ancient times, when recurrent neural networks were typically used for language translation, we didn't need to encode the positions of input tokens, since they were processed in order (front-to-back and back-to-front). We don't want to do that now because it's slow, but we still need some way to marking where in a sequence an input token occurred. There are a few ways of doing this, but the Transformer paper interleaved some scaled sine and cosine values, to be added to the embeddings.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim: int, dropout: float, max_length: int = 5000):
        """
        :param embed_dim: Number of elements in embedding vectors.
        :param dropout: Probability of zeroing an output element.
        :param max_length: Longest sequence length supported.
        """
        super().__init__()
        ln_10000 = 9.21034049987793
        positional_encoding = torch.zeros((max_length, embed_dim))
        positions = torch.arange(0, max_length).reshape(max_length, 1)
        scale = torch.exp(-torch.arange(0, embed_dim, 2) * ln_10000 / embed_dim)
        positional_encoding[:, 0::2] = torch.sin(positions * scale)
        positional_encoding[:, 1::2] = torch.cos(positions * scale)
        positional_encoding = positional_encoding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer("positional_encoding", positional_encoding)

    def forward(self, token_embedding: torch.Tensor):
        """
        :param token_embedding: Token embeddings shaped (sequence length, batch size, embedding size)
        :returns: Embeddings with positional encoding added, possibly with some dropouts.
        """
        return self.dropout(token_embedding + self.positional_encoding[: token_embedding.size(0), :])

As an example, suppose there were embeddings of size 64, and a sequence of length 100. The following positional encodings would be added to the sequence.

In [None]:
plt.pcolormesh(PositionalEncoding(64, 0, 100).positional_encoding[:, 0, :].permute(1, 0).numpy(), cmap="coolwarm")
plt.colorbar()
plt.title("positional encoding")
plt.xlabel("position")
plt.ylabel("encoding element")
plt.gca().invert_yaxis()
plt.show()

Next, we'll implement (masked) multi-head attention.

<div>
<img src="data/transformer_attention.png" width="400"/>
</div>

All we have to do is implement masked multi-head attention and use it without a mask where appropriate. These attention modules sit inside an encoder or decoder layer, and such layers are composed multiple times to make an encoder or a decoder. A mask is applied in the decoder attention module to prevent attending to future tokens when predicting the current token.

Inside a multi-head attention module we do the following. Look for references to these steps in the comments to help you implement it.
1. We take three copies of the input embeddings (called the *query*, *key*, and *value*), and do a linear transformation on each to a typically smaller size. This is sometimes called *in-projection*.

1. We calculate the attention weights.

1. If a mask is provided, we apply it to the weights.

1. We apply dropout to the attention weights, at some given dropout probability.

1. We multiply the attention weights with the in-projected values.

1. We apply a linear transformation, called *out-projection*.

In [None]:
class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: int = 0.0, kdim: int = None, vdim: int = None):
        """
        :param embed_dim: Size of embeddings, must be a multiple of num_heads.
        :param num_heads: Number of parallel attention heads. Each head attends to embed_dim / num_heads elements.
        :param dropout: Dropout probability on attention weights.
        :param kdim: Total number of features for keys. If None, kdim=embed_dim.
        :param vdim: Total number of features for values. If None, vdim=embed_dim.
        """
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim

        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.dropout_probability = dropout

        # STEP 1: IN-PROJECTION
        self.query_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.key_linear = nn.Linear(self.kdim, self.embed_dim)
        self.value_linear = nn.Linear(self.vdim, self.embed_dim)
        # END STEP 1

        # STEP 6: OUT-PROJECTION
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
        # END STEP 6

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_padding_mask: torch.Tensor = None,
        attention_mask: torch.Tensor = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        :param query: query embeddings shaped (target sequence length, batch size, embedding size)
        :param key: key embeddings shaped (source sequence length, batch size, embedding size)
        :param value: value embeddings shaped (source sequence length, batch size, embedding size)
        :param key_padding_mask: floating-point mask shaped (batch size, source sequence length).
            This mask will be added to the attention mask before softmax.
            To ignore elements in the key embedding, put -inf in corresponding places in the mask.
        :param attention_mask: floating-point mask shaped (target sequence length, source sequence length).
            This mask will be added to the attention weights before softmax.
        :returns: tuple of (attention output, attention weight).
            Attention output is shaped (target sequence length, batch size, embedding size).
            Attention weight is shaped ()

        Simplification of nn.functional.multi_head_attention_forward
        """

        target_length, batch_size, _ = query.shape
        source_length, _, _ = key.shape

        # STEP 1: In-projection
        q = self.query_linear(query)
        k = self.key_linear(key)
        v = self.value_linear(value)
        # END STEP 1

        # Need to reshape to prepare for later linear layers.
        # q: (target sequence length, batch size, embedding size)
        # -> (batch size * number of heads, target sequence length, head size)
        q = q.view(target_length, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
        # k: (source sequence length, batch size, embedding size)
        # -> (batch size * number of heads, source sequence length, head size)
        k = k.view(k.shape[0], batch_size * self.num_heads, self.head_dim).transpose(0, 1)
        # v: (source sequence length, batch size, embedding size)
        # -> (batch size * number of heads, source sequence length, head size)
        v = v.view(v.shape[0], batch_size * self.num_heads, self.head_dim).transpose(0, 1)

        q = q.view(batch_size, self.num_heads, target_length, self.head_dim)
        k = k.view(batch_size, self.num_heads, source_length, self.head_dim)
        v = v.view(batch_size, self.num_heads, source_length, self.head_dim)

        # STEP 2: Calculate attention weights
        attention_weight = q @ k.transpose(-2, -1) / q.size(-1) ** 0.5
        # END STEP 2

        # STEP 3: Apply masks
        attention_mask = self._combine_masks(key_padding_mask, attention_mask, target_length, batch_size, source_length)
        if attention_mask is not None:
            attention_weight += attention_mask
        # END STEP 3

        # STEP 4: Softmax and dropout
        attention_weight = nn.functional.dropout(
            torch.softmax(attention_weight, dim=-1), self.dropout_probability, self.training
        )
        # END STEP 4

        # STEP 5
        attention_output = attention_weight @ v
        # END STEP 5

        # STEP 6
        attention_output = (
            attention_output.permute(2, 0, 1, 3).contiguous().view(batch_size * target_length, self.embed_dim)
        )
        attention_output = self.out_proj(attention_output)
        # END STEP 6

        attention_output = attention_output.view(target_length, batch_size, attention_output.size(1))

        return attention_output, attention_weight

    def _combine_masks(
        self,
        key_padding_mask: torch.Tensor,
        attention_mask: torch.Tensor,
        target_length: int,
        batch_size: int,
        source_length: int,
    ) -> torch.Tensor:
        if attention_mask is not None:
            if attention_mask.dim() == 2:
                attention_mask = attention_mask.unsqueeze(0)

        if key_padding_mask is not None:
            key_padding_mask = (
                key_padding_mask.view(batch_size, 1, 1, source_length)
                .expand(-1, self.num_heads, -1, -1)
                .reshape(batch_size * self.num_heads, 1, source_length)
            )
            if attention_mask is not None:
                attention_mask = attention_mask + key_padding_mask  # Don't say +=.

        if attention_mask is not None:
            if attention_mask.shape[0] == 1:  # batch size 1 and 1 head
                attention_mask = attention_mask.unsqueeze(0)
            else:
                attention_mask = attention_mask.view(batch_size, self.num_heads, target_length, source_length)
        return attention_mask

Next, the following classes combine multi-head attention, layer norm, and linear layers into an encoder layer, and compose a bunch of these into an encoder. Note that each layer is the same instance.

<div>
<img src="data/transformer_encoder.png" width="400"/>
</div>

In [None]:
class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: callable = torch.nn.functional.relu,
        layer_norm_eps: float = 1e-5,
    ):
        """
        :param d_model: the number of expected features in the input.
        :param nhead: the number of heads in the multiheadattention models.
        :param dim_feedforward: the dimension of the feedforward network model.
        :param dropout: the dropout value (default=0.1).
        :param activation: the activation function of the intermediate layer
        :param layer_norm_eps: the eps value in layer normalization components.
        """
        super().__init__()
        self.self_attention = MultiheadAttention(d_model, nhead, dropout=dropout)

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = activation

    def forward(
        self, source: torch.Tensor, source_mask: torch.Tensor = None, source_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Pass the input through the encoder layer.

        :param src: the sequence to the encoder layer.
        :param source_mask: the mask for the src sequence.
        :param source_key_padding_mask: the mask for the src keys per batch.
        """
        x = source
        x = self.norm1(x + self._self_attention_block(x, source_mask, source_key_padding_mask))
        x = self.norm2(x + self._feedforward_block(x))

        return x

    def _self_attention_block(
        self, x: torch.Tensor, attention_mask: torch.Tensor, key_padding_mask: torch.Tensor
    ) -> torch.Tensor:
        x, _ = self.self_attention(x, x, x, attention_mask=attention_mask, key_padding_mask=key_padding_mask)
        return self.dropout1(x)

    def _feedforward_block(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)


class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None):
        """
        TransformerEncoder is a stack of encoder layers. Users can build the
        BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.

        :param encoder_layer: an instance of the TransformerEncoderLayer class.
        :param num_layers: the number of sub-encoder-layers in the encoder.
        :param norm: the layer normalization component.
        """
        super().__init__()
        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
        self.num_layers = num_layers
        self.norm = norm

    def forward(
        self, src: torch.Tensor, mask: torch.Tensor = None, source_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Pass the input through the encoder layers in turn.
        :param src: the sequence to the encoder.
        :param mask: the mask for the src sequence.
        :param source_key_padding_mask: the mask for the src keys per batch.
        """
        output = src
        for layer in self.layers:
            output = layer(output, source_mask=mask, source_key_padding_mask=source_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

As with the encoder, we make a decoder layer and compose a bunch of decoder layers into a decoder.

<div>
<img src="data/transformer_decoder.png" width="400"/>
</div>

In [None]:
class TransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: callable = torch.nn.functional.relu,
        layer_norm_eps: float = 1e-5,
    ):
        """
        :param d_model: the number of expected features in the input.
        :param nhead: the number of heads in the multiheadattention models.
        :param dim_feedforward: the dimension of the feedforward network model.
        :param dropout: the dropout probability.
        :param activation: the activation function of the intermediate layer.
        :param layer_norm_eps: the eps value in layer normalization components.
        """
        super().__init__()
        self.self_attention = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attention = MultiheadAttention(d_model, nhead, dropout=dropout)

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = activation

    def forward(
        self,
        target: torch.Tensor,
        memory: torch.Tensor,
        target_mask: torch.Tensor = None,
        memory_mask: torch.Tensor = None,
        target_key_padding_mask: torch.Tensor = None,
        memory_key_padding_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        """Pass the inputs (and mask) through the decoder layer.
        :param target: the sequence to the decoder layer.
        :param memory: the sequence from the last layer of the encoder.
        :param target_mask: the mask for the target sequence.
        :param memory_mask: the mask for the memory sequence.
        :param target_key_padding_mask: the mask for the target keys per batch.
        :param memory_key_padding_mask: the mask for the memory keys per batch.
        """
        x = target

        x = self.norm1(x + self._self_attention_block(x, target_mask, target_key_padding_mask))
        x = self.norm2(x + self._multihead_attention_block(x, memory, memory_mask, memory_key_padding_mask))
        x = self.norm3(x + self._feedforward_block(x))

        return x

    def _self_attention_block(self, x, attention_mask, key_padding_mask) -> torch.Tensor:
        x, _ = self.self_attention(x, x, x, attention_mask=attention_mask, key_padding_mask=key_padding_mask)
        return self.dropout1(x)

    def _multihead_attention_block(self, x, mem, attention_mask, key_padding_mask) -> torch.Tensor:
        x, _ = self.multihead_attention(x, mem, mem, attention_mask=attention_mask, key_padding_mask=key_padding_mask)
        return self.dropout2(x)

    def _feedforward_block(self, x) -> torch.Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout3(x)


class TransformerDecoder(nn.Module):
    def __init__(self, decoder_layer: nn.Module, num_layers: int, norm: nn.Module = None):
        """
        TransformerDecoder is a stack of decoder layers

        :param decoder_layer: an instance of the TransformerDecoderLayer class.
        :param num_layers: the number of sub-decoder-layers in the decoder.
        :param norm: the layer normalization component.
        """
        super().__init__()
        self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
        self.num_layers = num_layers
        self.norm = norm

    def forward(
        self,
        target,
        memory,
        target_mask=None,
        memory_mask=None,
        target_key_padding_mask=None,
        memory_key_padding_mask=None,
    ) -> torch.Tensor:
        """
        Pass the inputs (and mask) through the decoder layer in turn.

        :param target: the sequence to the decoder.
        :param memory: the sequence from the last layer of the encoder.
        :param target_mask: the mask for the target sequence.
        :param memory_mask: the mask for the memory sequence.
        :param target_key_padding_mask: the mask for the target keys per batch.
        :param memory_key_padding_mask: the mask for the memory keys per batch.
        """
        output = target

        for layer in self.layers:
            output = layer(
                output,
                memory,
                target_mask=target_mask,
                memory_mask=memory_mask,
                target_key_padding_mask=target_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask,
            )

        if self.norm is not None:
            output = self.norm(output)

        return output

The composition of the encoder and the decoder will form the Transformer. This is very similar to the `nn.Transformer` module in PyTorch 2.0.
<div>
<img src="data/transformer_torch.png" width="400"/>
</div>

In [None]:
class Transformer(nn.Module):
    def __init__(
        self,
        d_model: int = 512,
        nhead: int = 8,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: callable = torch.nn.functional.relu,
        layer_norm_eps: float = 1e-5,
    ):
        """
        :param d_model: the number of expected features in the encoder or decoder inputs.
        :param nhead: the number of heads in the multiheadattention models.
        :param num_encoder_layers: the number of sub-encoder-layers in the encoder.
        :param num_decoder_layers: the number of sub-decoder-layers in the decoder.
        :param dim_feedforward: the dimension of the feedforward network model.
        :param dropout: the dropout probability.
        :param activation: the activation function of encoder and decoder intermediate layer
        :param layer_norm_eps: the eps value in layer normalization components.
        """
        super().__init__()

        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps)
        encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps)
        decoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

        for parameter in self.parameters():
            if parameter.dim() > 1:
                torch.nn.init.xavier_uniform_(parameter)

        self.d_model = d_model
        self.nhead = nhead

    def forward(
        self,
        src: torch.Tensor,
        target: torch.Tensor,
        source_mask: torch.Tensor = None,
        target_mask: torch.Tensor = None,
        memory_mask: torch.Tensor = None,
        source_key_padding_mask: torch.Tensor = None,
        target_key_padding_mask: torch.Tensor = None,
        memory_key_padding_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        """Take in and process masked source/target sequences.
        :param src: the sequence to the encoder.
        :param target: the sequence to the decoder.
        :param source_mask: the additive mask for the src sequence.
        :param target_mask: the additive mask for the target sequence.
        :param memory_mask: the additive mask for the encoder output.
        :param source_key_padding_mask: the Tensor mask for src keys per batch.
        :param target_key_padding_mask: the Tensor mask for target keys per batch.
        :param memory_key_padding_mask: the Tensor mask for memory keys per batch.
        """
        memory = self.encoder(src, mask=source_mask, source_key_padding_mask=source_key_padding_mask)
        output = self.decoder(
            target,
            memory,
            target_mask=target_mask,
            memory_mask=memory_mask,
            target_key_padding_mask=target_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
        )
        return output

    @staticmethod
    def generate_square_subsequent_mask(sz: int, device="cpu") -> torch.Tensor:
        """
        Generate a square mask for the sequence. The masked positions are filled with float('-inf').
        Unmasked positions are filled with float(0.0).
        """
        return torch.triu(torch.full((sz, sz), float("-inf"), device=device), diagonal=1)

You might notice a few differences between this `Transformer` module and the Transformer diagram we've been looking at.
* This module doesn't include the embedding and positional encoding layers.
* This module doesn't include the final linear and softmax layers.

If we want to use it for natural language translation, we have to add these layers except softmax to model that we'll train to translate sentences, which is the following `Seq2SeqTransformer`. We don't need the final softmax because we'll just take the argmax of the outputs to get the predicted token.

In [None]:
class Seq2SeqTransformer(nn.Module):
    def __init__(
        self,
        num_encoder_layers: int,
        num_decoder_layers: int,
        emb_size: int,
        nhead: int,
        source_vocab_size: int,
        target_vocab_size: int,
        dim_feedforward: int = 512,
        dropout: float = 0.1,
    ):
        """
        Sequence-to-sequence transformer for translating natural languages.
        """
        super(Seq2SeqTransformer, self).__init__()
        self.source_token_embedding = TokenEmbedding(source_vocab_size, emb_size)
        self.target_token_embedding = TokenEmbedding(target_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)
        self.transformer = Transformer(
            d_model=emb_size,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )
        self.generator = nn.Linear(emb_size, target_vocab_size)

    def forward(
        self,
        source: torch.Tensor,
        target: torch.Tensor,
        source_mask: torch.Tensor,
        target_mask: torch.Tensor,
        source_padding_mask: torch.Tensor,
        target_padding_mask: torch.Tensor,
        memory_key_padding_mask: torch.Tensor,
    ) -> torch.Tensor:
        return self.generator(
            self.transformer(
                self.positional_encoding(self.source_token_embedding(source)),
                self.positional_encoding(self.target_token_embedding(target)),
                source_mask,
                target_mask,
                None,
                source_padding_mask,
                target_padding_mask,
                memory_key_padding_mask,
            )
        )

    def encode(self, src: torch.Tensor, source_mask: torch.Tensor) -> torch.Tensor:
        return self.transformer.encoder(self.positional_encoding(self.source_token_embedding(src)), source_mask)

    def decode(self, target: torch.Tensor, memory: torch.Tensor, target_mask: torch.Tensor) -> torch.Tensor:
        return self.transformer.decoder(
            self.positional_encoding(self.target_token_embedding(target)), memory, target_mask
        )

# Sequence-to-sequence translation

To translate natural languages, we have to prepare the tokenizers and build the vocabularies (token space) for English and German. We will prepare special beginning-of-sequence (BOS) and end-of-sequence (EOS) to mark sequnece boundaries. We also have an unknown (UNK) token to handle subsequences outside the tokenizer's vocabulary, or tokens that appear to infrequently for us to care.

In [None]:
LANGUAGES = (SOURCE_LANGUAGE := "de", TARGET_LANGUAGE := "en")
SPECIAL_SYMBOLS = (UNK := "<unk>", PAD := "<pad>", BOS := "<bos>", EOS := "<eos>")
SPECIAL_SYMBOL_INDICES = {symbol: index for index, symbol in enumerate(SPECIAL_SYMBOLS)}
MINIMUM_FREQUENCY = 2

token_transform = {
    SOURCE_LANGUAGE: torchtext.data.utils.get_tokenizer("spacy", language="de_core_news_sm"),
    TARGET_LANGUAGE: torchtext.data.utils.get_tokenizer("spacy", language="en_core_web_sm"),
}


def yield_tokens(data_iter: Iterable, language: str) -> list[str]:
    for data_sample in data_iter:
        yield token_transform[language](data_sample[{SOURCE_LANGUAGE: 0, TARGET_LANGUAGE: 1}[language]])


vocab_transform = {}
for language in LANGUAGES:
    vocab_transform[language] = torchtext.vocab.build_vocab_from_iterator(
        yield_tokens(Multi30k(split="train", language_pair=LANGUAGES), language),
        min_freq=MINIMUM_FREQUENCY,
        specials=SPECIAL_SYMBOLS,
        special_first=True,
    )
    # Set UNK_IDX as the default index. This is returned when the token is not found.
    vocab_transform[language].set_default_index(SPECIAL_SYMBOL_INDICES[UNK])

These are some helper functions for messing with natural language datasets.

In [None]:
def sequential_transforms(*transforms):
    """
    Compose a bunch of transforms.
    """

    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input

    return func


def tensor_transform(token_ids: list[int]):
    """
    Add BOS/EOS and create tensor for input sequence indices
    """
    return torch.cat(
        (
            torch.tensor([SPECIAL_SYMBOL_INDICES[BOS]]),
            torch.tensor(token_ids),
            torch.tensor([SPECIAL_SYMBOL_INDICES[EOS]]),
        )
    )


# Source and target language text transforms to convert raw strings into tensors indices
text_transform = {
    language: sequential_transforms(token_transform[language], vocab_transform[language], tensor_transform)
    for language in LANGUAGES
}


def collate_fn(batch):
    """
    Collate data samples into batch tensors
    """
    sourcebatch, target_batch = [], []
    for sourcesample, target_sample in batch:
        sourcebatch.append(text_transform[SOURCE_LANGUAGE](sourcesample.rstrip("\n")))
        target_batch.append(text_transform[TARGET_LANGUAGE](target_sample.rstrip("\n")))

    sourcebatch = nn.utils.rnn.pad_sequence(sourcebatch, padding_value=SPECIAL_SYMBOL_INDICES[PAD])
    target_batch = nn.utils.rnn.pad_sequence(target_batch, padding_value=SPECIAL_SYMBOL_INDICES[PAD])
    return sourcebatch, target_batch

We'll use the following function to create triangular masks and padding masks.

In [None]:
def create_masks(src, target):
    sourceseq_len = src.shape[0]
    target_seq_len = target.shape[0]

    target_mask = Transformer.generate_square_subsequent_mask(target_seq_len, DEVICE)
    source_mask = torch.zeros((sourceseq_len, sourceseq_len), device=DEVICE, dtype=torch.float32)

    source_padding_mask = torch.zeros_like(src, dtype=torch.float32)
    target_padding_mask = torch.zeros_like(target, dtype=torch.float32)
    source_padding_mask[src == SPECIAL_SYMBOL_INDICES[PAD]] = -torch.inf
    target_padding_mask[target == SPECIAL_SYMBOL_INDICES[PAD]] = -torch.inf
    source_padding_mask = source_padding_mask.transpose(0, 1)
    target_padding_mask = target_padding_mask.transpose(0, 1)

    return source_mask, target_mask, source_padding_mask, target_padding_mask

Here, we will define the hyperparameters for our Transformer. If you have difficulty running this notebook, change these.

In [None]:
SOURCE_VOCAB_SIZE = len(vocab_transform[SOURCE_LANGUAGE])
TARGET_VOCAB_SIZE = len(vocab_transform[TARGET_LANGUAGE])
EMBED_SIZE = 512
NUM_HEADS = 8
FEEDFORWARD_HIDDEN_SIZE = 512
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6
BATCH_SIZE = 64

LEARNING_RATE = 0.0001
NUM_EPOCHS = 15

We'll make the train and validation data loaders. I happened to know the number of training and validation examples.

In [None]:
train_dataloader = DataLoader(
    Multi30k(split="train", language_pair=(SOURCE_LANGUAGE, TARGET_LANGUAGE)),
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
)
validation_dataloader = DataLoader(
    Multi30k(split="valid", language_pair=(SOURCE_LANGUAGE, TARGET_LANGUAGE)),
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
)

# Used for a working progress bar.
N_TRAIN = 29001
N_VALIDATION = 1050

Now we instantiate a model, loss function, and optimizer.

In [None]:
model = Seq2SeqTransformer(
    NUM_ENCODER_LAYERS,
    NUM_DECODER_LAYERS,
    EMBED_SIZE,
    NUM_HEADS,
    SOURCE_VOCAB_SIZE,
    TARGET_VOCAB_SIZE,
    FEEDFORWARD_HIDDEN_SIZE,
)
for parameter in model.parameters():
    if parameter.dim() > 1:
        nn.init.xavier_uniform_(parameter)
model = model.to(DEVICE)
criterion = torch.nn.CrossEntropyLoss(ignore_index=SPECIAL_SYMBOL_INDICES[PAD])
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9)

And train.

In [None]:
for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    train_loss = 0
    total = 0

    for source, target in (
        progress_bar := tqdm.tqdm(
            train_dataloader, total=(N_TRAIN + BATCH_SIZE - 1) // BATCH_SIZE, desc=f"Epoch {epoch}"
        )
    ):
        source = source.to(DEVICE)
        target = target.to(DEVICE)

        target_input = target[:-1, :]
        source_mask, target_mask, source_padding_mask, target_padding_mask = create_masks(source, target_input)
        logits = model(
            source,
            target_input,
            source_mask,
            target_mask,
            source_padding_mask,
            target_padding_mask,
            source_padding_mask,
        )

        optimizer.zero_grad()

        target_out = target[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), target_out.reshape(-1))
        loss.backward()

        optimizer.step()

        train_loss += loss.item()
        total += target.shape[1]

        progress_bar.set_postfix_str(f"train loss: {train_loss / total:.3f}")

    model.eval()
    validation_loss = 0
    total = 0

    for source, target in (
        progress_bar := tqdm.tqdm(
            validation_dataloader, total=(N_VALIDATION + BATCH_SIZE - 1) // BATCH_SIZE, desc=f"Validation"
        )
    ):
        source = source.to(DEVICE)
        target = target.to(DEVICE)
        target_input = target[:-1, :]
        source_mask, target_mask, source_padding_mask, target_padding_mask = create_masks(source, target_input)
        logits = model(
            source,
            target_input,
            source_mask,
            target_mask,
            source_padding_mask,
            target_padding_mask,
            source_padding_mask,
        )
        target_out = target[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), target_out.reshape(-1))
        validation_loss += loss.item()
        total += target.shape[1]

        progress_bar.set_postfix_str(f"validation loss: {validation_loss / total:.3f}")

You can use the following `translate` function to try translating a German sentence to English. Note that we trained on **descriptions of images**, so the model will do better on sentences like that. 

In [None]:
def greedy_decode(model, source, source_mask, max_length, start_symbol):
    source = source.to(DEVICE)
    memory = model.encode(source, source_mask)
    predicted_tokens = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for _ in range(max_length - 1):
        memory = memory.to(DEVICE)
        target_mask = Transformer.generate_square_subsequent_mask(predicted_tokens.size(0)).to(DEVICE)
        out = model.decode(predicted_tokens, memory, target_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        predicted_tokens = torch.cat([predicted_tokens, torch.ones(1, 1).type_as(source.data).fill_(next_word)], dim=0)
        if next_word == SPECIAL_SYMBOL_INDICES[EOS]:
            break
    return predicted_tokens

def translate(model: torch.nn.Module, sourcesentence: str):
    model.eval()
    source = text_transform[SOURCE_LANGUAGE](sourcesentence).view(-1, 1)
    num_tokens = source.shape[0]
    source_mask = None
    target_tokens = greedy_decode(
        model, source, source_mask, max_length=num_tokens + 5, start_symbol=SPECIAL_SYMBOL_INDICES[BOS]
    ).flatten()
    return (
        " ".join(vocab_transform[TARGET_LANGUAGE].lookup_tokens(list(target_tokens.cpu().numpy())))
        .replace(BOS, "")
        .replace(EOS, "")
    )

In [None]:
translate(model, "Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen")

The target translation was "A group of men are loading cotton onto a truck."