<a href="https://colab.research.google.com/github/eisbetterthanpi/pytorch/blob/main/6_Attention_is_All_You_Need.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb
# https://colab.research.google.com/github/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb
# https://www.mihaileric.com/posts/transformers-attention-in-disguise/
# https://jalammar.github.io/illustrated-transformer/
# http://nlp.seas.harvard.edu/2018/04/03/attention.html

# position embedding has a "vocabulary" size of 100, model can accept sentences up to 100 tokens long
# we use a learned positional encoding instead of a static one
# we use the standard Adam optimizer with a static learning rate instead of one with warm-up and cool-down steps
# we do not use label smoothing


In [None]:
# @title setup

# https://pytorch.org/tutorials/beginner/translation_transformer.html
# https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/c64c91cf87c13c0e83586b8e66e4d74e/translation_transformer.ipynb

# https://github.com/pytorch/data
%pip install portalocker
%pip install torchdata

# Create source and target language tokenizer. Make sure to install the dependencies.
!pip install -U torchdata
!pip install -U spacy
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm


In [None]:
# @title data

from torchtext.datasets import multi30k, Multi30k
# modify the URLs for the dataset since the links to the original dataset are broken https://github.com/pytorch/text/issues/1756#issuecomment-1163664163
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

from torchtext.data.utils import get_tokenizer
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')


UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # unknown, pad, bigining, end of sentence
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

from torchtext.vocab import build_vocab_from_iterator
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))

de_tokens = [de_tokenizer(data_sample[0]) for data_sample in train_iter]
en_tokens = [en_tokenizer(data_sample[1]) for data_sample in train_iter]

de_vocab = build_vocab_from_iterator(de_tokens, min_freq=1, specials=special_symbols, special_first=True)
en_vocab = build_vocab_from_iterator(en_tokens, min_freq=1, specials=special_symbols, special_first=True)
de_vocab.set_default_index(UNK_IDX)
en_vocab.set_default_index(UNK_IDX)

import torch

def de_transform(o):
    o=de_tokenizer(o)
    o=de_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))

def en_transform(o):
    o=en_tokenizer(o)
    o=en_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))


from torch.nn.utils.rnn import pad_sequence
# function to collate data samples into batch tensors
def collate_fn(batch): # convert a batch of raw strings into batch tensors
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(de_transform(src_sample.rstrip("\n")))
        tgt_batch.append(en_transform(tgt_sample.rstrip("\n")))
    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=PAD_IDX)
    return src_batch, tgt_batch


torch.manual_seed(0)

train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
batch_size = 128 # 128
train_loader = torch.utils.data.DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_iter, batch_size=batch_size, collate_fn=collate_fn)

# vocab_transform = {SRC_LANGUAGE:de_vocab, TGT_LANGUAGE:en_vocab}
# text_transform = {SRC_LANGUAGE:de_transform, TGT_LANGUAGE:en_transform}




In [None]:
# @title model
import torch
import torch.nn as nn
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import math
class PositionalEncoder(nn.Module):
    def __init__(self, embedding_dim, max_seq_length=512, dropout=0.1):
        super(PositionalEncoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_seq_length, embedding_dim)
        position = torch.arange(0, max_seq_length).unsqueeze(1) # https://nlp.seas.harvard.edu/annotated-transformer/
        div_term = torch.exp(torch.arange(0, embedding_dim, 2) * -(math.log(10000.0) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x = x*math.sqrt(self.embedding_dim) # ?
        # return self.dropout(x + Variable(self.pe[:, :x.size(1)], requires_grad=False).to(x.device))
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)


class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        assert hid_dim % n_heads == 0
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)

    def forward(self, query, key, value, mask = None):
        batch_size = query.shape[0]
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]
        # print("mha fwd1",query.shape, key.shape, value.shape)
        Q = self.fc_q(query) #Q = [batch size, query len, hid dim]
        K = self.fc_k(key) #K = [batch size, key len, hid dim]
        V = self.fc_v(value) #V = [batch size, value len, hid dim]
        # print("mha fwd2",Q.shape,K.shape,V.shape)
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) #Q = [batch size, n heads, query len, head dim]
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) #K = [batch size, n heads, key len, head dim]
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) #V = [batch size, n heads, value len, head dim]
        # scaled dot-product attention
        # print("mha fwd3",Q.shape,K.shape,V.shape)
        energy = torch.matmul(Q, K.transpose(2,3)) / self.scale #energy = [batch size, n heads, query len, key len]
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        attention = torch.softmax(energy, dim = -1) #attention = [batch size, n heads, query len, key len]
        # why dropout applied directly to the attn?
        # print("mha fwd",self.dropout(attention).shape, V.shape) # [11, 8, 11, 11], [11, 8, 1, 32]
        x = torch.matmul(self.dropout(attention), V) #x = [batch size, n heads, query len, head dim]
        x = x.transpose(1, 2).contiguous() #x = [batch size, query len, n heads, head dim]
        x = x.view(batch_size, -1, self.hid_dim) #x = [batch size, query len, hid dim]
        x = self.fc_o(x) #x = [batch size, query len, hid dim]
        return x, attention

