# Attention

## Motivation

Seq2seq models face a challenge that the entire representation of the encoded sequence must be captured in a single vector. That encoding represents the concept of the source sequence as a whole. 
All of the rich information in the source sequence must be captured in this "information bottleneck", making it likely that some detail will be lost.

For a task such as translation, which a seq2seq model could tackle, this can make things difficult. The encoding gives you an idea of what the output should represent, but there are often many ways that the source could be translated, and getting a word-to-word translation can be difficult after everything has been summarised.

The typical and intuitive explanation here is that a human translator does not read the whole source sentence, memorise it, and then translate it. Instead, they read the whole thing to get an idea of what the translation needs to represent, and then they translate it part by part, looking back at the source sentence to translate a few words at a time. They are primed with the concept that the translation needs to represent, but they need to pay attention to parts of the source sequence as they perform decide the next word in the translated output.

## What's the result of this?

Vanilla seq2seq models tend to be able to perform well on short sequences, where the information can be "memorised" within just a single vector, but perform worse on longer sequences.

## The Attention Mechanism

We can mathematically define an "attention mechanism".

Overall, it looks like this:

![attention mechanism](../images/RNN%20Seq2seq%20Attention.gif)

#### Questions
- Why do we need the decoder RNN?
    - Different languages so makes sense to have different parameterisations
- Why not also attend to different parts of the decoded sequence
    - Great idea - we will get to that in self-attention

- As long as the decoder contains enough information to tell it where to look back to, then it can grab more information as and when it needs it, instead of wasting effort carrying it throughout.

### The Attention Score

You've only got a limited amount of attention.
But you can pay a different percentage of our attention to each word. 
The most attention we could pay to a word is 100%, and the least is 0%. 
Or 1.0 and 0.0 as proportions.

So we could give each word a number between 0 and 1 which represents the proportion of our attention we give that word.
We call this number $\alpha_t$.


$\alpha$ is a vector of the attention paid to each part of the input.

# $\alpha = \begin{bmatrix} \alpha_1 \\ \vdots \\ \alpha_t \\ \vdots \\ \alpha_T \end{bmatrix}$

In the case of translation, $\alpha$ has as many elements as the source sentence has tokens.

# $\alpha_t \in \R^T$

This is the distribution of our attention paid to each input token.


### How do we calculate the attention score?

Because the attention score is a distribution, it can be computed by applying the softmax function to a vector of logits, $e$. Those logits should have larger values where more attention should be paid.

# $\alpha_t = softmax(e)$

> The logits that are softmaxed to compute the attention distribution are also known as the alignment scores.

In [4]:
import torch
import torch.nn.functional as F

logits = torch.tensor([12, -1, 0.4, 5, 2])
attention_distribution = F.softmax(logits, dim=0)


So how do we compute those attention logits (the alignment scores)?

Intuitively, it would make sense that the attention that should be paid to one word is a function of what we think about the output translation so far, and what we think about that word in the context of the input.

That is, the current decoder hidden state, and the encoder hidden state for the timestep you're computing the attention for.

We call these two things that we want to find the alignment between, queries and values:
- Query, $Q$: 
    - The current decoder state
    - A vector
    - This is a current state which we want to know 

- Values, $V$:
    - The encoder hidden states
    - A set of vectors (perhaps a matrix)
    - This is the set of representations stored in the model's "memory" - you can think of it like model RAM

We say that the query attends to the values.

