In [None]:
%matplotlib inline

import collections
import random
import matplotlib.pyplot as plt
import nltk
import numpy as np
import pandas as pd
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# seq2seq and attention

In this topic we'll look at one of the most important developments in neural networks for NLP: attention.
It is what allowed for the development of the next big thing after RNNs, **transformers**, which we will be talking about in the next topic.
But first, we need to start from the basic task that attention was designed to help: **sequence-to-sequence** mapping or **seq2seq**.

seq2seq is simply the conditioning of a language model on a sequence.
This includes applications like machine translation (condition a French langauge model on an English text) and summarisation (condition a summary language model on a long text).
The conditioning sequence is called the **source sequence** and the sequence that gets generated by the language model is called the **target sequence**.
It has its own challenges compared to tasks seen in the previous topic because the source sequence needs to be converted into a single vector in order to condition the language model, and, contrary to sentiment analysis, you need to store a lot of information about the source sequence to be able to generate a faithful translation or summary of it.

## Prompting

The simplest way to do this is to avoid conditioning the language model at all and simply **prompt** the language model instead.
This is when you train a language model to generate a single long sequence consisting of the source and target sequences joined together with a special **separator token** in between, for example `I like it . SEP J'aime ça .`.

![](prompting_seq2seq.png)

The predictions made by the language model while the source sequence is being fed to it are ignored.

In this case of translation, this would require that the vocabulary of the language model includes tokens from both languages so that both the source and the target sequence can be embedded.
Unfortunately, this is a very resource intensive technique because you need to train the language model on very long sequences and RNNs are very hard to make them pick up on very long range dependencies (the first source token could be essential for predicting the last target token).
With recent hardware and transformers this is now more commonplace but it's good to know of the other solutions.

## Traditional neural translation