# maybe noneed?
class PositionwiseFeedforwardLayer(nn.Module): # input is transformed from hid_dim to pf_dim,  pf_dim>>hid_dim. then ReLU, dropout then transform back into a hid_dim representation
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x): #x = [batch size, seq len, hid dim]
        x = self.dropout(torch.relu(self.fc_1(x))) #x = [batch size, seq len, pf dim]
        # x = self.dropout(torch.gelu(self.fc_1(x))) #bert use gelu
        x = self.fc_2(x) #x = [batch size, seq len, hid dim]
        return x

class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_mask):
        #src = [batch size, src len, hid dim]
        #src_mask = [batch size, 1, 1, src len]
        _src, _ = self.self_attention(src, src, src, src_mask)
        src = self.self_attn_layer_norm(src + self.dropout(_src)) #src = [batch size, src len, hid dim]
        _src = self.positionwise_feedforward(src)
        src = self.ff_layer_norm(src + self.dropout(_src)) #src = [batch size, src len, hid dim]
        return src

class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, trg, enc_src, trg_mask, src_mask):
        #trg = [batch size, trg len, hid dim]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        trg = self.self_attn_layer_norm(trg + self.dropout(_trg)) #trg = [batch size, trg len, hid dim]
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        trg = self.enc_attn_layer_norm(trg + self.dropout(_trg)) #trg = [batch size, trg len, hid dim]
        _trg = self.positionwise_feedforward(trg)
        trg = self.ff_layer_norm(trg + self.dropout(_trg))
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        return trg, attention

class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device, max_length = 100):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim) # learnt positional embeddings
        self.pos_enc = PositionalEncoder(hid_dim, max_length, dropout)
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

    def forward(self, src, src_mask):
        #src = [batch size, src len]
        #src_mask = [batch size, 1, 1, src len]
        batch_size = src.shape[0]
        src_len = src.shape[1]
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) # [batch size, src len]
        # src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos)) #src = [batch size, src len, hid dim]
        src = self.dropout(self.pos_enc(self.tok_embedding(src) * self.scale)) # sine pos enc
        for layer in self.layers:
            src = layer(src, src_mask) #src = [batch size, src len, hid dim]
        return src

# src_mask, same shape as the source sentence,
# value of 1 when the token in the source sentence is not <pad> token,
# for encoder layers to mask the multi-head attention mechanisms,
# calculate and apply attention over the source sentence, so dun pay attention to <pad> tokens

class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device, max_length = 100):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.pos_enc = PositionalEncoder(hid_dim, max_length, dropout)
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

    def forward(self, trg, enc_src, trg_mask, src_mask):
        #trg = [batch size, trg len]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) #pos = [batch size, trg len]
        # trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos)) #trg = [batch size, trg len, hid dim]
        trg = self.dropout(self.pos_enc(self.tok_embedding(trg))) * self.scale # sine pos enc
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        output = self.fc_out(trg) #output = [batch size, trg len, output dim]
        return output, attention

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src): #src = [batch size, src len]
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2).to(self.device) #src_mask = [batch size, 1, 1, src len]
        return src_mask

