In [None]:
%matplotlib inline

import collections
import random
import matplotlib.pyplot as plt
import nltk
import numpy as np
import pandas as pd
import sklearn.metrics
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Recurrent neural networks 2

Last time we saw how to use an RNN to turn a sequence of tokens into a single vector, regardless of the number of tokens.
This was useful for classifying a whole text but sometimes we want to classify each token of the text instead, such as for part-of-speech tagging (identify nouns, verbs, etc.).
The problem with classifying tokens is that you can't just look at each token in isolation as you'll need to see how that token is used, that is, its context.
Traditionally, this was solved by taking a fixed number of tokens around each token as context (for example, two to the left and two to the right) and passing the chunk of tokens as input, similarly to how the CNN works.
In this topic we'll see how to use the entire text as a context to each token by using RNNs.

## Using the intermediate states

Refer back to this diagram we saw last time:

![](rnn_chain.png)

Up to now we've used the final state (state<sub>3</sub>) for representing the entire text (input<sub>1</sub>, input<sub>2</sub>, and input<sub>3</sub>), but we know that the intermediate states also have a useful meaning: state<sub>2</sub> represents input<sub>1</sub> and input<sub>2</sub> whilst state<sub>1</sub> represents input<sub>1</sub>.

To introduce the use of these intermediate states, we'll try classifying the sentiment of a text (using the toy data set) but using prefixes of the text as input rather than the full text only.
So we'll train the model to learn that "I", "I like", "I like it", and "I like it ." should all be predicted as having a positive sentiment.
This will be done by classifying each intermediate state produced by the RNN when consuming the text rather than only the final state as we've done up to now.
To do this, we'll need to learn some new techniques that will help us.
You may be wondering when happens when a prefix is found in both positive and negative texts; wait and see.

Let's start from how to collect the intermediate states from an RNN.
What you do is, for every time step in the `for` loop, collect each matrix of states (each row being from a different text in the batch) into a list, and then join them all together in a 3D tensor after the loop ends.
If you think this is very inefficient, it's actually what PyTorch suggests in their [example code](https://pytorch.org/docs/stable/generated/torch.nn.LSTMCell.html) (see bottom of web page).
Note that we will not be masking out the pad tokens just yet because that's something that we'll do later.

In [None]:
embedding_size = 4
state_size = 5

x_indexed = torch.tensor([
    [1, 0, 0],
    [1, 2, 0],
], dtype=torch.int64, device=device)
pad_index = 0
print('x_indexed:')
print(x_indexed)
print()

batch_size = x_indexed.shape[0]
time_steps = x_indexed.shape[1]

embedding = torch.nn.Embedding(3, embedding_size)
embedding.to(device)
embedded = embedding(x_indexed)
print('embedded:')
print(embedded)
print()

rnn = torch.nn.LSTMCell(embedding_size, state_size)
rnn.to(device)
s0 = torch.zeros((state_size,), dtype=torch.float32, device=device)
c0 = torch.zeros((state_size,), dtype=torch.float32, device=device)

state = s0[None, :].tile((batch_size, 1))
c = c0[None, :].tile((batch_size, 1))
interm_states_list = []
for t in range(time_steps):
    (state, c) = rnn(embedded[:, t, :], (state, c))
    interm_states_list.append(state)

for (t, state) in enumerate(interm_states_list, start=1):
    print('state at time', t)
    print(state)
    print()

To join all of these states into a 3D tensor (number of texts by numbers of state vectors by state vector size) we need to use the `stack` function:

In [None]:
interm_states = torch.stack(interm_states_list, dim=1)
print(interm_states.shape)
print(interm_states)

The next problem we have is how to apply the output layer on each one of these states.
You might think that you cannot apply a linear layer on a 3D tensor; thankfully this is not the case.

In [None]:
num_outputs = 1

layer = torch.nn.Linear(state_size, num_outputs)
layer.to(device)

logits = layer(interm_states)
print('logits for 2 texts by 4 state vectors each')
print(logits)

Note that the matrix multiplication operator (`@`) let's you multiply a 3D tensor by a matrix and it will automatically multiply each vector in the 3D tensor by the weight matrix and preserve the shape of the resulting tensor.

Finally we need to think about how to compute the error.
We can calculate the cross-entropy of all the logits in the 3D tensor by comparing it with a 3D tensor of target values.

In [None]:
train_y = torch.tensor([
    [[1], [1], [0]],
    [[0], [1], [0]],
], dtype=torch.float32, device=device)
print('train_y')
print(train_y)
print()

error = torch.nn.functional.binary_cross_entropy_with_logits(logits, train_y)
print('error')
print(error)