$alpha_t = a(h_{decoder}^{t'-1}, h_{encoder}^t) = a(Q, V)$

_Note that the most recently computed decoder hidden state is the one from the previous timestep $h_{decoder}^{t'-1}$. We don't have the current decoder hidden state - that's what we are trying to use the attention to compute._

We call the function $a$, the _alignment function_. Intuitively, it tells you which parts of the source sequence correspond to the target sequence. In traditional (non-neural) NLP systems, this was a function that told you which words, if any corresponded to others between the translation pairs.

### Cosing similarity alignment

The simplest ways we can use these two variables is by computing their cosine similarity.


### Could we learn the alignment scoring function too?!

Assuming that the cosine similarity is the right function to compare alignment is quite an assumption.

So let's learn the function instead, like we are doing for the rest of the neural network, by setting it to a trainable neural network.

Typically, we use a 1-layer neural network, passing in a stacked vector of the two input hidden states.



## Using the attention distribution

The point of computing the attention distribution, which tells us which input tokens to pay attention to, was to use it to make a prediction for the next decoder hidden state.

We can use it to create a sum of the encoder hidden state representations, weighted by the attention paid to each of them. 

This is known as the _context_ vector as it gives a representation of the hidden states in context of what should be paid attention.

## $context, c = \sum_t \alpha_t h_{encoder}^t$

# TODO diagram



In [None]:

print(attention_distribution.shape)
encoder_hidden_states = 

context = torch.dot()

The context is combined with the most recent decoder hidden state to compute the next decoder hidden state.

## $next \ decoder \ hidden \ state \ input = \begin{bmatrix}c \\ h_{decoder}^{t'}\end{bmatrix}$

## TODO diagram

This is then processed as any decoder input would be.

In [None]:
class Attention(torch.nn.Module):
    def __init__(self):
        

Cross-attention is the type of attention we have seen here where the values come from a different source (the encoder) than the queries (which come from the decoder)

## Attention variants

There are many forms of attention, but they always include:
1. Computing the attention scores  $e \in \R^N$
1. Turning this into an attention distribution, $\alpha \in R^N$
1. Using the attention distribution to combine the values.

Noticeable variants include
- dot product attention
- Multiplicative attention
- Reduced rank multiplicative attention
- Additive attention. Using a single-layer neural network as the alignment function instead of assuming that the cosine similarity is the right function to align values and queries with
#### $e_t = W_{alignment} \cdot tanh(\begin{bmatrix}W_{encoder} \cdot h_{decoder}^{t'-1} \\ \\ W_{decoder} \cdot  h_{encoder}^t \end{bmatrix})$


In [None]:
class AdditiveAttention(torch.nn.Module):
    def __init__(self, input_dim=128):
        attention_hidden_dim = 128
        self.layers = torch.nn.Sequential( # TODO
            torch.nn.Linear(2*input_dim, attention_hidden_dim), # TODO
            torch.nn.Tanh(), # TODO
            torch.nn.Linear(attention_hidden_dim, 2*input_dim), # TODO
            torch.nn.Softmax() # TODO
        ) # TODO

    def forward(self, query, values):
        alignments = self.layers(
            torch.concat(query, values)
        )
        for 

In [None]:
class Seq2SeqWithAttention(torch.nn.Module):
    def __init__(self):
        

Attention is a general technique that can be applied to many tasks (not just translation) and in many architectures (not just seq2seq)

Attention is a way to combine any arbitrary set of representations (the values), into a fixed size representation dependent on some other representation (the query).

The query is used to select the importance of the keys in the resulting summary.

# TODO diagram directed graph of indexed nodes, where each node is pointed to by only itself and those before it

Note that there is no notion of time (or generally position), which is why we need to encode it.

This makes it different to convolution, which have an explicitly defined use of position.

Attention really works. It smashed benchmarks when it was discovered shortly after seq2seq.

## Why does attention help?

### Attention eliminates the information bottleneck

At every timestep, the decoder can see the entire sequence of encoder hidden states.

# TODO diagram

### Attention opens the gradient superhighway

Becuase of the fact that at every timestep, the entire sequence of encoder hidden states is fed directly to the decoder, the gradient does not have to flow through many sequential layers of the models to influence the weights that affected far away calculations, such as the first encoder hidden state.

# TODO diagram

### Attention makes the model somewhat interpretable

You can tell what is being considered by looking at the attention weights

# TODO diagram