# subsequent mask, tril, nth in tgt can only see up to nth in
# bitwise & with trg_pad_mask: dun pay attn to <pad>
    def make_trg_mask(self, trg): #trg = [batch size, trg len]
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2).to(self.device) #trg_pad_mask = [batch size, 1, 1, trg len]
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool() #trg_sub_mask = [trg len, trg len]
        trg_mask = trg_pad_mask & trg_sub_mask #trg_mask = [batch size, 1, trg len, trg len]
        return trg_mask

    def forward(self, src, trg):
        #src = [batch size, src len]
        #trg = [batch size, trg len]
        src_mask = self.make_src_mask(src) #src_mask = [batch size, 1, 1, src len]
        trg_mask = self.make_trg_mask(trg) #trg_mask = [batch size, 1, trg len, trg len]
        enc_src = self.encoder(src, src_mask) #enc_src = [batch size, src len, hid dim]
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        #output = [batch size, trg len, output dim]
        #attention = [batch size, n heads, trg len, src len]
        return output, attention


INPUT_DIM = len(de_vocab)
OUTPUT_DIM = len(en_vocab)
HID_DIM = 512 #256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1


# emb_size = 512
# nhead = 8
# dim_feedforward = 512
# num_encoder_layers = 3
# num_decoder_layers = 3
# transformer = Transformer(num_encoder_layers, num_decoder_layers, emb_size, nhead, src_vocab_size, tgt_vocab_size, dim_feedforward).to(device)


enc = Encoder(INPUT_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device)
dec = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device)
model = Seq2Seq(enc, dec, PAD_IDX, PAD_IDX, device).to(device)

# def initialize_weights(m):
#     if hasattr(m, 'weight') and m.weight.dim() > 1:
#         nn.init.xavier_uniform_(m.weight.data)
# model.apply(initialize_weights);

for p in model.parameters():
# for p in self.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)



In [None]:

def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
    src_emb = self.pos_enc(self.src_tok_emb(src))
    tgt_emb = self.pos_enc(self.tgt_tok_emb(tgt))
    outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
    return self.generator(outs)

def encode(self, src, src_mask):
    return self.transformer.encoder(self.pos_enc(self.src_tok_emb(src)), src_mask)

def decode(self, tgt, memory, tgt_mask):
    return self.transformer.decoder(self.pos_enc(self.tgt_tok_emb(tgt)), memory, tgt_mask)



##text

### Mutli Head Attention Layer

One of the key, novel concepts introduced by the Transformer paper is the *multi-head attention layer*.

![](https://github.com/bentrevett/pytorch-seq2seq/blob/master/assets/transformer-attention.png?raw=1)

Attention can be though of as *queries*, *keys* and *values* - where the query is used with the key to get an attention vector (usually the output of a *softmax* operation and has all values between 0 and 1 which sum to 1) which is then used to get a weighted sum of the values.

The Transformer uses *scaled dot-product attention*, where the query and key are combined by taking the dot product between them, then applying the softmax operation and scaling by $d_k$ before finally then multiplying by the value. $d_k$ is the *head dimension*, `head_dim`, which we will shortly explain further.

$$ \text{Attention}(Q, K, V) = \text{Softmax} \big( \frac{QK^T}{\sqrt{d_k}} \big)V $$

This is similar to standard *dot product attention* but is scaled by $d_k$, which the paper states is used to stop the results of the dot products growing large, causing gradients to become too small.

However, the scaled dot-product attention isn't simply applied to the queries, keys and values. Instead of doing a single attention application the queries, keys and values have their `hid_dim` split into $h$ *heads* and the scaled dot-product attention is calculated over all heads in parallel. This means instead of paying attention to one concept per attention application, we pay attention to $h$. We then re-combine the heads into their `hid_dim` shape, thus each `hid_dim` is potentially paying attention to $h$ different concepts.

$$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,...,\text{head}_h)W^O $$

$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$

$W^O$ is the linear layer applied at the end of the multi-head attention layer, `fc`. $W^Q, W^K, W^V$ are the linear layers `fc_q`, `fc_k` and `fc_v`.