But this is ignoring the fact that we'll have pad tokens in our inputs, which should be ignored.
To do this we'll need to do several things.

First, we need to calculate the individual cross entropies of all logits.
The cross entropy function, by default, calculates the mean of the cross entropy of each logit value.
You can make it give you individual cross entropies using the parameter `reduction='none'`:

In [None]:
token_errors = torch.nn.functional.binary_cross_entropy_with_logits(logits, train_y, reduction='none')
print(token_errors)

Next, we need to mask out all the errors that are from pad tokens such that they are replaced by zeros.

In [None]:
pad_mask = x_indexed == pad_index

print('pad_mask:')
print(pad_mask)
print()

print('masked errors')
token_errors = token_errors.masked_fill(pad_mask[:, :, None], 0.0)
print(token_errors)

Finally we need to calculate the average error of the unmasked errors.
You can't just use the `torch.mean` function because that will include the zeroed out errors as well.
Instead, we'll compute it ourselves:

In [None]:
error = token_errors.sum()/(~pad_mask).sum() # Total error divided by the total number of non-pad tokens.
print(error)

Note that `~` is for complementing/negating a mask so that what's true becomes false and what's false becomes true.

In [None]:
print('pad_mask:')
print(pad_mask)
print()
print('~pad_mask:')
print(~pad_mask)

Now we can put it all together to make a neural network that tries to classify the sentiment of each prefix of a text.

In [None]:
train_x = [
    'I like it .'.split(' '),
    'I hate it .'.split(' '),
    'I don\'t hate it .'.split(' '),
    'I don\'t like it .'.split(' '),
]
train_y = torch.tensor([
    [1],
    [0],
    [1],
    [0],
], dtype=torch.float32, device=device)

max_len = max(len(text) for text in train_x)
print('max_len:', max_len)

vocab = ['<PAD>'] + sorted({token for text in train_x for token in text})
token2index = {t: i for (i, t) in enumerate(vocab)}
pad_index = token2index['<PAD>']
print('vocab:', vocab)
print()

train_x_indexed_np = np.full((len(train_x), max_len), pad_index, np.int64)
for i in range(len(train_x)):
    for j in range(len(train_x[i])):
        train_x_indexed_np[i, j] = token2index[train_x[i][j]]
train_x_indexed = torch.tensor(train_x_indexed_np, device=device)

# The target value of each text will be replicated for each prefix (number of prefixes == number of tokens).
train_y_seq = train_y[:, None, :].tile((1, max_len, 1))
print('train_y_seq:')
print(train_y_seq)

In [None]:
class Model(torch.nn.Module):

    def __init__(self, vocab_size, embedding_size, state_size):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size)
        self.rnn_s0 = torch.nn.Parameter(torch.zeros((state_size,), dtype=torch.float32))
        self.rnn_c0 = torch.nn.Parameter(torch.zeros((state_size,), dtype=torch.float32))
        self.rnn_cell = torch.nn.LSTMCell(embedding_size, state_size)
        self.output_layer = torch.nn.Linear(state_size, 1)

    def forward(self, x_indexed):
        batch_size = x_indexed.shape[0]
        time_steps = x_indexed.shape[1]

        embedded = self.embedding(x_indexed)
        state = self.rnn_s0[None, :].tile((batch_size, 1))
        c = self.rnn_s0[None, :].tile((batch_size, 1))
        interm_states_list = []
        for t in range(time_steps):
            # No need to mask anything here, because we'll be masking the output at the end.
            (state, c) = self.rnn_cell(embedded[:, t, :], (state, c))
            interm_states_list.append(state)
        interm_states = torch.stack(interm_states_list, dim=1)
        return self.output_layer(interm_states)

model = Model(len(vocab), embedding_size=2, state_size=2)
model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.1)

print('epoch', 'error')
train_errors = []
for epoch in range(1, 1000+1):
    pad_mask = train_x_indexed == pad_index
    
    optimiser.zero_grad()
    logits = model(train_x_indexed)
    train_token_errors = torch.nn.functional.binary_cross_entropy_with_logits(logits, train_y_seq, reduction='none')
    train_token_errors = train_token_errors.masked_fill(pad_mask[:, :, None], 0.0)
    train_error = train_token_errors.sum()/(~pad_mask).sum()
    train_errors.append(train_error.detach().cpu().tolist())
    train_error.backward()
    optimiser.step()

    if epoch%100 == 0:
        print(epoch, train_errors[-1])
print()

with torch.no_grad():
    print('text', 'output')
    output = torch.sigmoid(model(train_x_indexed))[:, :, 0].cpu().tolist()
    for (text, y) in zip(train_x, output):
        print(text + ['<PAD>']*(max_len - len(text)), y)

