# Attention Basics
In a typical encoder-decoder framework, we want to summarize the information from $\mathbf{x} = x_1,\dots,x_I$ by some representions/hidden states. These hidden states will then be used in the decoder to predict $\mathbf{y} = y_1,\dots,y_J$. 

The encoding operation can be best described by $h_{i+1}=f^E(h_i, x_i)\in \mathbb{R}^K$, where $f^E$ is the chosen encoding function, and $h_{i+1}$ is the new hidden state vector encoded with additional information from $x_i$. At the end, $h_I$ should have encoded all the necessary information from the given input sequence $\mathbf{x}$ for the decoder to consume. 

The decoder is then initialized with $s_0=h_I$ and produces both the target $y_j$ and the corresponding hidden state $s_j$ for a given $y_{j-1}$ and $s_{j-1}$. The problem with this approach is that $h_I$ is often limited in providing sufficient information for the decoder to accuractely compute $y_1,\dots,y_J$ when $m$ is large. An obvious approach is to reintroduce $h_1,\dots,h_{I-1}$ to see if they can provide useful information for the decoder.

Toward this end, a neural network is used to learn the importance of $h_1,\dots,h_{I}$ and see how much each of them contribute to the prediction of $y_j$. More precisely, the contribution/attention weight for a given $i$ and $j$ is
\\[
    a_{j,i} = \text{softmax}_{i=1,\dots,I}\left( \text{score}(h_i, s_{j-1})\right)
\\]
where the scoring function (many other options are available) is 
\\[
    \text{score}(s_{j-1}, h_i) = \frac{h_is_{j-1}^\intercal}{\sqrt{K}}.
\\]
The scoring function measures how close $h_i$ and $s_{j-1}$ are, and the scaling is done to prevent extremely large values resulting from the dot product when $K$ is large (so that the resulting score is hidden size independent). As the final step, we multiply the attention weight with the encoder hidden states to produce a vector, context vector, with information that the decoder can use. More precisely, for $y_j$, the context vector $c_j\in \mathbb{R}^K$ is
\\[
    c_j = a_{j,:}^\intercal h_:
\\]
The context vector is then added to decode the hidden state vector for $j$:
\\[
    s_{j} = f( s_{j-1}, y_{j-1}, c_j).
\\]

## Scaled-dot-product Attention 

To put the attention in the transformer context, we have the context matrix of the scaled-dot-product attention 

\\[
    c = \text{softmax}\left(\frac{q k^\intercal }{\sqrt{H}}\right)v
\\]
where
- $s$: $q\in R^{J\times K}$ is the query,
- $h$: $k\in R^{I\times K}$ is the key,
- $h$: $v\in R^{I\times K}$ is the value,
- $c\in R^{I\times K}$ is the attention, and
- the softmax function is applied for each row.

The score $qk^\intercal$ is normalized by $K$ to prevent the attention score to be overly large before applying the softmax. The operation is summarized on the left hand side of Figure 2 below.

<img src="images/scaled_mhead_attentions.png" alt="Figure 2. From ‘Attention Is All You Need’ by Vaswani et al." style="width:80%;"/>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
# B: batch size (could be mostly ignored)
# I, J: seq, time, max # of words in a sentence
# K: hidden layer, size of word embedding
B, I, J, K = 3, 9, 10, 18

In [3]:
# define dropout and softmax layers
#dropout = nn.Dropout(0.1)
softmax = nn.Softmax(dim=-1) # softmax for each j

In [4]:
# initialize the query, key, and value
q = torch.rand((B, J, K))
k = torch.rand((B, I, K))
v = torch.rand((B, I, K))

In [5]:
# each query has a key
# (B, J, K) x (B, K, I) -> (B, J, I)
s = torch.matmul(q, k.transpose(-2, -1).contiguous())

In [6]:
# scale to prevent value from being too large because of the large bmm
s = s / math.sqrt(K)

In [7]:
# softmax
#s = torch.exp(s)
#s = s / s.sum(-1, keepdim=True)
# or
a = F.softmax(s, dim=-1)
#s = dropout(s)

In [8]:
# each keyed-weight of a query multiplies v to get the final attention score
# (B, J, I) x (B, I, K) -> (B, J, K)
c = torch.bmm(a, v)

In [9]:
print('attention weights: {}'.format(a.shape[1:]))
print('context matrix: {}'.format(c.shape[1:]))

# ensure that the weight summed up to 1
print('first batch, first j, sum(i) = {}'.format(sum(a[0][0])))

attention weights: torch.Size([10, 9])
context matrix: torch.Size([10, 18])
first batch, first j, sum(i) = 1.0000001192092896


In [10]:
def scaled_dot_product_attention(q, k, v):
    a = F.softmax(torch.matmul(q, k.transpose(-2, -1).contiguous()) / math.sqrt(k.shape[-1]), dim=-1)
    return torch.matmul(a, v)

In [11]:
# check if the function is correct
torch.all(c.eq(scaled_dot_product_attention(q, k, v)))

tensor(True)