Walking through the module, first we calculate $QW^Q$, $KW^K$ and $VW^V$ with the linear layers, `fc_q`, `fc_k` and `fc_v`, to give us `Q`, `K` and `V`. Next, we split the `hid_dim` of the query, key and value into `n_heads` using `.view` and correctly permute them so they can be multiplied together. We then calculate the `energy` (the un-normalized attention) by multiplying `Q` and `K` together and scaling it by the square root of `head_dim`, which is calulated as `hid_dim // n_heads`. We then mask the energy so we do not pay attention over any elements of the sequeuence we shouldn't, then apply the softmax and dropout. We then apply the attention to the value heads, `V`, before combining the `n_heads` together. Finally, we multiply this $W^O$, represented by `fc_o`.

Note that in our implementation the lengths of the keys and values are always the same, thus when matrix multiplying the output of the softmax, `attention`, with `V` we will always have valid dimension sizes for matrix multiplication. This multiplication is carried out using `torch.matmul` which, when both tensors are >2-dimensional, does a batched matrix multiplication over the last two dimensions of each tensor. This will be a **[query len, key len] x [value len, head dim]** batched matrix multiplication over the batch size and each head which provides the **[batch size, n heads, query len, head dim]** result.

One thing that looks strange at first is that dropout is applied directly to the attention. This means that our attention vector will most probably not sum to 1 and we may pay full attention to a token but the attention over that token is set to 0 by dropout. This is never explained, or even mentioned, in the paper however is used by the [official implementation](https://github.com/tensorflow/tensor2tensor/) and every Transformer implementation since, [including BERT](https://github.com/google-research/bert/).

### Decoder

The objective of the decoder is to take the encoded representation of the source sentence, $Z$, and convert it into predicted tokens in the target sentence, $\hat{Y}$. We then compare $\hat{Y}$ with the actual tokens in the target sentence, $Y$, to calculate our loss, which will be used to calculate the gradients of our parameters and then use our optimizer to update our weights in order to improve our predictions.

![](https://github.com/bentrevett/pytorch-seq2seq/blob/master/assets/transformer-decoder.png?raw=1)

The decoder is similar to encoder, however it now has two multi-head attention layers. A *masked multi-head attention layer* over the target sequence, and a multi-head attention layer which uses the decoder representation as the query and the encoder representation as the key and value.

The decoder uses positional embeddings and combines - via an elementwise sum - them with the scaled embedded target tokens, followed by dropout. Again, our positional encodings have a "vocabulary" of 100, which means they can accept sequences up to 100 tokens long. This can be increased if desired.

The combined embeddings are then passed through the $N$ decoder layers, along with the encoded source, `enc_src`, and the source and target masks. Note that the number of layers in the encoder does not have to be equal to the number of layers in the decoder, even though they are both denoted by $N$.

The decoder representation after the $N^{th}$ layer is then passed through a linear layer, `fc_out`. In PyTorch, the softmax operation is contained within our loss function, so we do not explicitly need to use a softmax layer here.

As well as using the source mask, as we did in the encoder to prevent our model attending to `<pad>` tokens, we also use a target mask. This will be explained further in the `Seq2Seq` model which encapsulates both the encoder and decoder, but the gist of it is that it performs a similar operation as the decoder padding in the convolutional sequence-to-sequence model. As we are processing all of the target tokens at once in parallel we need a method of stopping the decoder from "cheating" by simply "looking" at what the next token in the target sequence is and outputting it.

Our decoder layer also outputs the normalized attention values so we can later plot them to see what our model is actually paying attention to.

### Decoder Layer
decoder layer has two multi-head attention layers, `self_attention` and `encoder_attention`.

The first performs self-attention, as in the encoder, by using the decoder representation so far as the query, key and value. This is followed by dropout, residual connection and layer normalization. This `self_attention` layer uses the target sequence mask, `trg_mask`, in order to prevent the decoder from "cheating" by paying attention to tokens that are "ahead" of the one it is currently processing as it processes all tokens in the target sentence in parallel.

The second is how we actually feed the encoded source sentence, `enc_src`, into our decoder. In this multi-head attention layer the queries are the decoder representations and the keys and values are the encoder representations. Here, the source mask, `src_mask` is used to prevent the multi-head attention layer from attending to `<pad>` tokens within the source sentence.




### Seq2Seq

Finally, we have the `Seq2Seq` module which encapsulates the encoder and decoder, as well as handling the creation of the masks.

The source mask is created by checking where the source sequence is not equal to a `<pad>` token. It is 1 where the token is not a `<pad>` token and 0 when it is. It is then unsqueezed so it can be correctly broadcast when applying the mask to the `energy`, which of shape **_[batch size, n heads, seq len, seq len]_**.

The target mask is slightly more complicated. First, we create a mask for the `<pad>` tokens, as we did for the source mask. Next, we create a "subsequent" mask, `trg_sub_mask`, using `torch.tril`. This creates a diagonal matrix where the elements above the diagonal will be zero and the elements below the diagonal will be set to whatever the input tensor is. In this case, the input tensor will be a tensor filled with ones. So this means our `trg_sub_mask` will look something like this (for a target with 5 tokens):

$$\begin{matrix}
1 & 0 & 0 & 0 & 0\\
1 & 1 & 0 & 0 & 0\\
1 & 1 & 1 & 0 & 0\\
1 & 1 & 1 & 1 & 0\\
1 & 1 & 1 & 1 & 1\\
\end{matrix}$$

This shows what each target token (row) is allowed to look at (column). The first target token has a mask of **_[1, 0, 0, 0, 0]_** which means it can only look at the first target token. The second target token has a mask of **_[1, 1, 0, 0, 0]_** which it means it can look at both the first and second target tokens.

The "subsequent" mask is then logically anded with the padding mask, this combines the two masks ensuring both the subsequent tokens and the padding tokens cannot be attended to. For example if the last two tokens were `<pad>` tokens the mask would look like:

$$\begin{matrix}
1 & 0 & 0 & 0 & 0\\
1 & 1 & 0 & 0 & 0\\
1 & 1 & 1 & 0 & 0\\
1 & 1 & 1 & 0 & 0\\
1 & 1 & 1 & 0 & 0\\
\end{matrix}$$

After the masks are created, they used with the encoder and decoder along with the source and target sentences to get our predicted target sentence, `output`, along with the decoder's attention over the source sequence.

Then, we'll define our training loop. This is the exact same as the one used in the previous tutorial.

As we want our model to predict the `<eos>` token but not have it be an input into our model we simply slice the `<eos>` token off the end of the sequence. Thus:

$$\begin{align*}
\text{trg} &= [sos, x_1, x_2, x_3, eos]\\
\text{trg[:-1]} &= [sos, x_1, x_2, x_3]
\end{align*}$$

$x_i$ denotes actual target sequence element. We then feed this into the model to get a predicted sequence that should hopefully predict the `<eos>` token:

$$\begin{align*}
\text{output} &= [y_1, y_2, y_3, eos]
\end{align*}$$

$y_i$ denotes predicted target sequence element. We then calculate our loss using the original `trg` tensor with the `<sos>` token sliced off the front, leaving the `<eos>` token:

$$\begin{align*}
\text{output} &= [y_1, y_2, y_3, eos]\\
\text{trg[1:]} &= [x_1, x_2, x_3, eos]
\end{align*}$$

We then calculate our losses and update our parameters as is standard.

##end text

In [None]:
# @title train eval

def train(model, dataloader, optimizer, loss_fn):
    model.train()
    epoch_loss = 0
    for src, trg in dataloader:
        src = src.to(device)
        trg = trg.to(device) #trg = [batch size, trg len]
        optimizer.zero_grad()
        output, _ = model(src, trg[:,:-1]) #output = [batch size, trg len - 1, output dim]
        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim) #output = [batch size * trg len - 1, output dim]
        trg = trg[:,1:].contiguous().view(-1) #trg = [batch size * trg len - 1]
        loss = loss_fn(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(list(dataloader))

def evaluate(model, dataloader, loss_fn):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for src, trg in dataloader:
            src = src.to(device)
            trg = trg.to(device) #trg = [batch size, trg len]
            output, _ = model(src, trg[:,:-1]) #output = [batch size, trg len - 1, output dim]
            output_dim = output.shape[-1]
            output = output.contiguous().view(-1, output_dim) #output = [batch size * trg len - 1, output dim]
            trg = trg[:,1:].contiguous().view(-1) #trg = [batch size * trg len - 1]
            loss = loss_fn(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(list(dataloader))



In [None]:
# @title translate

def translate(model, src_sentence):
    model.eval()
    src = de_transform(src_sentence).view(1,-1).to(device)
    num_tokens = src.shape[0]
    src_mask = torch.zeros((num_tokens, num_tokens), dtype=bool, device=device)
    with torch.no_grad():
        enc_src = model.encoder(src, src_mask)
    trg_indexes = [BOS_IDX]
    max_len = src.shape[1]+5
    for i in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)
        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
        pred_token = output.argmax(2)[:,-1].item()
        trg_indexes.append(pred_token)
        if pred_token == EOS_IDX: break
    tgt_tokens = torch.tensor(trg_indexes[1:-1]).flatten()
    return " ".join(en_vocab.lookup_tokens(list(tgt_tokens.cpu().numpy())))

# UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # unknown, pad, bigining, end of sentence
print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu ."))


In [None]:
# @title run
# import math
import time
loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_IDX)
# optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001) # lr=0.0005
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9) # lr=0.0001

