# Notes: Attention Is All You Need and Transformers are RNNs

1. Attention Is All You Need
2. Transformers are RNNs:Fast Autoregressive Transformers with Linear Attention

## Introduction to Transformers

Attention Is All You Need (AIAYN) introduces the concept of a Transformer. They outline the transformer in the following figure:

![transformer diagram](https://i.imgur.com/thw3UDN.png)

The transformer consists of a encoder and decoder. The goal of the transformer is to encode a source set of tokens into a hidden state to be passed to the multi-head attention (more on this in a sec) portion of the decoder in order to predict the next token in a sequence. A simple application example would be using a transformer to translate one sentence to another.

### Attention

The meat of the AIAYN paper is in their use of what they deem Attention. Their attention mechanism is defined as:

$$\mathrm{Attention} (Q,K,V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})\mathrm{V}$$

* $Q = \mathrm{Querys}$ 
* $K = \mathrm{Keys}$ 
* $V = \mathrm{Values}$ 

where they also list the following diagrams:

![attention mechanism](https://i.imgur.com/rMHl7xy.png)


## Transformers are RNNs

This paper outlines the problem of transformers where computing the Scaled Dot Product Attention results in a quadratic complexity. The authors propose using a kernel to downgrade the dimensionality of the attention equation.

Given:

$$\mathrm{V'} =\mathrm{Attention} (Q,K,V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})\mathrm{V}$$

we can rewrite as:

$$\mathrm{V'_i} = \frac{\sum^N_{j=1} \mathrm{sim}(Q_i,K_j)V_j}{\sum^N_{j=1} \mathrm{sim}(Q_i,K_j)}$$

where:

$$\mathrm{sim}(q,k) = \mathrm{exp}(\frac{q^T k}{\sqrt{D}})$$

### Linearization 

From here the authors introduce their method for linearized attention where $\phi(x)$ represents their kernel function and simplify the above equation:

$$\mathrm{V'_i} = \frac{\sum^N_{j=1} \phi(Q_i)^T \phi(K_j)V_j}{\sum^N_{j=1} \phi(Q_i)^T \phi(K_j)}$$

$$\mathrm{V'_i} = \frac{ \phi(Q_i)^T \sum^N_{j=1}  \phi(K_j)V_j}{\phi(Q_i)^T \sum^N_{j=1} \phi(K_j)}$$

They explain that this results in better performance in the following quote:

![complexity_simplification](https://i.imgur.com/0jrSzvd.png)

Where equation 2 represents the original scaled dot product attention.

### The connection between RNNs and Transformers

The authors make the connection between RNNs and Transformers by first defining how they implement causal masking. This makes training the the transformer really fast since you can paralleize for each token in a given sequence. They do this by restricting each token so it can only look at the token in the input sequence at the same position or previous position like so:

$\mathrm{V'_i} = \frac{ \phi(Q_i)^T \sum^N_{j=1}  \phi(K_j)V_j}{\phi(Q_i)^T \sum^N_{j=1} \phi(K_j)}$ becomes $\mathrm{V'_i} = \frac{ \phi(Q_i)^T \sum^i_{j=1}  \phi(K_j)V^T_j}{\phi(Q_i)^T \sum^i_{j=1} \phi(K_j)}$

this + the linearization means that you can calculate $\phi(K_j)V^T_j$ once per time step and just store it.

They then simplify to:

$$S_i=\sum^i_{j=1}\phi(K_i)(V^T_j)$$

$$Z_i=\sum^i_{j=1}\phi(K_i)$$

which allows for the authors to interpret transformers as RNNs as follows:

$s_0=0$

$z_0=0$

$s_i=s_{i-1} +\phi(x_iW_K)(x_iW_V)^T$

$z_i=z_{i-1} +\phi(x_iW_K)$

$y_i= f_l(\frac{x_iW_Q)^Ts_i}{x_iW_Q)^Tz_i}+x_i)$

### Performance

They get significantly better performance in the linear transformer's inference time as well as the training time (as expected). They also use significantly less memory due to fact they only have to calculate. 