In 2014, Sutskever, Vinyals, and Le published a paper titled [Sequence to Sequence Learning with Neural Networks](https://proceedings.neurips.cc/paper/2014/hash/a14ac55a4f27472c5d894ec1c3c743d2-Abstract.html) where they showed how to make a neural network that performs translation from English to French.
The basic idea was that an LSTM encodes an English text such that the final state is then init-injected into a neural language model that generates French text.

![](traditional_seq2seq.png)

The part that deals with the source text is typically called the **encoder** and the part that deals with the target text is typically called the **decoder**, with the whole architecture called an **encoder-decoder architecture**.
The name comes from image encoder-decoders which transform an image into another image and the name 'decoder' makes sense there because a neural network can directly output an image.
In the case of text generation tasks, 'decoder' is bit of a misnomer because neural networks do not directly output text and it is only with some search algorithm that text is generated.
Still, this didn't stop anyone from using calling a conditioned language model a decoder.

Mathematically, the language model in an encoder-decoder architecture is described as follows:

$P(t_1^t, t_2^t, \dots, t_n^t | t_1^s, t_2^s, \dots, t_m^s)$

that is, the probability of a set of tokens, called the target tokens, given another set of tokens of a possibly different length, called the source tokens.

Let's try using this formulation on our toy sentiment analysis data set to translate a positive text into a negative one and vice versa.

In [None]:
train_src_tokens = [
    'I like it .'.split(' '),
    'I hate it .'.split(' '),
    'I don\'t hate it .'.split(' '),
    'I don\'t like it .'.split(' '),
]

train_trg_tokens = [ # This is just the corresponding source text but with the word "don't" inserted or removed.
    'I don\'t like it .'.split(' '),
    'I don\'t hate it .'.split(' '),
    'I hate it .'.split(' '),
    'I like it .'.split(' '),
]

src_max_len = max(len(text) for text in train_src_tokens)
print('src_max_len:', src_max_len)

src_vocab = ['<PAD>'] + sorted({token for text in train_src_tokens for token in text}) # No edge token in the source vocabulary.
src_token2index = {t: i for (i, t) in enumerate(src_vocab)}
src_pad_index = src_token2index['<PAD>']
print('src_vocab:', src_vocab)
print()

train_src_indexed_np = np.full([len(train_src_tokens), src_max_len], src_pad_index, np.int64)
for i in range(len(train_src_tokens)):
    for j in range(len(train_src_tokens[i])):
        train_src_indexed_np[i, j] = src_token2index[train_src_tokens[i][j]]
train_src_indexed = torch.tensor(train_src_indexed_np, device=device)

trg_max_len = max(len(text) + 1 for text in train_trg_tokens)
print('trg_max_len:', trg_max_len)

trg_vocab = ['<PAD>', '<EDGE>'] + sorted({token for text in train_trg_tokens for token in text}) # Include the edge token.
trg_token2index = {t: i for (i, t) in enumerate(trg_vocab)}
trg_pad_index = trg_token2index['<PAD>']
trg_edge_index = trg_token2index['<EDGE>']
print('trg_vocab:', trg_vocab)
print()

train_trg_x_indexed_np = np.full((len(train_trg_tokens), trg_max_len), trg_pad_index, np.int64)
for i in range(len(train_trg_tokens)):
    train_trg_x_indexed_np[i, 0] = trg_edge_index
    for j in range(len(train_trg_tokens[i])):
        train_trg_x_indexed_np[i, j + 1] = trg_token2index[train_trg_tokens[i][j]]
train_trg_x_indexed = torch.tensor(train_trg_x_indexed_np, device=device)

train_trg_y_indexed_np = np.full((len(train_trg_tokens), trg_max_len), trg_pad_index, np.int64)
for i in range(len(train_trg_tokens)):
    for j in range(len(train_trg_tokens[i])):
        train_trg_y_indexed_np[i, j] = trg_token2index[train_trg_tokens[i][j]]
    train_trg_y_indexed_np[i, len(train_trg_tokens[i])] = trg_edge_index # Add the edge token at the end.
train_trg_y_indexed = torch.tensor(train_trg_y_indexed_np, device=device)

In [None]:
# This beam search algorithm was adapted for conditioning on a list of tokens.
def beam_generate(model, src_token2index, src_pad_index, trg_vocab, trg_edge_index, src_tokens, max_len, beam_size):
    src_indexed = torch.tensor(
        [[src_token2index[token] for token in src_tokens]],
        dtype=torch.int64, device=device
    ).tile((beam_size, 1))
    beam_prefixes_indexed = torch.tensor([[trg_edge_index]], dtype=torch.int64, device=device)
    beam_prefixes_probs = np.array([1.0], np.float32)
    
    best_full_prefix_indexed = None
    best_full_prefix_prob = None
    
    with torch.no_grad():
        for _ in range(max_len):
            outputs = torch.softmax(model(src_indexed[:len(beam_prefixes_indexed), :], beam_prefixes_indexed), dim=2)
            token_probs = outputs[:, -1, :]
            new_prefixes_probs = beam_prefixes_probs[:, None]*token_probs.cpu().numpy()

            new_partial_prefixes = []
            for (prefix, probs_group) in zip(beam_prefixes_indexed.cpu().tolist(), new_prefixes_probs.tolist()):
                for (next_token_index, prefix_prob) in enumerate(probs_group):
                    if next_token_index == trg_edge_index:
                        if best_full_prefix_prob is None or prefix_prob > best_full_prefix_prob:
                            best_full_prefix_indexed = prefix + [next_token_index]
                            best_full_prefix_prob = prefix_prob
                    else:
                        new_partial_prefixes.append((prefix_prob, prefix + [next_token_index]))
            
            new_partial_prefixes.sort(reverse=True)
            (best_partial_prefix_prob, _) = new_partial_prefixes[0]
            if best_full_prefix_prob > best_partial_prefix_prob:
                text = [trg_vocab[index] for index in best_full_prefix_indexed]
                return (text, best_full_prefix_prob)
            
            new_beam = new_partial_prefixes[:beam_size]
            beam_prefixes_indexed = torch.tensor([prefix for (prob, prefix) in new_beam], dtype=torch.int64, device=device)
            beam_prefixes_probs = np.array([prob for (prob, prefix) in new_beam], np.float32)

    text = [trg_vocab[index] for index in beam_prefixes_indexed[0, :].cpu().tolist()]
    return (text, beam_prefixes_probs[0])

In [None]:
class Model(torch.nn.Module):

    def __init__(self, src_vocab_size, trg_vocab_size, embedding_size, state_size, src_pad_index):
        super().__init__()
        self.src_pad_index = src_pad_index
        
        self.src_embedding = torch.nn.Embedding(src_vocab_size, embedding_size)
        self.src_rnn_s0 = torch.nn.Parameter(torch.zeros((state_size,), dtype=torch.float32))
        self.src_rnn_c0 = torch.nn.Parameter(torch.zeros((state_size,), dtype=torch.float32))
        self.src_rnn_cell = torch.nn.LSTMCell(embedding_size, state_size)
        
        self.trg_embedding = torch.nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_rnn_c0 = torch.nn.Parameter(torch.zeros((state_size,), dtype=torch.float32))
        self.trg_rnn_cell = torch.nn.LSTMCell(embedding_size, state_size)
        
        self.output_layer = torch.nn.Linear(state_size, trg_vocab_size)

    def forward(self, src_indexed, trg_x_indexed):
        batch_size = src_indexed.shape[0]
        src_time_steps = src_indexed.shape[1]
        trg_time_steps = trg_x_indexed.shape[1]

        src_non_pad_mask = src_indexed != self.src_pad_index
        
        ####################
        # Source processing
        ####################
        
        src_embedded = self.src_embedding(src_indexed)
        
        state = self.src_rnn_s0[None, :].tile((batch_size, 1))
        c = self.src_rnn_c0[None, :].tile((batch_size, 1))
        for t in range(src_time_steps):
            (new_state, c) = self.src_rnn_cell(src_embedded[:, t, :], (state, c))
            state = torch.where(src_non_pad_mask[:, t, None], new_state, state)
        
        ####################
        # Target processing
        ####################
        
        trg_embedded = self.trg_embedding(trg_x_indexed)

        # Use source RNN hidden state as the initial hidden state.
        c = self.trg_rnn_c0[None, :].tile((batch_size, 1))
        interm_states = []
        for t in range(trg_time_steps):
            (state, c) = self.trg_rnn_cell(trg_embedded[:, t, :], (state, c))
            interm_states.append(state)
        interm_states = torch.stack(interm_states, dim=1)
        
        ####################
        # Output processing
        ####################
        
        return self.output_layer(interm_states)

model = Model(len(src_vocab), len(trg_vocab), embedding_size=2, state_size=2, src_pad_index=src_pad_index)
model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.01)