for epoch in range(20):
    start_time = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn)
    val_loss = evaluate(model, val_loader, loss_fn)
    end_time = time.time()
    print((f"Epoch: {epoch+1}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
    print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu ."))
#  A group of people standing in front of an igloo

# sine pos enc, scale after pos
# Epoch: 20, Train loss: 3.238, Val loss: 3.504, Epoch time = 44.329s
# A person is playing a trick in a race .

# sine pos enc, token,scale,pos
# Epoch: 20, Train loss: 1.506, Val loss: 2.149, Epoch time = 43.637s
# A crowd of people of a crowd .


Epoch: 1, Train loss: 5.685, Val loss: 4.895, Epoch time = 44.724s
A woman in a a a a a a .
Epoch: 2, Train loss: 4.724, Val loss: 4.455, Epoch time = 43.922s
A group of of of of of of of of .
Epoch: 3, Train loss: 4.254, Val loss: 3.989, Epoch time = 44.762s
A group of of of of of of of of of .
Epoch: 4, Train loss: 3.874, Val loss: 3.693, Epoch time = 43.549s
A group of of of of of of of of of .
Epoch: 5, Train loss: 3.571, Val loss: 3.408, Epoch time = 43.872s
A group of people of front of of of of of of of .
Epoch: 6, Train loss: 3.238, Val loss: 3.126, Epoch time = 44.194s
A crowd of a a building of a a a .
Epoch: 7, Train loss: 2.981, Val loss: 2.935, Epoch time = 44.248s
A crowd of a a crowd of a a building .
Epoch: 8, Train loss: 2.773, Val loss: 2.803, Epoch time = 44.176s
A crowd of a front of a store .
Epoch: 9, Train loss: 2.589, Val loss: 2.661, Epoch time = 43.593s
A crowd of a crowd of a crowd .
Epoch: 10, Train loss: 2.429, Val loss: 2.562, Epoch time = 44.491s
A crowd 

In [None]:
# @title translate_sentence, display_attention
def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):
    model.eval()
    if isinstance(sentence, str):
        nlp = spacy.load('de_core_news_sm')
        tokens = [token.text.lower() for token in nlp(sentence)]
    else:
        tokens = [token.lower() for token in sentence]
    tokens = [src_field.init_token] + tokens + [src_field.eos_token]
    src_indexes = [src_field.vocab.stoi[token] for token in tokens]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
    src_mask = model.make_src_mask(src_tensor)
    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)
    trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
    for i in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)
        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
        pred_token = output.argmax(2)[:,-1].item()
        trg_indexes.append(pred_token)
        if pred_token == trg_field.vocab.stoi[trg_field.eos_token]: break
    trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
    return trg_tokens[1:], attention