(fig, ax) = plt.subplots(1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('$E$')
ax.plot(range(1, len(train_errors) + 1), train_errors, color='blue', linestyle='-', linewidth=3)
ax.grid()

In the above outputs, we're seeing the sentiment prediction for each prefix starting from the first token.
So we're seeing how the neural network has learned to classify "I", "I like", "I like it", "I like it .", and "I like it . \<PAD>" for example.
We can see how the first prefix ("I") is always close to 0.5, because it is equally likely to be in a positive or a negative text.
As soon as there is a prefix that is unique to a text category, the neural network gives that prefix and all the following prefixes the correct output.

How did the neural network give an output of 0.5 for the ambiguous prefixes if it wasn't trained to do so?
The number comes from the fact that it is what gives the smallest training error when you have to give both an output of 0 and an output of 1 for the same input (it's the closest single output to both desired outputs).
If a particular prefix was in 2 training set texts with a target output of 1 and 1 text with a target of 0, its prediction would instead settle on $\frac{2}{2 + 1} = \frac{2}{3} \approx 0.67$.
In general, if a prefix is in $a$ training set texts with a target output of 1 and $b$ texts with a target of 0, a sigmoid output will settle somewhere close to $\frac{a}{a+b}$.
Remember this for the next topic.

## Bi-directional RNNs

Intermediates are not only useful for making predictions about prefixes (although that is a big use for them, as we'll see in the next topic).
You can also use them to make **bi-directional RNNs**.
For tasks that require information about the full context of a token rather than just its preceding tokens (prefix), such as for part-of-speech tagging, the intermediate states of a single RNN are not enough as that will only tell us what is on one side of a token.
We need to have information about both the tokens before the token being tagged as well as the tokens that come after.
For this purpose we need two RNNs: one that goes forward in the text and one that goes backward, hence bi-directional.
We call the forward RNN 'fw' and the backward RNN 'bw' in short.

![](bidirectional_rnn.png)

By concatenating the two sets of intermediate states, as shown at the bottom of the diagram, each token will be represented by information about all the tokens in front and behind it.

Unfortunately we can't just run the `for` loop backwards on the sequences for the bw RNN because that would mean having it start with the pad tokens which means that the pad tokens would influence the representation.
So for the bw RNN we'll need to use the masking technique we used in the previous topic so that the initial state remains as the current state until the first non-pad token is encountered.

Also, we'll need to reverse the bw intermediate states before concatenating them to the fw intermediate states (see the above diagram).
Since the list of intermediate states is a normal Python list (before being stacked into a single tensor), we can reverse the list easily as follows:

In [None]:
print(['a', 'b', 'c', 'd'][::-1])

After reversing the list it can then be stacked into a single 3D tensor.

In [None]:
class Model(torch.nn.Module):

    def __init__(self, vocab_size, embedding_size, state_size, pad_index):
        super().__init__()
        self.pad_index = pad_index
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size)
        self.rnn_fw_s0 = torch.nn.Parameter(torch.zeros((state_size,), dtype=torch.float32))
        self.rnn_fw_c0 = torch.nn.Parameter(torch.zeros((state_size,), dtype=torch.float32))
        self.rnn_fw_cell = torch.nn.LSTMCell(embedding_size, state_size)
        self.rnn_bw_s0 = torch.nn.Parameter(torch.zeros((state_size,), dtype=torch.float32))
        self.rnn_bw_c0 = torch.nn.Parameter(torch.zeros((state_size,), dtype=torch.float32))
        self.rnn_bw_cell = torch.nn.LSTMCell(embedding_size, state_size)
        self.output_layer = torch.nn.Linear(2*state_size, 1) # The input to this layer will be the fw and bw states concatenated together.

    def forward(self, x_indexed):
        batch_size = x_indexed.shape[0]
        time_steps = x_indexed.shape[1]
        non_pad_mask = x_indexed != pad_index

        embedded = self.embedding(x_indexed)
        
        state = self.rnn_fw_s0[None, :].tile((batch_size, 1))
        c = self.rnn_fw_c0[None, :].tile((batch_size, 1))
        interm_states_list = []
        for t in range(time_steps):
            (state, c) = self.rnn_fw_cell(embedded[:, t, :], (state, c))
            interm_states_list.append(state)
        interm_states_fw = torch.stack(interm_states_list, dim=1)

        state = self.rnn_bw_s0[None, :].tile((batch_size, 1))
        c = self.rnn_bw_c0[None, :].tile((batch_size, 1))
        interm_states_list = []
        for t in reversed(range(time_steps)): # Go backwards from time_steps-1 to 0.
            (new_state, new_c) = self.rnn_bw_cell(embedded[:, t, :], (state, c))
            state = torch.where(non_pad_mask[:, t, None], new_state, state)
            c = torch.where(non_pad_mask[:, t, None], new_c, c) # We also need to mask the cell state now because it also needs to remain as the initial state while we're on a pad token.
            interm_states_list.append(state)
        interm_states_bw = torch.stack(interm_states_list[::-1], dim=1) # Reverse the bw intermediate states.

        interm_states = torch.concat((interm_states_fw, interm_states_bw), dim=2)

        return self.output_layer(interm_states)

model = Model(len(vocab), embedding_size=2, state_size=2, pad_index=pad_index)
model.to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.1)

print('epoch', 'error')
train_errors = []
for epoch in range(1, 1000+1):
    pad_mask = train_x_indexed == pad_index
    
    optimiser.zero_grad()
    logits = model(train_x_indexed)
    train_token_errors = torch.nn.functional.binary_cross_entropy_with_logits(logits, train_y_seq, reduction='none')
    train_token_errors = train_token_errors.masked_fill(pad_mask[:, :, None], 0.0)
    train_error = train_token_errors.sum()/(~pad_mask).sum()
    train_errors.append(train_error.detach().cpu().tolist())
    train_error.backward()
    optimiser.step()

    if epoch%100 == 0:
        print(epoch, train_errors[-1])
print()

with torch.no_grad():
    print('text', 'output')
    output = torch.sigmoid(model(train_x_indexed))[:, :, 0].cpu().tolist()
    for (text, y) in zip(train_x, output):
        print(text + ['<PAD>']*(max_len - len(text)), y)
        
(fig, ax) = plt.subplots(1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('$E$')
ax.plot(range(1, len(train_errors) + 1), train_errors, color='blue', linestyle='-', linewidth=3)
ax.grid()

Why are even the previously ambiguous prefixes getting classified correctly now?

## Exercises

### 1) Vote-based classification

Rewrite the movie reviews classification program as follows:

* Use a bi-directional RNN to classify each token in the texts as shown above.
* *When measuring the accuracy on the test set*, convert all the individual token classifications into a single classification by taking the average of the token-level predictions (this is a form of classifier voting or **ensembling** meant to improve the performance).
    Do not do this while training!

Preprocessing has been done for you.

In [None]:
min_freq = 3

train_df = pd.read_csv('../data_set/sentiment/train.csv')
test_df = pd.read_csv('../data_set/sentiment/test.csv')

train_x = train_df['text']
train_y = train_df['class']
test_x = test_df['text']
test_y = test_df['class']
categories = ['neg', 'pos']
cat2idx = {cat: i for (i, cat) in enumerate(categories)}

train_y_indexed = torch.tensor(
    train_y.map(cat2idx.get).to_numpy()[:, None],
    dtype=torch.float32, device=device
)
test_y_indexed = test_y.map(cat2idx.get).to_numpy()[:, None]

nltk.download('punkt')
train_x_tokens = [nltk.word_tokenize(text) for text in train_x]
test_x_tokens = [nltk.word_tokenize(text) for text in test_x]
max_len = max(max(len(text) for text in train_x_tokens), max(len(text) for text in test_x_tokens))

frequencies = collections.Counter(token for text in train_x_tokens for token in text)
vocabulary = sorted(frequencies.keys(), key=frequencies.get, reverse=True)
while frequencies[vocabulary[-1]] < min_freq:
    vocabulary.pop()
vocab = ['<PAD>', '<UNK>'] + vocabulary
token2index = {token: i for (i, token) in enumerate(vocab)}
pad_index = token2index['<PAD>']
unk_index = token2index['<UNK>']

train_x_indexed_np = np.full((len(train_x_tokens), max_len), pad_index, np.int64)
for i in range(len(train_x_tokens)):
    for j in range(len(train_x_tokens[i])):
        train_x_indexed_np[i, j] = token2index.get(train_x_tokens[i][j], unk_index)
train_x_indexed = torch.tensor(train_x_indexed_np, device=device)

test_x_indexed_np = np.full((len(test_x_tokens), max_len), pad_index, np.int64)
for i in range(len(test_x_tokens)):
    for j in range(len(test_x_tokens[i])):
        test_x_indexed_np[i, j] = token2index.get(test_x_tokens[i][j], unk_index)
test_x_indexed = torch.tensor(test_x_indexed_np, device=device)

train_y_indexed_seq = train_y_indexed[:, None, :].tile([1, max_len, 1])