In [None]:
%matplotlib inline

import collections
import random
import matplotlib.pyplot as plt
import nltk
import numpy as np
import pandas as pd
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Transformers

In 2017, another paper that revolutionised deep learning for NLP was published: [Attention is All you Need](https://papers.nips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html) by Vaswani and others.
In it, the **transformer** was described, which was a new way to encode words-in-context instead of the bi-directional RNN.

The problem with RNNs is that they are inherently sequential.
You can't get the next state unless you have the previous one.
So regardless of how big your GPU is, if you want to encode a sequence with $n$ items then you will need to go through $n$ sequential time steps.

The transformer solves this problem by doing something similar to taking the average vector of the embedding vectors.
Adding all the embedding vectors together can be done in parallel, so that would solve the RNN's problem.
Unfortunately it also loses word order information, which is a big deal.
We'll get to how word order information is preserved in a bit, but for now we should focus on this part of the process.

## Queries, keys, and values

Transformers do the following:

* Take each token vector ($t_i$) and use three separate neural layers to transform each token vector into three separate vectors: a query vector ($q_i$), a key vector ($k_i$), and a value vector ($v_i$).
* Combine each $q_i$ with each $k_j$ (including the key from the same token as $q_i$) to produce an attention value $a_{ij}$.
* Multiply $a_{ij}$ by $v_j$ (the value vector that came from the same token as the key vector) and take their sum to produce a vector for every token.

Here is a diagram illustrating this architecture:

![](encoder_full.png)

Here is the same diagram but focusing only on the first token's context vector:

![](encoder_focused.png)

Note how we're comparing each token $t_i$ to every other token $t_j$, including $t_i$ itself, to determine how important $t_j$ is for understanding $t_i$.
This importance is used to make a weighted average of all the tokens in the text in order to represent the meaning of $t_i$.
So we compare the query vector of $t_i$ to the key vector of $t_j$ to see if we should return the value vector of $t_j$, similar to how a Python dictionary works, except that a weighted average of all the values in the dictionary is returned instead of just one.

This process is done for each token which, again, happens all in parallel.
The advantage of having separate query and key vectors is that you can compare a token to itself without comparing the tokens's vector to itself (so the model can control how similar the token should be to itself).

As explained in the previous topic, the attention maker is the dot product of the two vectors (the key and the query in this case).
Vaswani adds a modification to avoid the dot product from getting too big due to having large query and key vectors: the dot product is divided by the square root of the vector size.

Furthermore, there is not just one word-in-context vector per token coming out at the other end.
Just like in convolutional neural networks, the embeddings are split to be processed by several branches in parallel, each of which having its own query-key-value vectors and attention values.
This is called **multihead attention** in transformers.

![](encoder_focused_branched.png)

Let's show how all of this is done in PyTorch.
First, create the token embeddings:

In [None]:
x_indexed = torch.tensor([
    [1, 2, 0],
    [1, 0, 0],
], dtype=torch.int64, device=device)

batch_size = x_indexed.shape[0]
time_steps = x_indexed.shape[1]
pad_index = 0
vocab_size = 3
num_branches = 2
embedding_size = 8

embedding = torch.nn.Embedding(vocab_size, embedding_size)
embedding.to(device)

embedded = embedding(x_indexed)
print('embedded:')
print(embedded.shape)
print(embedded)

Next, split the token vectors into branches by adding another dimension (this is much faster than working with lists like in CNNs):

In [None]:
branched_embedded = embedded.reshape(
    (batch_size, time_steps, num_branches, embedding_size//num_branches)
)
branched_embedded = branched_embedded.transpose(1, 2) # Put the branches dimension after the batch size.
print('branched_embedded:')
print(branched_embedded.shape)
print(branched_embedded)

Create the queries, keys, and values from these branched token vectors.
Note that we can just pass in that 4D tensor through a linear layer and everything will be as expected.

In [None]:
query_layer = torch.nn.Linear(embedding_size//num_branches, embedding_size)
key_layer = torch.nn.Linear(embedding_size//num_branches, embedding_size)
value_layer = torch.nn.Linear(embedding_size//num_branches, embedding_size)
query_layer.to(device)
key_layer.to(device)
value_layer.to(device)

q = query_layer(branched_embedded)
k = key_layer(branched_embedded)
v = value_layer(branched_embedded)

print('q:')
print(q.shape)
print()
print('k:')
print(k.shape)
print()
print('v:')
print(v.shape)

Next, we produce the attention logits by taking the dot product of every query to every key (remember that this can be made by just making a matrix multiplication):

In [None]:
attn_logits = q@k.transpose(2, 3)

# Divide the logits by the square root of the vector size.
sqrt_dim = np.sqrt(embedding_size)
attn_logits = attn_logits/sqrt_dim

print('attn_logits:')
print(attn_logits.shape) # Dimensions are: [batch, branch, query, key]
print(attn_logits)

Next, mask out the pad tokens by replacing their logits with negative infinity which will result in an attention of 0:

In [None]:
pad_mask = x_indexed == pad_index
attn_logits = attn_logits.masked_fill(pad_mask[:, None, None, :], float('-inf')) # Add a singleton dimension for the branch dimension and the query dimension (queries corresponding to pad tokens also need to be ignored but that can be done later).
print('attn_logits:')
print(attn_logits.shape)
print(attn_logits)

Next, we convert these logits into attention values:

In [None]:
attention = torch.softmax(attn_logits, dim=3)
print('attention:')
print(attention.shape)
print(attention)

Now we can apply the attention to the values which, again, is just a matrix multiplication:

In [None]:
branched_attended_values = attention@v
print('branched_attended_values:')
print(branched_attended_values.shape)

We can then rejoin the branches:

In [None]:
attended_values = branched_attended_values.transpose(1, 2).reshape((batch_size, time_steps, num_branches*embedding_size))
print('attended_values:')
print(attended_values.shape)

Finally transform these vectors with another layer to squash them to the value size and those are your word-in-context vectors:

In [None]:
word_in_context_layer = torch.nn.Linear(num_branches*embedding_size, embedding_size)
word_in_context_layer.to(device)

word_in_context = torch.nn.functional.leaky_relu(word_in_context_layer(attended_values))
print('word_in_context:')
print(word_in_context.shape)

There are also some other things missing from Vaswani's implementation such as residual connections and layer normalisation, but we don't need to get into those.
Also Vaswani's transformer was actually a sequence to sequence model, which we'll get to later in this topic.

Let's use a transformer on the toy data set in order to re-implement the text classification at every token task we previously implemented using a bi-directional RNN.

In [None]:
train_x = [
    'I like it .'.split(' '),
    'I hate it .'.split(' '),
    'I don\'t hate it .'.split(' '),
    'I don\'t like it .'.split(' '),
]
train_y = torch.tensor([
    [1],
    [0],
    [1],
    [0],
], dtype=torch.float32, device=device)

max_len = max(len(text) for text in train_x)
print('max_len:', max_len)

vocab = ['<PAD>'] + sorted({token for text in train_x for token in text})
token2index = {t: i for (i, t) in enumerate(vocab)}
pad_index = token2index['<PAD>']
print('vocab:', vocab)
print()

train_x_indexed_np = np.full((len(train_x), max_len), pad_index, np.int64)
for i in range(len(train_x)):
    for j in range(len(train_x[i])):
        train_x_indexed_np[i, j] = token2index[train_x[i][j]]
train_x_indexed = torch.tensor(train_x_indexed_np, device=device)

train_y_seq = train_y[:, None, :].tile((1, max_len, 1))
print('train_y_seq:')
print(train_y_seq)

In [None]:
class Model(torch.nn.Module):

    def __init__(self, vocab_size, embedding_size, num_branches, pad_index):
        super().__init__()
        self.pad_index = pad_index
        self.embedding_size = embedding_size
        self.num_branches = num_branches
        self.sqrt_dim = np.sqrt(embedding_size)
        
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size)
        self.query_layer = torch.nn.Linear(embedding_size//num_branches, embedding_size)
        self.key_layer = torch.nn.Linear(embedding_size//num_branches, embedding_size)
        self.value_layer = torch.nn.Linear(embedding_size//num_branches, embedding_size)
        self.word_in_context_layer = torch.nn.Linear(num_branches*embedding_size, embedding_size)
        self.output_layer = torch.nn.Linear(embedding_size, 1)

    def forward(self, x_indexed):
        batch_size = x_indexed.shape[0]
        time_steps = x_indexed.shape[1]
        pad_mask = x_indexed == self.pad_index

        embedded = self.embedding(x_indexed)
        
        branched_embedded = embedded.reshape(
            (batch_size, time_steps, self.num_branches, self.embedding_size//self.num_branches)
        ).transpose(1, 2)

        q = self.query_layer(branched_embedded)
        k = self.key_layer(branched_embedded)
        v = self.value_layer(branched_embedded)

        attn_logits = q@k.transpose(2, 3)
        attn_logits = attn_logits/self.sqrt_dim
        attn_logits = attn_logits.masked_fill(pad_mask[:, None, None, :], float('-inf'))
        attention = torch.softmax(attn_logits, dim=3)
        branched_attended_values = attention@v

        attended_values = branched_attended_values.transpose(1, 2).reshape(
            (batch_size, time_steps, self.num_branches*self.embedding_size)
        )
        word_in_context = torch.nn.functional.leaky_relu(self.word_in_context_layer(attended_values))

        return self.output_layer(word_in_context)

model = Model(len(vocab), embedding_size=4, num_branches=2, pad_index=pad_index)
model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.01)

print('epoch', 'error')
train_errors = []
for epoch in range(1, 1000+1):
    batch_size = train_y_seq.shape[0]
    time_steps = train_y_seq.shape[1]
    pad_mask = train_x_indexed == pad_index
    
    optimiser.zero_grad()
    logits = model(train_x_indexed)
    train_token_errors = torch.nn.functional.binary_cross_entropy_with_logits(logits, train_y_seq, reduction='none')
    train_token_errors = torch.masked_fill(train_token_errors, pad_mask[:, :, None], 0.0)
    train_error = train_token_errors.sum()/(~pad_mask).sum()
    train_errors.append(train_error.detach().tolist())
    train_error.backward()
    optimiser.step()

    if epoch%100 == 0:
        print(epoch, train_errors[-1])
print()

with torch.no_grad():
    print('text', 'output')
    output = torch.sigmoid(model(train_x_indexed))[:, :, 0].cpu().tolist()
    for (text, y) in zip(train_x, output):
        print(text + ['<PAD>']*(max_len - len(text)), y)

(fig, ax) = plt.subplots(1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('$E$')
ax.plot(range(1, len(train_errors) + 1), train_errors, color='blue', linestyle='-', linewidth=3)
ax.grid()

Thankfully there is a built-in module for multihead attention:

    torch.nn.MultiheadAttention

This is how you use it:

In [None]:
class Model(torch.nn.Module):

    def __init__(self, vocab_size, embedding_size, num_branches, pad_index):
        super().__init__()
        self.pad_index = pad_index
        
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size)

        self.multihead_attention_layer = torch.nn.MultiheadAttention(embedding_size, num_branches, batch_first=True) # batch_first is used to say that the token vectors will be provided with the batch as the first dimension rather than the second.
        
        self.output_layer = torch.nn.Linear(embedding_size, 1)

    def forward(self, x_indexed):
        batch_size = x_indexed.shape[0]
        time_steps = x_indexed.shape[1]
        pad_mask = x_indexed == self.pad_index

        embedded = self.embedding(x_indexed)
        
        # The __forward__ function of the multihead attention module returns two things: the words in context and the attention values; the attention values are set to None unless need_weights is set to True.
        (word_in_context, _) = self.multihead_attention_layer(query=embedded, key=embedded, value=embedded, key_padding_mask=pad_mask, need_weights=False)
        
        return self.output_layer(word_in_context)

model = Model(len(vocab), embedding_size=4, num_branches=2, pad_index=pad_index)
model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.001)

print('epoch', 'error')
train_errors = []
for epoch in range(1, 1000+1):
    batch_size = train_y_seq.shape[0]
    time_steps = train_y_seq.shape[1]
    pad_mask = train_x_indexed == pad_index
    
    optimiser.zero_grad()
    logits = model(train_x_indexed)
    train_token_errors = torch.nn.functional.binary_cross_entropy_with_logits(logits, train_y_seq, reduction='none')
    train_token_errors = torch.masked_fill(train_token_errors, pad_mask[:, :, None], 0.0)
    train_error = train_token_errors.sum()/(~pad_mask).sum()
    train_errors.append(train_error.detach().tolist())
    train_error.backward()
    optimiser.step()

    if epoch%100 == 0:
        print(epoch, train_errors[-1])
print()

with torch.no_grad():
    print('text', 'output')
    output = torch.sigmoid(model(train_x_indexed))[:, :, 0].cpu().tolist()
    for (text, y) in zip(train_x, output):
        print(text + ['<PAD>']*(max_len - len(text)), y)

(fig, ax) = plt.subplots(1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('$E$')
ax.plot(range(1, len(train_errors) + 1), train_errors, color='blue', linestyle='-', linewidth=3)
ax.grid()

You might have noticed that the multihead attention module expects you to pass the embedded tokens 3 times, once for the query, once for the key, and once for the value.
Later on we'll see a situation where these would not be the same tensor.

## Token order information

As-is, when calculating the attention, there is no way to take into account where a particular query and key pair are situated in the text, which means that the neural network won't see anything different if you shuffle the order of the tokens.
To solve this problem, the embedding vectors are modified to include positional information by adding to each embedding vector a positional vector.
A positional vector is like an embedding vector but for positions instead of tokens.
There are many ways to do this, but the simplest is by using a **positioning matrix**.

A positioning matrix would have a row vector for every token position, just like an embedding matrix has a row vector for every token in the vocabulary.
The problem with this is that, just like you must have a fixed number of tokens in your vocabulary, you also must have a fixed number of positions, which means that you have a maximum text length you can process.
This is a common issue in transformers.

We can get a sequence of position indexes for every token in a batch by using `torch.arange`, which gives a tensor of increasing numbers:

In [None]:
batch_size = 3
time_steps = 5
position_indexes = torch.arange(time_steps)[None, :].tile((batch_size, 1))
print(position_indexes)

These indexes can be passed into an embedding layer to convert them into vectors which can then be either concatenated to the token vectors or simply added together, the second being the most common method.

Let's use it on the toy data set.

In [None]:
class Model(torch.nn.Module):

    def __init__(self, vocab_size, max_len, embedding_size, num_branches, pad_index):
        super().__init__()
        self.pad_index = pad_index
        self.register_buffer('positions', torch.arange(max_len)) # This creates a positions instance variable that is not a parameter but still moves to the requested device when the module's 'to' method is used.
        
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size)
        self.positioning = torch.nn.Embedding(max_len, embedding_size)
        self.multihead_attention_layer = torch.nn.MultiheadAttention(embedding_size, num_branches, batch_first=True)
        self.output_layer = torch.nn.Linear(embedding_size, 1)

    def forward(self, x_indexed):
        batch_size = x_indexed.shape[0]
        time_steps = x_indexed.shape[1]
        pad_mask = x_indexed == self.pad_index
        
        embedded = self.embedding(x_indexed)
        positions = self.positions[None, :time_steps].tile((batch_size, 1)) # Trim the positions sequence to the number of time steps used.
        positioned = self.positioning(positions)
        embedded = embedded + positioned # Position vectors are added to the embedding vectors.
        
        (word_in_context, _) = self.multihead_attention_layer(query=embedded, key=embedded, value=embedded, key_padding_mask=pad_mask, need_weights=False)
        
        return self.output_layer(word_in_context)

model = Model(len(vocab), max_len=max_len, embedding_size=4, num_branches=2, pad_index=pad_index)
model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.001)

print('epoch', 'error')
train_errors = []
for epoch in range(1, 1000+1):
    batch_size = train_y_seq.shape[0]
    time_steps = train_y_seq.shape[1]
    pad_mask = train_x_indexed == pad_index
    
    optimiser.zero_grad()
    logits = model(train_x_indexed)
    train_token_errors = torch.nn.functional.binary_cross_entropy_with_logits(logits, train_y_seq, reduction='none')
    train_token_errors = torch.masked_fill(train_token_errors, pad_mask[:, :, None], 0.0)
    train_error = train_token_errors.sum()/(~pad_mask).sum()
    train_errors.append(train_error.detach().tolist())
    train_error.backward()
    optimiser.step()

    if epoch%100 == 0:
        print(epoch, train_errors[-1])
print()

with torch.no_grad():
    print('text', 'output')
    output = torch.sigmoid(model(train_x_indexed))[:, :, 0].cpu().tolist()
    for (text, y) in zip(train_x, output):
        print(text + ['<PAD>']*(max_len - len(text)), y)

(fig, ax) = plt.subplots(1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('$E$')
ax.plot(range(1, len(train_errors) + 1), train_errors, color='blue', linestyle='-', linewidth=3)
ax.grid()

## Text representation

We've seen how to represent each word-in-context, but how do you represent a whole text with one vector?
A lot of papers to do so by just taking the average of the word-in-context vectors.
The average of a batch of vector sequences that need to be masked is found as follows:

In [None]:
pad_mask = torch.tensor([
    [0, 0, 1],
    [0, 1, 1],
], dtype=torch.bool, device=device)
print('pad_mask')
print(pad_mask)
print()

batch_size = pad_mask.shape[0]
time_step = pad_mask.shape[1]
embedding_size = 4

word_in_context = torch.randn((batch_size, time_step, embedding_size), dtype=torch.float32, device=device)
print('word_in_context')
print(word_in_context)
print()

word_in_context = word_in_context.masked_fill(pad_mask[:, :, None], 0.0) # Zero out the vectors corresponding to pad tokens.
text_vecs = word_in_context.sum(dim=1)/(~pad_mask).sum(dim=1)[:, None] # Divide the each sum of the vector sequences by the number of tokens in each sequence.
print('text_vecs')
print(text_vecs)

This is how it would be used:

In [None]:
class Model(torch.nn.Module):

    def __init__(self, vocab_size, max_len, embedding_size, num_branches, pad_index):
        super().__init__()
        self.pad_index = pad_index
        self.register_buffer('positions', torch.arange(max_len))
        
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size)
        self.positioning = torch.nn.Embedding(max_len, embedding_size)
        self.multihead_attention_layer = torch.nn.MultiheadAttention(embedding_size, num_branches, batch_first=True)
        self.output_layer = torch.nn.Linear(embedding_size, 1)

    def forward(self, x_indexed):
        batch_size = x_indexed.shape[0]
        time_steps = x_indexed.shape[1]
        pad_mask = x_indexed == self.pad_index
        
        embedded = self.embedding(x_indexed)
        positions = self.positions[None, :time_steps].tile((batch_size, 1))
        positioned = self.positioning(positions)
        embedded = embedded + positioned
        
        (word_in_context, _) = self.multihead_attention_layer(query=embedded, key=embedded, value=embedded, key_padding_mask=pad_mask, need_weights=False)
        text_vecs = word_in_context.sum(dim=1)/(~pad_mask).sum(dim=1)[:, None]
        
        return self.output_layer(text_vecs)

model = Model(len(vocab), max_len=max_len, embedding_size=4, num_branches=2, pad_index=pad_index)
model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.001)

print('epoch', 'error')
train_errors = []
for epoch in range(1, 1000+1):
    optimiser.zero_grad()
    logits = model(train_x_indexed)
    train_error = torch.nn.functional.binary_cross_entropy_with_logits(logits, train_y)
    train_errors.append(train_error.detach().tolist())
    train_error.backward()
    optimiser.step()

    if epoch%100 == 0:
        print(epoch, train_errors[-1])
print()

with torch.no_grad():
    print('text', 'output')
    output = torch.sigmoid(model(train_x_indexed))[:, 0].cpu().tolist()
    for (text, y) in zip(train_x, output):
        print(text, y)

(fig, ax) = plt.subplots(1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('$E$')
ax.plot(range(1, len(train_errors) + 1), train_errors, color='blue', linestyle='-', linewidth=3)
ax.grid()

Another way to get a vector for the entire text is to add another special token, usually called a **class token**, which is always added to the beginning of the text.
The vector produced at this token is then used to represent the whole text.

In [None]:
train_x = [
    '<CLS> I like it .'.split(' '),
    '<CLS> I hate it .'.split(' '),
    '<CLS> I don\'t hate it .'.split(' '),
    '<CLS> I don\'t like it .'.split(' '),
]
train_y = torch.tensor([
    [1],
    [0],
    [1],
    [0],
], dtype=torch.float32, device=device)

max_len = max(len(text) for text in train_x)
print('max_len:', max_len)

vocab = ['<PAD>'] + sorted({token for text in train_x for token in text})
token2index = {t: i for (i, t) in enumerate(vocab)}
pad_index = token2index['<PAD>']
print('vocab:', vocab)
print()

train_x_indexed_np = np.full((len(train_x), max_len), pad_index, np.int64)
for i in range(len(train_x)):
    for j in range(len(train_x[i])):
        train_x_indexed_np[i, j] = token2index[train_x[i][j]]
train_x_indexed = torch.tensor(train_x_indexed_np, device=device)

In [None]:
class Model(torch.nn.Module):

    def __init__(self, vocab_size, max_len, embedding_size, num_branches, pad_index):
        super().__init__()
        self.pad_index = pad_index
        self.register_buffer('positions', torch.arange(max_len))
        
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size)
        self.positioning = torch.nn.Embedding(max_len, embedding_size)
        self.multihead_attention_layer = torch.nn.MultiheadAttention(embedding_size, num_branches, batch_first=True)
        self.output_layer = torch.nn.Linear(embedding_size, 1)

    def forward(self, x_indexed):
        batch_size = x_indexed.shape[0]
        time_steps = x_indexed.shape[1]
        pad_mask = x_indexed == self.pad_index
        
        embedded = self.embedding(x_indexed)
        positions = self.positions[None, :time_steps].tile((batch_size, 1))
        positioned = self.positioning(positions)
        embedded = embedded + positioned
        
        (word_in_context, _) = self.multihead_attention_layer(query=embedded, key=embedded, value=embedded, key_padding_mask=pad_mask, need_weights=False)
        text_vecs = word_in_context[:, 0, :] # Take the vector of the first token, which should be the CLS token.
        
        return self.output_layer(text_vecs)

model = Model(len(vocab), max_len=max_len, embedding_size=4, num_branches=2, pad_index=pad_index)
model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.001)

print('epoch', 'error')
train_errors = []
for epoch in range(1, 1000+1):
    optimiser.zero_grad()
    logits = model(train_x_indexed)
    train_error = torch.nn.functional.binary_cross_entropy_with_logits(logits, train_y)
    train_errors.append(train_error.detach().tolist())
    train_error.backward()
    optimiser.step()

    if epoch%100 == 0:
        print(epoch, train_errors[-1])
print()

with torch.no_grad():
    print('text', 'output')
    output = torch.sigmoid(model(train_x_indexed))[:, 0].cpu().tolist()
    for (text, y) in zip(train_x, output):
        print(text, y)

(fig, ax) = plt.subplots(1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('$E$')
ax.plot(range(1, len(train_errors) + 1), train_errors, color='blue', linestyle='-', linewidth=3)
ax.grid()

## Language modelling / text generation

Transformers are also very popular for generating text.
Just like with an RNN, you predict the next token for every prefix.
Unlike an RNN, a transformer encodes the left and right context of each token, so it doesn't make prefix vectors.
To make each word context vector only use its prefix as context we need to mask the tokens to the right for each token.

![](decoder.png)

This is done using a triangular mask:

In [None]:
ones = torch.ones((4, 4), dtype=torch.bool, device=device)
print(torch.triu(ones, diagonal=1))

A triangular mask can be used to make each query ignore the keys of the tokens to its right by making those keys have an attention of zero, thus be ignored when computing the word-in-context vector.

This automatically handles pad tokens as well since any query that attends to a pad token key will also be itself a pad token and so will be masked when calculating the train error.
Remember that both the keys and the queries are just different representations of the same tokens.

Let's implement a language modelling task on the toy data set using our fully-coded transformer.

In [None]:
train_text_tokens = [
    'I like it .'.split(' '),
    'I hate it .'.split(' '),
    'I don\'t hate it .'.split(' '),
    'I don\'t like it .'.split(' '),
]

max_len = max(len(text) for text in train_text_tokens) + 1
print('max_len:', max_len)

vocab = ['<PAD>', '<EDGE>'] + sorted({token for text in train_text_tokens for token in text})
token2index = {t: i for (i, t) in enumerate(vocab)}
pad_index = token2index['<PAD>']
edge_index = token2index['<EDGE>']
print('vocab:', vocab)

train_text_x_indexed_np = np.full((len(train_text_tokens), max_len), pad_index, np.int64)
for i in range(len(train_text_tokens)):
    train_text_x_indexed_np[i, 0] = edge_index
    for j in range(len(train_text_tokens[i])):
        train_text_x_indexed_np[i, j + 1] = token2index[train_text_tokens[i][j]]
train_text_x_indexed = torch.tensor(train_text_x_indexed_np, device=device)

train_text_y_indexed_np = np.full((len(train_text_tokens), max_len), pad_index, np.int64)
for i in range(len(train_text_tokens)):
    for j in range(len(train_text_tokens[i])):
        train_text_y_indexed_np[i, j] = token2index[train_text_tokens[i][j]]
    train_text_y_indexed_np[i, len(train_text_tokens[i])] = edge_index
train_text_y_indexed = torch.tensor(train_text_y_indexed_np, device=device)

In [None]:
class Model(torch.nn.Module):

    def __init__(self, vocab_size, max_len, embedding_size, num_branches):
        super().__init__()
        self.embedding_size = embedding_size
        self.num_branches = num_branches
        self.embedding_size = embedding_size
        self.sqrt_dim = np.sqrt(embedding_size)
        self.register_buffer('positions', torch.arange(max_len))
        self.register_buffer('tri_mask', torch.triu(torch.ones((max_len, max_len), dtype=torch.bool), diagonal=1)) # Make the tri mask part of the module.
        
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size)
        self.positioning = torch.nn.Embedding(max_len, embedding_size)
        self.query_layer = torch.nn.Linear(embedding_size//num_branches, embedding_size)
        self.key_layer = torch.nn.Linear(embedding_size//num_branches, embedding_size)
        self.value_layer = torch.nn.Linear(embedding_size//num_branches, embedding_size)
        self.word_in_context_layer = torch.nn.Linear(num_branches*embedding_size, embedding_size)
        self.output_layer = torch.nn.Linear(embedding_size, vocab_size)

    def forward(self, x_indexed):
        batch_size = x_indexed.shape[0]
        time_steps = x_indexed.shape[1]
        
        embedded = self.embedding(x_indexed)
        positions = self.positions[None, :time_steps].tile((batch_size, 1))
        positioned = self.positioning(positions)
        embedded = embedded + positioned
        
        branched_embedded = embedded.reshape(
            (batch_size, time_steps, self.num_branches, self.embedding_size//self.num_branches)
        ).transpose(1, 2)

        q = self.query_layer(branched_embedded)
        k = self.key_layer(branched_embedded)
        v = self.value_layer(branched_embedded)

        attn_logits = q@k.transpose(2, 3)
        attn_logits = attn_logits/self.sqrt_dim
        attn_logits = attn_logits.masked_fill(self.tri_mask[None, None, :time_steps, :time_steps], float('-inf')) # Use the tri mask instead of the pad mask.
        attention = torch.softmax(attn_logits, dim=3)
        branched_attended_values = attention@v

        attended_values = branched_attended_values.transpose(1, 2).reshape(
            (batch_size, time_steps, self.num_branches*self.embedding_size)
        )
        word_in_context = torch.nn.functional.leaky_relu(self.word_in_context_layer(attended_values))

        return self.output_layer(word_in_context)

model = Model(len(vocab), max_len, embedding_size=4, num_branches=2)
model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.01)

print('epoch', 'error')
train_errors = []
for epoch in range(1, 1000+1):
    batch_size = train_y_seq.shape[0]
    time_steps = train_y_seq.shape[1]
    pad_mask = train_x_indexed == pad_index
    
    optimiser.zero_grad()
    logits = model(train_text_x_indexed)
    train_token_errors = torch.nn.functional.cross_entropy(logits.transpose(1, 2), train_text_y_indexed, reduction='none')
    train_token_errors = train_token_errors.masked_fill(pad_mask[:, :], 0.0)
    train_error = train_token_errors.sum()/(~pad_mask).sum()
    train_errors.append(train_error.detach().tolist())
    train_error.backward()
    optimiser.step()

    if epoch%100 == 0:
        print(epoch, train_errors[-1])
print()

(fig, ax) = plt.subplots(1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('$E$')
ax.plot(range(1, len(train_errors) + 1), train_errors, color='blue', linestyle='-', linewidth=3)
ax.grid()

In [None]:
prefixes = sorted({tuple(text[:i]) for text in train_text_tokens for i in range(len(text) + 1)}) # Prefix must be a tuple so we can put it in a set.
prefix_max_len = max(len(prefix) for prefix in prefixes) + 1
prefix_text_x_indexed_np = np.full((len(prefixes), prefix_max_len), pad_index, np.int64)
for i in range(len(prefixes)):
    prefix_text_x_indexed_np[i, 0] = edge_index
    for j in range(len(prefixes[i])):
        prefix_text_x_indexed_np[i, j+1] = token2index[prefixes[i][j]]
prefix_text_x_indexed = torch.tensor(prefix_text_x_indexed_np, device=device)
with torch.no_grad():
    output = torch.softmax(model(prefix_text_x_indexed), dim=2).cpu().tolist()

for (prefix, y) in zip(prefixes, output):
    last_token_y = y[len(prefix)]
    top_preds = sorted(zip(last_token_y, vocab), reverse=True)[:5]
    print(['<EDGE>'] + list(prefix))
    for (x, token) in top_preds:
        print(f'   {token:6s}: {x:.8f}')
    print()

And this is how we use Pytorch's `MultiheadAttention` to do the same:

In [None]:
class Model(torch.nn.Module):

    def __init__(self, vocab_size, max_len, embedding_size, num_branches):
        super().__init__()
        self.register_buffer('positions', torch.arange(max_len))
        self.register_buffer('tri_mask', torch.triu(torch.ones((max_len, max_len), dtype=torch.bool), diagonal=1))
        
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size)
        self.positioning = torch.nn.Embedding(max_len, embedding_size)
        self.multihead_attention_layer = torch.nn.MultiheadAttention(embedding_size, num_branches, batch_first=True)
        self.output_layer = torch.nn.Linear(embedding_size, vocab_size)

    def forward(self, x_indexed):
        batch_size = x_indexed.shape[0]
        time_steps = x_indexed.shape[1]
        
        embedded = self.embedding(x_indexed)
        positions = self.positions[None, :time_steps].tile((batch_size, 1))
        positioned = self.positioning(positions)
        embedded = embedded + positioned
        
        # We now specify the tri mask to the attn_mask parameter and set the is_causal parameter to true (causal refers to a language model that predicts the next token given the previous ones).
        (word_in_context, _) = self.multihead_attention_layer(query=embedded, key=embedded, value=embedded, attn_mask=self.tri_mask[:time_steps, :time_steps], need_weights=False, is_causal=True)
        
        return self.output_layer(word_in_context)

model = Model(len(vocab), max_len, embedding_size=4, num_branches=2)
model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.01)

print('epoch', 'error')
train_errors = []
for epoch in range(1, 1000+1):
    batch_size = train_y_seq.shape[0]
    time_steps = train_y_seq.shape[1]
    pad_mask = train_x_indexed == pad_index
    
    optimiser.zero_grad()
    logits = model(train_text_x_indexed)
    train_token_errors = torch.nn.functional.cross_entropy(logits.transpose(1, 2), train_text_y_indexed, reduction='none')
    train_token_errors = train_token_errors.masked_fill(pad_mask[:, :], 0.0)
    train_error = train_token_errors.sum()/(~pad_mask).sum()
    train_errors.append(train_error.detach().tolist())
    train_error.backward()
    optimiser.step()

    if epoch%100 == 0:
        print(epoch, train_errors[-1])
print()

(fig, ax) = plt.subplots(1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('$E$')
ax.plot(range(1, len(train_errors) + 1), train_errors, color='blue', linestyle='-', linewidth=3)
ax.grid()

In [None]:
prefixes = sorted({tuple(text[:i]) for text in train_text_tokens for i in range(len(text) + 1)}) # Prefix must be a tuple so we can put it in a set.
prefix_max_len = max(len(prefix) for prefix in prefixes) + 1
prefix_text_x_indexed_np = np.full((len(prefixes), prefix_max_len), pad_index, np.int64)
for i in range(len(prefixes)):
    prefix_text_x_indexed_np[i, 0] = edge_index
    for j in range(len(prefixes[i])):
        prefix_text_x_indexed_np[i, j+1] = token2index[prefixes[i][j]]
prefix_text_x_indexed = torch.tensor(prefix_text_x_indexed_np, device=device)
with torch.no_grad():
    output = torch.softmax(model(prefix_text_x_indexed), dim=2).cpu().tolist()

for (prefix, y) in zip(prefixes, output):
    last_token_y = y[len(prefix)]
    top_preds = sorted(zip(last_token_y, vocab), reverse=True)[:5]
    print(['<EDGE>'] + list(prefix))
    for (x, token) in top_preds:
        print(f'   {token:6s}: {x:.8f}')
    print()

## seq2seq

So how do you perform seq2seq with a transformer?
According to Vaswani's paper, you use 3 transformers.
The first transformer acts as a language model and is used to produce prefix vectors from the target tokens.
These prefix vectors will be used to make query vectors.
The second transformer acts as a word-in-context model and is used to produce token vectors from the source tokens.
These token vectors will be used to make key and value vectors.
The query vectors are then combined with key and value vectors and passed through the third transformer.
This aligns the target prefixes to the source tokens (in context) to produce a translation.

Here is a diagram showing what happens in the third transformer (the first two are as shown in the previous diagrams):

![](seq2seq.png)

Note that there is no need to use a triangular mask in the third transformer because the prefix vectors are already not influenced by the target tokens that come after them.

Now let's use it on the toy data set for sentiment translation.

In [None]:
train_src_tokens = [
    'I like it .'.split(' '),
    'I hate it .'.split(' '),
    'I don\'t hate it .'.split(' '),
    'I don\'t like it .'.split(' '),
]

train_trg_tokens = [
    'I don\'t like it .'.split(' '),
    'I don\'t hate it .'.split(' '),
    'I hate it .'.split(' '),
    'I like it .'.split(' '),
]

src_max_len = max(len(text) for text in train_src_tokens)
print('src_max_len:', src_max_len)

src_vocab = ['<PAD>'] + sorted({token for text in train_src_tokens for token in text})
src_token2index = {t: i for (i, t) in enumerate(src_vocab)}
src_pad_index = src_token2index['<PAD>']
print('src_vocab:', src_vocab)
print()

train_src_indexed_np = np.full([len(train_src_tokens), src_max_len], src_pad_index, np.int64)
for i in range(len(train_src_tokens)):
    for j in range(len(train_src_tokens[i])):
        train_src_indexed_np[i, j] = src_token2index[train_src_tokens[i][j]]
train_src_indexed = torch.tensor(train_src_indexed_np, device=device)

trg_max_len = max(len(text) + 1 for text in train_trg_tokens)
print('trg_max_len:', trg_max_len)

trg_vocab = ['<PAD>', '<EDGE>'] + sorted({token for text in train_trg_tokens for token in text})
trg_token2index = {t: i for (i, t) in enumerate(trg_vocab)}
trg_pad_index = trg_token2index['<PAD>']
trg_edge_index = trg_token2index['<EDGE>']
print('trg_vocab:', trg_vocab)
print()

train_trg_x_indexed_np = np.full((len(train_trg_tokens), trg_max_len), trg_pad_index, np.int64)
for i in range(len(train_trg_tokens)):
    train_trg_x_indexed_np[i, 0] = trg_edge_index
    for j in range(len(train_trg_tokens[i])):
        train_trg_x_indexed_np[i, j + 1] = trg_token2index[train_trg_tokens[i][j]]
train_trg_x_indexed = torch.tensor(train_trg_x_indexed_np, device=device)

train_trg_y_indexed_np = np.full((len(train_trg_tokens), trg_max_len), trg_pad_index, np.int64)
for i in range(len(train_trg_tokens)):
    for j in range(len(train_trg_tokens[i])):
        train_trg_y_indexed_np[i, j] = trg_token2index[train_trg_tokens[i][j]]
    train_trg_y_indexed_np[i, len(train_trg_tokens[i])] = trg_edge_index # Add the edge token at the end.
train_trg_y_indexed = torch.tensor(train_trg_y_indexed_np, device=device)

max_len = max(src_max_len, trg_max_len)

In [None]:
class Model(torch.nn.Module):

    def __init__(self, src_vocab_size, trg_vocab_size, max_len, embedding_size, num_branches, src_pad_index):
        super().__init__()
        self.src_pad_index = src_pad_index
        self.register_buffer('positions', torch.arange(max_len))
        self.register_buffer('tri_mask', torch.triu(torch.ones((max_len, max_len), dtype=torch.bool), diagonal=1))
        self.positioning = torch.nn.Embedding(max_len, embedding_size)

        # Prefixes transformer
        self.trg_embedding = torch.nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_multihead_attention_layer = torch.nn.MultiheadAttention(embedding_size, num_branches, batch_first=True)

        # Sources transformer
        self.src_embedding = torch.nn.Embedding(src_vocab_size, embedding_size)
        self.src_multihead_attention_layer = torch.nn.MultiheadAttention(embedding_size, num_branches, batch_first=True)

        # Combining transformer
        self.multihead_attention_layer = torch.nn.MultiheadAttention(embedding_size, num_branches, batch_first=True)

        self.output_layer = torch.nn.Linear(embedding_size, trg_vocab_size)

    def forward(self, src_indexed, trg_x_indexed):
        batch_size = src_indexed.shape[0]
        src_time_steps = src_indexed.shape[1]
        trg_time_steps = trg_x_indexed.shape[1]

        #############
        # Prefixes
        #############
        
        embedded = self.trg_embedding(trg_x_indexed)
        positions = self.positions[None, :trg_time_steps].tile((batch_size, 1))
        positioned = self.positioning(positions)
        embedded = embedded + positioned

        (prefix_vecs, _) = self.trg_multihead_attention_layer(query=embedded, key=embedded, value=embedded, attn_mask=self.tri_mask[:trg_time_steps, :trg_time_steps], need_weights=False, is_causal=True)

        #############
        # Sources
        #############

        src_pad_mask = src_indexed == self.src_pad_index

        embedded = self.src_embedding(src_indexed)
        positions = self.positions[None, :src_time_steps].tile((batch_size, 1))
        positioned = self.positioning(positions)
        embedded = embedded + positioned

        (word_in_context, _) = self.src_multihead_attention_layer(query=embedded, key=embedded, value=embedded, key_padding_mask=src_pad_mask, need_weights=False)

        #############
        # Combining
        #############

        (final_prefix_vecs, _) = self.multihead_attention_layer(query=prefix_vecs, key=word_in_context, value=word_in_context, key_padding_mask=src_pad_mask, need_weights=False)
        
        return self.output_layer(final_prefix_vecs)


model = Model(len(src_vocab), len(trg_vocab), max_len, embedding_size=4, num_branches=2, src_pad_index=src_pad_index)
model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.01)

print('epoch', 'error')
train_errors = []
for epoch in range(1, 1000+1):
    batch_size = train_trg_x_indexed.shape[0]
    trg_time_steps = train_trg_x_indexed.shape[1]
    trg_pad_mask = train_trg_y_indexed == trg_pad_index
    
    optimiser.zero_grad()
    logits = model(train_src_indexed, train_trg_x_indexed)
    train_token_errors = torch.nn.functional.cross_entropy(logits.transpose(1, 2), train_trg_y_indexed, reduction='none')
    train_token_errors = torch.masked_fill(train_token_errors, trg_pad_mask, 0.0)
    train_error = train_token_errors.sum()/(~trg_pad_mask).sum()
    train_errors.append(train_error.detach().cpu().tolist())
    train_error.backward()
    optimiser.step()

    if epoch%100 == 0:
        print(epoch, train_errors[-1])
print()

(fig, ax) = plt.subplots(1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('$E$')
ax.plot(range(1, len(train_errors) + 1), train_errors, color='blue', linestyle='-', linewidth=3)
ax.grid()

In [None]:
def beam_generate(model, src_token2index, src_pad_index, trg_vocab, trg_edge_index, src_tokens, max_len, beam_size):
    src_indexed = torch.tensor(
        [[src_token2index[token] for token in src_tokens]],
        dtype=torch.int64, device=device
    ).tile((beam_size, 1))
    beam_prefixes_indexed = torch.tensor([[trg_edge_index]], dtype=torch.int64, device=device)
    beam_prefixes_probs = np.array([1.0], np.float32)
    
    best_full_prefix_indexed = None
    best_full_prefix_prob = None
    
    with torch.no_grad():
        for _ in range(max_len):
            outputs = torch.softmax(model(src_indexed[:len(beam_prefixes_indexed), :], beam_prefixes_indexed), dim=2)
            token_probs = outputs[:, -1, :]
            new_prefixes_probs = beam_prefixes_probs[:, None]*token_probs.cpu().numpy()

            new_partial_prefixes = []
            for (prefix, probs_group) in zip(beam_prefixes_indexed.cpu().tolist(), new_prefixes_probs.tolist()):
                for (next_token_index, prefix_prob) in enumerate(probs_group):
                    if next_token_index == trg_edge_index:
                        if best_full_prefix_prob is None or prefix_prob > best_full_prefix_prob:
                            best_full_prefix_indexed = prefix + [next_token_index]
                            best_full_prefix_prob = prefix_prob
                    else:
                        new_partial_prefixes.append((prefix_prob, prefix + [next_token_index]))
            
            new_partial_prefixes.sort(reverse=True)
            (best_partial_prefix_prob, _) = new_partial_prefixes[0]
            if best_full_prefix_prob > best_partial_prefix_prob:
                text = [trg_vocab[index] for index in best_full_prefix_indexed]
                return (text, best_full_prefix_prob)
            
            new_beam = new_partial_prefixes[:beam_size]
            beam_prefixes_indexed = torch.tensor([prefix for (prob, prefix) in new_beam], dtype=torch.int64, device=device)
            beam_prefixes_probs = np.array([prob for (prob, prefix) in new_beam], np.float32)

    text = [trg_vocab[index] for index in beam_prefixes_indexed[0, :].cpu().tolist()]
    return (text, beam_prefixes_probs[0])
    
with torch.no_grad():
    for src_text in train_src_tokens:
        print(src_text)
        (trg_text, prob) = beam_generate(model, src_token2index, src_pad_index, trg_vocab, trg_edge_index, src_text, max_len=10, beam_size=3)
        print(trg_text, prob)
        print()

## Exercises

### 1) Autoencoding sentences again

Redo the sentence autoencoder exercise from last topic but this time use transformers.
Use a class token to represent the original sentence as a single token vector.
Note that, given that you will have only one source token, you don't technically need to use a transformer for combining the prefixes with the source tokens, although you still can if you want.
If you do use a transformer, make sure that you don't use a pad mask in it because there's just one token.

In [None]:
min_freq = 3

train_df = pd.read_csv('../data_set/sentiment/train.csv')
test_df = pd.read_csv('../data_set/sentiment/test.csv')

train_text = train_df['text']
test_text = test_df['text'][0]

nltk.download('punkt')
train_text_tokens = [nltk.word_tokenize(text) for text in train_text]
test_text_tokens = nltk.word_tokenize(test_text)
max_len = max(len(text) for text in train_text_tokens) + 1 # Both src and trg have same max len because src includes the CLS token and trg includes the EDGE token.

frequencies = collections.Counter(token for text in train_text_tokens for token in text)
vocabulary = sorted(frequencies.keys(), key=frequencies.get, reverse=True)
while frequencies[vocabulary[-1]] < min_freq:
    vocabulary.pop()
vocab = ['<PAD>', '<EDGE>', '<UNK>', '<CLS>'] + vocabulary
token2index = {token: i for (i, token) in enumerate(vocab)}
pad_index = token2index['<PAD>']
edge_index = token2index['<EDGE>']
unk_index = token2index['<UNK>']
cls_index = token2index['<CLS>']

train_src_indexed_np = np.full((len(train_text_tokens), max_len), pad_index, np.int64)
for i in range(len(train_text_tokens)):
    train_src_indexed_np[i, 0] = cls_index
    for j in range(len(train_text_tokens[i])):
        train_src_indexed_np[i, j + 1] = token2index.get(train_text_tokens[i][j], unk_index)
train_src_indexed = torch.tensor(train_src_indexed_np, device=device)

train_trg_x_indexed_np = np.full((len(train_text_tokens), max_len), pad_index, np.int64)
for i in range(len(train_text_tokens)):
    train_trg_x_indexed_np[i, 0] = edge_index
    for j in range(len(train_text_tokens[i])):
        train_trg_x_indexed_np[i, j + 1] = token2index.get(train_text_tokens[i][j], unk_index)
train_trg_x_indexed = torch.tensor(train_trg_x_indexed_np, device=device)

train_trg_y_indexed_np = np.full((len(train_text_tokens), max_len), pad_index, np.int64)
for i in range(len(train_text_tokens)):
    for j in range(len(train_text_tokens[i])):
        train_trg_y_indexed_np[i, j] = token2index.get(train_text_tokens[i][j], unk_index)
    train_trg_y_indexed_np[i, len(train_text_tokens[i])] = edge_index
train_trg_y_indexed = torch.tensor(train_trg_y_indexed_np, device=device)

In [None]:
def beam_generate(model, token2index, pad_index, edge_index, unk_index, cls_index, vocab, src_tokens, max_len, beam_size):
    src_indexed = torch.tensor(
        [[cls_index] + [token2index.get(token, unk_index) for token in src_tokens]],
        dtype=torch.int64, device=device
    ).tile((beam_size, 1))
    beam_prefixes_indexed = torch.tensor([[edge_index]], dtype=torch.int64, device=device)
    beam_prefixes_probs = np.array([1.0], np.float32)
    
    best_full_prefix_indexed = None
    best_full_prefix_prob = None
    
    with torch.no_grad():
        for _ in range(max_len):
            outputs = torch.softmax(model(src_indexed[:len(beam_prefixes_indexed), :], beam_prefixes_indexed), dim=2)
            token_probs = outputs[:, -1, :]
            new_prefixes_probs = beam_prefixes_probs[:, None]*token_probs.cpu().numpy()

            new_partial_prefixes = []
            for (prefix, probs_group) in zip(beam_prefixes_indexed.cpu().tolist(), new_prefixes_probs.tolist()):
                for (next_token_index, prefix_prob) in enumerate(probs_group):
                    if next_token_index == edge_index:
                        if best_full_prefix_prob is None or prefix_prob > best_full_prefix_prob:
                            best_full_prefix_indexed = prefix + [next_token_index]
                            best_full_prefix_prob = prefix_prob
                    else:
                        new_partial_prefixes.append((prefix_prob, prefix + [next_token_index]))
            
            new_partial_prefixes.sort(reverse=True)
            (best_partial_prefix_prob, _) = new_partial_prefixes[0]
            if best_full_prefix_prob > best_partial_prefix_prob:
                text = [vocab[index] for index in best_full_prefix_indexed]
                return (text, best_full_prefix_prob)
            
            new_beam = new_partial_prefixes[:beam_size]
            beam_prefixes_indexed = torch.tensor([prefix for (prob, prefix) in new_beam], dtype=torch.int64, device=device)
            beam_prefixes_probs = np.array([prob for (prob, prefix) in new_beam], np.float32)

    text = [vocab[index] for index in beam_prefixes_indexed[0, :].cpu().tolist()]
    return (text, beam_prefixes_probs[0])