def display_attention(sentence, translation, attention, n_heads = 8, n_rows = 4, n_cols = 2):
    assert n_rows * n_cols == n_heads
    fig = plt.figure(figsize=(15,25))
    for i in range(n_heads):
        ax = fig.add_subplot(n_rows, n_cols, i+1)
        _attention = attention.squeeze(0)[i].cpu().detach().numpy()
        cax = ax.matshow(_attention, cmap='bone')
        ax.tick_params(labelsize=12)
        ax.set_xticklabels(['']+['<sos>']+[t.lower() for t in sentence]+['<eos>'], rotation=45)
        ax.set_yticklabels(['']+translation)
        ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
        ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    plt.show()
    plt.close()

# print(train_iter[0])

print(f'src = {src}')
print(f'trg = {trg}')

translation, attention = translate_sentence(src, SRC, TRG, model, device)
print(f'predicted trg = {translation}')
display_attention(src, translation, attention)

for src,trg in valid_loader: break


In [None]:
# @title bleu
from torchtext.data.metrics import bleu_score

def calculate_bleu(data, src_field, trg_field, model, device, max_len = 50):
    trgs = []
    pred_trgs = []
    for datum in data:
        src = vars(datum)['src']
        trg = vars(datum)['trg']
        pred_trg, _ = translate_sentence(src, src_field, trg_field, model, device, max_len)
        #cut off <eos> token
        pred_trg = pred_trg[:-1]
        pred_trgs.append(pred_trg)
        trgs.append([trg])
    return bleu_score(pred_trgs, trgs)