print('epoch', 'error')
train_errors = []
for epoch in range(1, 10000+1):
    batch_size = train_trg_x_indexed.shape[0]
    trg_time_steps = train_trg_x_indexed.shape[1]
    trg_pad_mask = train_trg_y_indexed == trg_pad_index
    
    optimiser.zero_grad()
    logits = model(train_src_indexed, train_trg_x_indexed)
    train_token_errors = torch.nn.functional.cross_entropy(logits.transpose(1, 2), train_trg_y_indexed, reduction='none')
    train_token_errors = torch.masked_fill(train_token_errors, trg_pad_mask, 0.0)
    train_error = train_token_errors.sum()/(~trg_pad_mask).sum()
    train_errors.append(train_error.detach().cpu().tolist())
    train_error.backward()
    optimiser.step()

    if epoch%1000 == 0:
        print(epoch, train_errors[-1])
print()

with torch.no_grad():
    for src_text in train_src_tokens:
        print(src_text)
        (trg_text, prob) = beam_generate(model, src_token2index, src_pad_index, trg_vocab, trg_edge_index, src_text, max_len=10, beam_size=3)
        print(trg_text, prob)
        print()

(fig, ax) = plt.subplots(1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('$E$')
ax.plot(range(1, len(train_errors) + 1), train_errors, color='blue', linestyle='-', linewidth=3)
ax.grid()

## Attention

Given what we said before about translation requiring the conditioning vector to contain a lot of information about the source text, beyond a few source tokens the model performance will start to break down as it cannot contain all the necessary information in the fixed-sized vector.
Ideally, we represent the source text without having to compress it into a single vector.
This is where **attention** comes in.

### Self-attentive sentence embedding

Let's consider our toy data set.
Do we need to know *all* of the tokens in the source text in order to generate the target text?
All you need to know is whether there is the token "don't" and whether 'like' or 'hate' are used.
Everything else is unnecessary information because the rest of the texts are completely predictable.
The source RNN should take this into account with enough training, but it would be better to explicitly make the model select the tokens to consider, that is, selectively pay more **attention** to some tokens over others.
Furthermore, it be useful if the model could also explicitly tell us how much attention it was giving each source token.

What we want is a neural network that produces a number for each source token that quantifies the amount of importance it gives each token.
To do this, we'll have a separate layer that takes in a source token vector (embedding) and outputs a number, like this:

In [None]:
num_tokens = 3
embedding_size = 2

src_embedded = torch.randn((1, num_tokens, embedding_size), dtype=torch.float32, device=device)
importance_layer = torch.nn.Linear(embedding_size, 1).to(device)

print('Token vectors:')
print(src_embedded)
print()

token_importances = importance_layer(src_embedded)

print('Token importances:')
print(token_importances)

The amount of importance of each token is then treated as logits and fed to a softmax, like this:

In [None]:
attention = torch.softmax(token_importances, dim=1)
print(attention)

This transforms the amount of importance of each token into a vector of fractions called an **attention vector**.
The attention vector is multiplied by the source token vectors in order to make unimportant tokens become (almost) zero vectors, like this:

In [None]:
print('original token vectors:')
print(src_embedded)
print()
print('after applying attention:')
print(attention*src_embedded)

This **weighs** the token vectors such that the less important they are, the closer they are to zeros.
When you weigh a set of numbers or vectors with positive weights that add up to 1, the result is called a **weighted average**, which is similar to normal average but where the numbers are not necessarily given equal importance (normal average is a weighted average where all the weights are the same).

If we add the weighted token vectors together, we'll produce a weighted average of the vectors, creating a single conditioning vector, like this:

In [None]:
cond_vec = (attention*src_embedded).sum(dim=1)
print(cond_vec)

This vector would be the weighted average of all the source token vectors and will be most similar to the token with the highest attention.
As a vector, it can be used to represent the whole source text as it contains all the information that the importance layer considered necessary to condition the language model.

Note that multiplying the rows of a matrix of token vectors by the values in an attention vector and then adding up the rows is just a matrix multiplication, which means that we can just do this:

In [None]:
cond_vec = attention.transpose(1, 2)@src_embedded
print(cond_vec)

Note that the result is now 3-dimensional instead of 2-dimensional as expected (a vector for every source text in the batch), so we need to get rid of the singleton dimension in the middle since that singleton dimension represents tokens (and we have text vectors).

In [None]:
cond_vec = cond_vec[:, 0, :]
print(cond_vec)

Should the source token vectors used here be the embedding vectors?
There is not enough information in an embedding vector to determine if it's an important token.
Context is needed so that the neural network can take decisions based on how the source tokens are used rather than based on what tokens they are.
So we use a bi-directional RNN to produce contextual source token vectors and these are what is used to determine the amount of importance of the tokens as well as to make the conditional vector.

We can now have a neural network that outputs both the generated text as well as the attention vector so that we can see what the neural network was looking at in the source text.

![](self_attentive.png)

This architecture was not the first kind that used attention, but it is the simplest.
It was described in 2017 by Lin and others in a paper called [A Structured Self-Attentive Sentence Embedding](https://openreview.net/forum?id=BJC_jUqxe).
The paper does not describe machine translation but sentence-level classification such that the conditioning vector is used in a softmax layer.
The architecture in the paper is also slightly more complex as several parallel branches of conditioning vectors are produced from the same source sentence using the same technique shown in the convolutional neural networks topic.

We haven't mentioned anything about how to handle batches with variable length source texts, that is, how to handle pad tokens in the source text.
In order for pad tokens to not influence the output, we need to make sure that the pad tokens do not get any attention.
We can't just replace the attention numbers returned by softmax with zero, because then the remaining numbers will not add up to one.
What we can do is manipulate the logits in order to make the softmax return numbers that are very close to zero wherever there are pad tokens.
This is done by making the pad logits very large negative numbers such as negative infinity.

In [None]:
pad_mask = torch.tensor([0, 0, 1, 1], dtype=torch.bool, device=device)
print('pad mask:')
print(pad_mask)
print()

logits = torch.tensor([1, 2, 3, 4], dtype=torch.float32, device=device)
print('logits:')
print(logits)
print()

new_logits = torch.masked_fill(logits, pad_mask, float('-inf'))
print('new_logits:')
print(new_logits)
print()

softmax = torch.softmax(new_logits, dim=0)
print('softmax:')
print(softmax)
print()

print('softmax of [1, 2] for reference:')
print(torch.softmax(torch.tensor([1, 2], dtype=torch.float32, device=device), dim=0))

Let's use this on our toy data set.

In [None]:
# This beam search algorithm was adapted for ignoring the attention values returned by the model.
def beam_generate(model, src_token2index, src_pad_index, trg_vocab, trg_edge_index, src_tokens, max_len, beam_size):
    src_indexed = torch.tensor(
        [[src_token2index[token] for token in src_tokens]],
        dtype=torch.int64, device=device
    ).tile((beam_size, 1))
    beam_prefixes_indexed = torch.tensor([[trg_edge_index]], dtype=torch.int64, device=device)
    beam_prefixes_probs = np.array([1.0], np.float32)
    
    best_full_prefix_indexed = None
    best_full_prefix_prob = None
    
    with torch.no_grad():
        for _ in range(max_len):
            (logits, _) = model(src_indexed[:len(beam_prefixes_indexed), :], beam_prefixes_indexed)
            outputs = torch.softmax(logits, dim=2)
            token_probs = outputs[:, -1, :]
            new_prefixes_probs = beam_prefixes_probs[:, None]*token_probs.cpu().numpy()

            new_partial_prefixes = []
            for (prefix, probs_group) in zip(beam_prefixes_indexed.cpu().tolist(), new_prefixes_probs.tolist()):
                for (next_token_index, prefix_prob) in enumerate(probs_group):
                    if next_token_index == trg_edge_index:
                        if best_full_prefix_prob is None or prefix_prob > best_full_prefix_prob:
                            best_full_prefix_indexed = prefix + [next_token_index]
                            best_full_prefix_prob = prefix_prob
                    else:
                        new_partial_prefixes.append((prefix_prob, prefix + [next_token_index]))
            
            new_partial_prefixes.sort(reverse=True)
            (best_partial_prefix_prob, _) = new_partial_prefixes[0]
            if best_full_prefix_prob > best_partial_prefix_prob:
                text = [trg_vocab[index] for index in best_full_prefix_indexed]
                return (text, best_full_prefix_prob)
            
            new_beam = new_partial_prefixes[:beam_size]
            beam_prefixes_indexed = torch.tensor([prefix for (prob, prefix) in new_beam], dtype=torch.int64, device=device)
            beam_prefixes_probs = np.array([prob for (prob, prefix) in new_beam], np.float32)

    text = [trg_vocab[index] for index in beam_prefixes_indexed[0, :].cpu().tolist()]
    return (text, beam_prefixes_probs[0])

In [None]:
class Model(torch.nn.Module):

    def __init__(self, src_vocab_size, trg_vocab_size, src_embedding_size, src_state_size, src_pad_index, trg_embedding_size, trg_state_size):
        super().__init__()
        self.src_pad_index = src_pad_index
        
        self.src_embedding = torch.nn.Embedding(src_vocab_size, src_embedding_size)
        self.src_rnn_fw_s0 = torch.nn.Parameter(torch.zeros((src_state_size,), dtype=torch.float32))
        self.src_rnn_fw_c0 = torch.nn.Parameter(torch.zeros((src_state_size,), dtype=torch.float32))
        self.src_rnn_fw_cell = torch.nn.LSTMCell(src_embedding_size, src_state_size)
        self.src_rnn_bw_s0 = torch.nn.Parameter(torch.zeros((src_state_size,), dtype=torch.float32))
        self.src_rnn_bw_c0 = torch.nn.Parameter(torch.zeros((src_state_size,), dtype=torch.float32))
        self.src_rnn_bw_cell = torch.nn.LSTMCell(src_embedding_size, src_state_size)
        self.context_layer = torch.nn.Linear(src_state_size*2, trg_state_size) # Squash the bi-RNN states into the size of the target RNN state.
        self.attention_layer = torch.nn.Linear(trg_state_size, 1)
        
        self.trg_embedding = torch.nn.Embedding(trg_vocab_size, trg_embedding_size)
        self.trg_rnn_c0 = torch.nn.Parameter(torch.zeros((trg_state_size,), dtype=torch.float32))
        self.trg_rnn_cell = torch.nn.LSTMCell(trg_embedding_size, trg_state_size)
        
        self.output_layer = torch.nn.Linear(trg_state_size, trg_vocab_size)

    def forward(self, src_indexed, trg_x_indexed):
        batch_size = src_indexed.shape[0]
        src_time_steps = src_indexed.shape[1]
        trg_time_steps = trg_x_indexed.shape[1]
        src_pad_mask = src_indexed == self.src_pad_index
        src_non_pad_mask = ~src_pad_mask
        
        ####################
        # Source processing
        ####################
        
        embedded = self.src_embedding(src_indexed)
        
        state = self.src_rnn_fw_s0[None, :].tile((batch_size, 1))
        c = self.src_rnn_fw_c0[None, :].tile((batch_size, 1))
        interm_states_list = []
        for t in range(src_time_steps):
            (state, c) = self.src_rnn_fw_cell(embedded[:, t, :], (state, c))
            interm_states_list.append(state)
        interm_states_fw = torch.stack(interm_states_list, dim=1)

        state = self.src_rnn_bw_s0[None, :].tile((batch_size, 1))
        c = self.src_rnn_bw_c0[None, :].tile((batch_size, 1))
        interm_states_list = []
        for t in reversed(range(src_time_steps)):
            (new_state, new_c) = self.src_rnn_bw_cell(embedded[:, t, :], (state, c))
            state = torch.where(src_non_pad_mask[:, t, None], new_state, state)
            c = torch.where(src_non_pad_mask[:, t, None], new_c, c)
            interm_states_list.append(state)
        interm_states_bw = torch.stack(interm_states_list[::-1], dim=1)

        src_interm_states = torch.concat((interm_states_fw, interm_states_bw), dim=2)
        src_context_token_vecs = self.context_layer(src_interm_states)
        
        # Attention!
        attn_logits = self.attention_layer(src_context_token_vecs)
        attn_logits = attn_logits.masked_fill(src_pad_mask[:, :, None], float('-inf'))
        attention = torch.softmax(attn_logits, dim=1)
        cond_vector = (attention.transpose(1, 2)@src_context_token_vecs)[:, 0, :]
        
        ####################
        # Target processing
        ####################
        
        embedded = self.trg_embedding(trg_x_indexed)

        state = cond_vector # Use conditioning vector as the initial hidden state.
        c = self.trg_rnn_c0[None, :].tile((batch_size, 1))
        interm_states_list = []
        for t in range(trg_time_steps):
            (state, c) = self.trg_rnn_cell(embedded[:, t, :], (state, c))
            interm_states_list.append(state)
        trg_interm_states = torch.stack(interm_states_list, dim=1)
        
        ####################
        # Output processing
        ####################
        
        return (self.output_layer(trg_interm_states), attention[:, :, 0]) # Return the attention values as well.

model = Model(len(src_vocab), len(trg_vocab), src_embedding_size=2, src_state_size=2, src_pad_index=src_pad_index, trg_embedding_size=2, trg_state_size=2)
model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.01)

print('epoch', 'error')
train_errors = []
for epoch in range(1, 10000+1):
    batch_size = train_trg_x_indexed.shape[0]
    trg_time_steps = train_trg_x_indexed.shape[1]
    trg_pad_mask = train_trg_x_indexed == trg_pad_index
    
    optimiser.zero_grad()
    (logits, _) = model(train_src_indexed, train_trg_x_indexed)
    train_token_errors = torch.nn.functional.cross_entropy(logits.transpose(1, 2), train_trg_y_indexed, reduction='none')
    train_token_errors = train_token_errors.masked_fill(trg_pad_mask, 0.0)
    train_error = train_token_errors.sum()/(~pad_mask).sum()
    train_errors.append(train_error.detach().cpu().tolist())
    train_error.backward()
    optimiser.step()

    if epoch%1000 == 0:
        print(epoch, train_errors[-1])
print()

# This function displays the attention by passing the source text through the model again and ignoring the predictions.
def show_attention(model, src_token2index, src_tokens, trg_edge_index):
    with torch.no_grad():
        src_indexed = torch.tensor([[src_token2index[token] for token in src_tokens]], dtype=torch.int64, device=device)
        trg_x_indexed = torch.tensor([[trg_edge_index]], dtype=torch.int64, device=device) # The target prefix doesn't matter for getting the attention values so use the shortest possible prefix.
        (_, attention) = model(src_indexed, trg_x_indexed)
        print(''.join(f'{token: <6s}' for token in src_tokens))
        print(''.join(f'{attn: <6.3f}' for attn in attention[0, :].cpu().tolist()))

with torch.no_grad():
    for src_text in train_src_tokens:
        print('---------------------------------------')
        print(src_text)
        (trg_text, prob) = beam_generate(model, src_token2index, src_pad_index, trg_vocab, trg_edge_index, src_text, max_len=10, beam_size=3)
        print(trg_text, prob)
        print()
        show_attention(model, src_token2index, src_text, trg_edge_index)
        print()

(fig, ax) = plt.subplots(1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('$E$')
ax.plot(range(1, len(train_errors) + 1), train_errors, color='blue', linestyle='-', linewidth=3)
ax.grid()

## Dynamic attention / translation with alignment

Up to now we have always represented our text as a single vector which doesn't change, which is why it was possible to use init-inject.
But the problem with a single vector representation is that the entire text needs to be crammed into this fixed size vector, regardless of how long the text is, which has obvious consequences.
Someone had to do something about this.

In the same year that Sutskever published the seq2seq model, Bahdanau, Cho, and Bengio published another paper called [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473).
The idea in this paper was that you don't need to represent your source text as a single vector but that you can have a dynamically changing source vector as the target text is being generated.
This is done by changing the attention values as the target text is being generated token by token.
Do you need to know all of the tokens in the source text in order to generate the first token in the target text?
The first token in the source text is probably all you need to know in order to be able to produce the first text in the translation.
As you're translating, you'll only care about a few source tokens at a time.

To do this, the attention values need to be produced based on what the next token to generate needs to be, which can be determined from on the target RNN's state vector.
So the attention maker needs to receive the contextual source token vector together with the target RNN's state vector in order to produce an attention value.
Since a different conditioning vector is being used for each time step in the target RNN, init-inject and pre-inject cannot be used.
Par-inject or merge must be used instead as these allow you to provide a new conditioning vector for every target token being generated.
We will be using merge.

![](align_attention.png)

How should we define the attention maker?
We could concatenate the contextual source token vectors to the target state vectors and pass that through a layer, but there's actually a much faster alternative: dot product.
The dot product is a common way to measure the similarity between two vectors in neural networks, mostly because it is fast and it's easy to calculate its gradient.
It only works well as a similarity measure when the numbers in the vectors are a mix of positive and negative numbers and are within a limited range.
The neural network will thus learn to generate vectors that meet these requirements.
If you have two groups of vectors and you want to measure the dot product between all the pairings in the two groups, just make each group a matrix of vectors and then perform a matrix multiplication:

In [None]:
vector_size = 2
num_vecs1 = 3
num_vecs2 = 4

vecs1 = torch.randn((num_vecs1, vector_size), dtype=torch.float32, device=device)
print('vecs1:')
print(vecs1)
vecs2 = torch.randn((num_vecs2, vector_size), dtype=torch.float32, device=device)
print('vecs2:')
print(vecs2)
print()

naive_dotprod = torch.zeros((num_vecs1, num_vecs2), dtype=torch.float32, device=device)
for i in range(num_vecs1):
    for j in range(num_vecs2):
        naive_dotprod[i, j] = (vecs1[i]*vecs2[j]).sum()
print('naive_dotprod:')
print(naive_dotprod)
print()

fast_dotprod = vecs1@(vecs2.transpose(0, 1))
print('fast_dotprod:')
print(fast_dotprod)

And now, the translation model:

In [None]:
class Model(torch.nn.Module):

    def __init__(self, src_vocab_size, trg_vocab_size, src_embedding_size, src_state_size, src_pad_index, trg_embedding_size, trg_state_size):
        super().__init__()
        self.src_pad_index = src_pad_index
        
        self.src_embedding = torch.nn.Embedding(src_vocab_size, src_embedding_size)
        self.src_rnn_fw_s0 = torch.nn.Parameter(torch.zeros((src_state_size,), dtype=torch.float32))
        self.src_rnn_fw_c0 = torch.nn.Parameter(torch.zeros((src_state_size,), dtype=torch.float32))
        self.src_rnn_fw_cell = torch.nn.LSTMCell(src_embedding_size, src_state_size)
        self.src_rnn_bw_s0 = torch.nn.Parameter(torch.zeros((src_state_size,), dtype=torch.float32))
        self.src_rnn_bw_c0 = torch.nn.Parameter(torch.zeros((src_state_size,), dtype=torch.float32))
        self.src_rnn_bw_cell = torch.nn.LSTMCell(src_embedding_size, src_state_size)
        self.context_layer = torch.nn.Linear(2*src_state_size, trg_state_size)
        
        self.trg_embedding = torch.nn.Embedding(trg_vocab_size, trg_embedding_size)
        self.trg_rnn_s0 = torch.nn.Parameter(torch.zeros((trg_state_size,), dtype=torch.float32))
        self.trg_rnn_c0 = torch.nn.Parameter(torch.zeros((trg_state_size,), dtype=torch.float32))
        self.trg_rnn_cell = torch.nn.LSTMCell(trg_embedding_size, trg_state_size)
        
        self.output_layer = torch.nn.Linear(trg_state_size + trg_state_size, trg_vocab_size) # Concatenate the source bi-RNN state with the target state.

    def forward(self, src_indexed, trg_x_indexed):
        batch_size = src_indexed.shape[0]
        src_time_steps = src_indexed.shape[1]
        trg_time_steps = trg_x_indexed.shape[1]
        src_pad_mask = src_indexed == self.src_pad_index
        src_non_pad_mask = ~src_pad_mask
        
        ####################
        # Source processing
        ####################
        
        embedded = self.src_embedding(src_indexed)
        
        state = self.src_rnn_fw_s0[None, :].tile((batch_size, 1))
        c = self.src_rnn_fw_c0[None, :].tile((batch_size, 1))
        interm_states_list = []
        for t in range(src_time_steps):
            (state, c) = self.src_rnn_fw_cell(embedded[:, t, :], (state, c))
            interm_states_list.append(state)
        interm_states_fw = torch.stack(interm_states_list, dim=1)

        state = self.src_rnn_bw_s0[None, :].tile((batch_size, 1))
        c = self.src_rnn_bw_c0[None, :].tile((batch_size, 1))
        interm_states_list = []
        for t in reversed(range(src_time_steps)):
            (new_state, new_c) = self.src_rnn_bw_cell(embedded[:, t, :], (state, c))
            state = torch.where(src_non_pad_mask[:, t, None], new_state, state)
            c = torch.where(src_non_pad_mask[:, t, None], new_c, c)
            interm_states_list.append(state)
        interm_states_bw = torch.stack(interm_states_list[::-1], dim=1)

        interm_states = torch.concat((interm_states_fw, interm_states_bw), dim=2)
        src_context_token_vecs = self.context_layer(interm_states)

        ####################
        # Target processing
        ####################
        
        embedded = self.trg_embedding(trg_x_indexed)

        state = self.trg_rnn_s0[None, :].tile((batch_size, 1))
        c = self.trg_rnn_c0[None, :].tile((batch_size, 1))
        interm_states_list = []
        for t in range(trg_time_steps):
            (state, c) = self.trg_rnn_cell(embedded[:, t, :], (state, c))
            interm_states_list.append(state)
        trg_interm_states = torch.stack(interm_states_list, dim=1)
        
        ####################
        # Output processing
        ####################
        
        attn_logits = src_context_token_vecs@(trg_interm_states.transpose(1, 2)) # Dot product similarities.
        attn_logits = attn_logits.masked_fill(src_pad_mask[:, :, None], float('-inf'))
        attention = torch.softmax(attn_logits, dim=1)
        cond_vectors = (attention.transpose(1, 2)@src_context_token_vecs) # Dimension 1 should not be removed now because there is a different attention vector for each token.
        trg_interm_states = torch.concat((cond_vectors, trg_interm_states), dim=2) # Merge
        return (self.output_layer(trg_interm_states), attention)

model = Model(len(src_vocab), len(trg_vocab), src_embedding_size=2, src_state_size=2, src_pad_index=src_pad_index, trg_embedding_size=2, trg_state_size=2)
model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.01)

print('epoch', 'error')
train_errors = []
for epoch in range(1, 10000+1):
    batch_size = train_trg_x_indexed.shape[0]
    trg_time_steps = train_trg_x_indexed.shape[1]
    trg_pad_mask = train_trg_x_indexed == trg_pad_index
    
    optimiser.zero_grad()
    (logits, _) = model(train_src_indexed, train_trg_x_indexed)
    train_token_errors = torch.nn.functional.cross_entropy(logits.transpose(1, 2), train_trg_y_indexed, reduction='none')
    train_token_errors = train_token_errors.masked_fill(trg_pad_mask, 0.0)
    train_error = train_token_errors.sum()/(~pad_mask).sum()
    train_errors.append(train_error.detach().cpu().tolist())
    train_error.backward()
    optimiser.step()

    if epoch%1000 == 0:
        print(epoch, train_errors[-1])
print()

# Re-implement the attention showing function so that it shows the whole grid of attention for each source-target token combination.
def show_attention(model, src_token2index, src_tokens, trg_edge_index, trg_token2index, trg_tokens):
    if len(trg_tokens) == 0: # If the model isn't trained well, you could end up with an empty text.
        print('N/A')
    else:
        with torch.no_grad():
            src_indexed = torch.tensor(
                [[src_token2index[token] for token in src_tokens]],
                dtype=torch.int64, device=device
            )
            trg_x_indexed = torch.tensor( # Use all target tokens so that we can extract the attention values for every target token.
                [[trg_edge_index] + [trg_token2index[token] for token in trg_tokens]],
                dtype=torch.int64, device=device
            )
            (_, attention) = model(src_indexed, trg_x_indexed)
            print(' '*6, ''.join(f'{src_token: <6s}' for src_token in src_tokens))
            for (attn_row, trg_token) in zip(attention[0, :, :].transpose(0, 1).cpu().tolist(), ['<EDGE>']+trg_tokens):
                print(f'{trg_token: <6s}', ''.join(f'{attn: <6.3f}' for attn in attn_row))

with torch.no_grad():
    for src_text in train_src_tokens:
        print('---------------------------------------')
        print(src_text)
        (trg_text, prob) = beam_generate(model, src_token2index, src_pad_index, trg_vocab, trg_edge_index, src_text, max_len=10, beam_size=3)
        print(trg_text, prob)
        print()
        show_attention(model, src_token2index, src_text, trg_edge_index, trg_token2index, trg_text)
        print()

(fig, ax) = plt.subplots(1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('$E$')
ax.plot(range(1, len(train_errors) + 1), train_errors, color='blue', linestyle='-', linewidth=3)
ax.grid()

Note that attention tends to need a lot of data in order to get interesting attention matrices, in which case you can actually form an alignment between the source words and their equivalent target words.

## Exercises

### 1) Autoencoding sentences

Take the texts in the movie reviews data set and make a traditional translation model that outputs the same sentence it is given.
Preprocessing has been done for you.
Given that the source and target are equivalent, there is no need for separate vocabularies.

This is useful in practice for obtaining sentence vectors (the conditioning vector) that contain as much information about the sentence as possible.

Take the given test text and use beam search to regenerate it using the trained seq2seq model.

In [None]:
min_freq = 3

train_df = pd.read_csv('../data_set/sentiment/train.csv')
test_df = pd.read_csv('../data_set/sentiment/test.csv')

train_text = train_df['text']
test_text = test_df['text'][0]

nltk.download('punkt')
train_text_tokens = [nltk.word_tokenize(text) for text in train_text]
test_text_tokens = nltk.word_tokenize(test_text)
src_max_len = max(len(text) for text in train_text_tokens)
trg_max_len = src_max_len + 1

frequencies = collections.Counter(token for text in train_text_tokens for token in text)
vocabulary = sorted(frequencies.keys(), key=frequencies.get, reverse=True)
while frequencies[vocabulary[-1]] < min_freq:
    vocabulary.pop()
vocab = ['<PAD>', '<EDGE>', '<UNK>'] + vocabulary
token2index = {token: i for (i, token) in enumerate(vocab)}
pad_index = token2index['<PAD>']
edge_index = token2index['<EDGE>']
unk_index = token2index['<UNK>']

train_src_indexed_np = np.full((len(train_text_tokens), src_max_len), pad_index, np.int64)
for i in range(len(train_text_tokens)):
    for j in range(len(train_text_tokens[i])):
        train_src_indexed_np[i, j] = token2index.get(train_text_tokens[i][j], unk_index)
train_src_indexed = torch.tensor(train_src_indexed_np, device=device)

train_trg_x_indexed_np = np.full((len(train_text_tokens), trg_max_len), pad_index, np.int64)
for i in range(len(train_text_tokens)):
    train_trg_x_indexed_np[i, 0] = edge_index
    for j in range(len(train_text_tokens[i])):
        train_trg_x_indexed_np[i, j + 1] = token2index.get(train_text_tokens[i][j], unk_index)
train_trg_x_indexed = torch.tensor(train_trg_x_indexed_np, device=device)

train_trg_y_indexed_np = np.full((len(train_text_tokens), trg_max_len), pad_index, np.int64)
for i in range(len(train_text_tokens)):
    for j in range(len(train_text_tokens[i])):
        train_trg_y_indexed_np[i, j] = token2index.get(train_text_tokens[i][j], unk_index)
    train_trg_y_indexed_np[i, len(train_text_tokens[i])] = edge_index
train_trg_y_indexed = torch.tensor(train_trg_y_indexed_np, device=device)

In [None]:
# Beam search without attention.
def beam_generate(model, token2index, pad_index, edge_index, unk_index, vocab, src_tokens, max_len, beam_size):
    src_indexed = torch.tensor(
        [[token2index.get(token, unk_index) for token in src_tokens]],
        dtype=torch.int64, device=device
    ).tile((beam_size, 1))
    beam_prefixes_indexed = torch.tensor([[edge_index]], dtype=torch.int64, device=device)
    beam_prefixes_probs = np.array([1.0], np.float32)
    
    best_full_prefix_indexed = None
    best_full_prefix_prob = None
    
    with torch.no_grad():
        for _ in range(max_len):
            outputs = torch.softmax(model(src_indexed[:len(beam_prefixes_indexed), :], beam_prefixes_indexed), dim=2)
            token_probs = outputs[:, -1, :]
            new_prefixes_probs = beam_prefixes_probs[:, None]*token_probs.cpu().numpy()

            new_partial_prefixes = []
            for (prefix, probs_group) in zip(beam_prefixes_indexed.cpu().tolist(), new_prefixes_probs.tolist()):
                for (next_token_index, prefix_prob) in enumerate(probs_group):
                    if next_token_index == edge_index:
                        if best_full_prefix_prob is None or prefix_prob > best_full_prefix_prob:
                            best_full_prefix_indexed = prefix + [next_token_index]
                            best_full_prefix_prob = prefix_prob
                    else:
                        new_partial_prefixes.append((prefix_prob, prefix + [next_token_index]))
            
            new_partial_prefixes.sort(reverse=True)
            (best_partial_prefix_prob, _) = new_partial_prefixes[0]
            if best_full_prefix_prob > best_partial_prefix_prob:
                text = [vocab[index] for index in best_full_prefix_indexed]
                return (text, best_full_prefix_prob)
            
            new_beam = new_partial_prefixes[:beam_size]
            beam_prefixes_indexed = torch.tensor([prefix for (prob, prefix) in new_beam], dtype=torch.int64, device=device)
            beam_prefixes_probs = np.array([prob for (prob, prefix) in new_beam], np.float32)

    text = [vocab[index] for index in beam_prefixes_indexed[0, :].cpu().tolist()]
    return (text, beam_prefixes_probs[0])