bleu_score = calculate_bleu(test_data, SRC, TRG, model, device)
print(f'BLEU score = {bleu_score*100:.2f}')
# 36.52, which beats the ~34 of the convolutional sequence-to-sequence model and ~28 of the attention based RNN model.

def translate_sentence_vectorized(src_tensor, src_field, trg_field, model, device, max_len=50):
    assert isinstance(src_tensor, torch.Tensor)

    model.eval()
    src_mask = model.make_src_mask(src_tensor)

    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)
    # enc_src = [batch_sz, src_len, hid_dim]

    trg_indexes = [[trg_field.vocab.stoi[trg_field.init_token]] for _ in range(len(src_tensor))]
    # Even though some examples might have been completed by producing a <eos> token
    # we still need to feed them through the model because other are not yet finished
    # and all examples act as a batch. Once every single sentence prediction encounters
    # <eos> token, then we can stop predicting.
    translations_done = [0] * len(src_tensor)
    for i in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)
        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
        pred_tokens = output.argmax(2)[:,-1]
        for i, pred_token_i in enumerate(pred_tokens):
            trg_indexes[i].append(pred_token_i)
            if pred_token_i == trg_field.vocab.stoi[trg_field.eos_token]:
                translations_done[i] = 1
        if all(translations_done):
            break

    # Iterate through each predicted example one by one;
    # Cut-off the portion including the after the <eos> token
    pred_sentences = []
    for trg_sentence in trg_indexes:
        pred_sentence = []
        for i in range(1, len(trg_sentence)):
            if trg_sentence[i] == trg_field.vocab.stoi[trg_field.eos_token]:
                break
            pred_sentence.append(trg_field.vocab.itos[trg_sentence[i]])
        pred_sentences.append(pred_sentence)
    return pred_sentences, attention

from torchtext.data.metrics import bleu_score

def calculate_bleu_alt(iterator, src_field, trg_field, model, device, max_len = 50):
    trgs = []
    pred_trgs = []
    with torch.no_grad():
        for batch in iterator:
            src = batch.src
            trg = batch.trg
            _trgs = []
            for sentence in trg:
                tmp = []
                # Start from the first token which skips the <start> token
                for i in sentence[1:]:
                    # Targets are padded. So stop appending as soon as a padding or eos token is encountered
                    if i == trg_field.vocab.stoi[trg_field.eos_token] or i == trg_field.vocab.stoi[trg_field.pad_token]:
                        break
                    tmp.append(trg_field.vocab.itos[i])
                _trgs.append([tmp])
            trgs += _trgs
            pred_trg, _ = translate_sentence_vectorized(src, src_field, trg_field, model, device)
            pred_trgs += pred_trg
    return pred_trgs, trgs, bleu_score(pred_trgs, trgs)
