## Self-Attention mechanism

Reference:

https://arxiv.org/abs/1706.03762

https://github.com/sooftware/attentions

https://github.com/greentfrapp/attention-primer

In [121]:
import numpy as np 
import string

# visualizatioion pkg 
import seaborn as sns 
import matplotlib.pyplot as plt 
%matplotlib inline
%config InlineBackend.figure_format='retina'

from utils.viz import viz 
viz.get_style()

### Brief introduction

The self-attention is one of the most popular and widely used mechanisms in the natural language processing area (NLP). 

Complex recurrent models are often used to capture the time dependency within the stimulus sequences (for example, phrases, sentences) to ensure the machine reaches a human-level performance in processing and generating languages. Most classic recurrent models with seq2seq training schema (GRU and LSTM) manage the input sequence using a simple philosophy: the input with a longer distance from the current stimulus must have less impact. The training algorithm--backpropagation through time (BPTT)--computes a dramatically lower gradient to the distanced input than the neighborhood. 

The self-attention mechanism processes the inputs depending on the inputs' representation similarity rather than time distance. The mechanism first estimates the between-input similarity and utilizes this similarity to weight these inputs, according to which it distributes its learning biases (via scaling the gradient). 

Here comes two implementational-level questions: 

1. How to calculate the similarity?
2. How to combine the computed attention with the input?

#### Similarity

In 1_RSA, we had introduced a few similarity (distance) metrics. Here, we use the correlation.

$$r(X,Y) = \frac{1}{n-1}\sum^n_i \frac{(x_i-\bar{x})(y_i-\bar{y})}{s_x s_y}$$ 
where $\bar{x}=\frac{1}{n}\sum_i^n x_i$, $s_x = \sqrt{\frac{1}{n-1}\sum_i^n (x_i-\bar{x})^2}$

Let's simplify this equation a bit. Assuming the elements in x, y are sampled from a Gaussian distribution $N(0, 1)$, we can remove the mean term $\bar{x}=0,\bar{y}=0$ and standard deviation $s_xs_y=1$.The correlation becomes a dot product of two vectors divided by the item number:

$$r'(X,Y) = \frac{1}{n-1}\sum^n_i x_i y_i = XY^{\top}$$
using the convention $X \in R^{1\times N}, X \in R^{1\times N}$ 

When using the attention mechanism, the model input is usually a sequence of vector $S = \{s^1, s^2, ...\}$, where the superscript indicates the location of the vector within a sequence. We can conduct a pairwise similarity of this sequence. The "Attention is all you need" paper dubbed the first element in each pair as a query, $Q$ and the second element as a key, $K$. The similarities between an arbitrary query and all keys are, 

$$R(j) = \frac{Q^jK^{\top}}{n}$$
where $-1$ is always neglected. Because $K \in N(0,1)$ and each $R(j)$ is approximately an linear combination of $R(j) \approx \frac{1}{n}K$, the variance of $R(j)$ is about $\frac{1}{\sqrt{n}}$. The similarity is then passed through a softmax function to form an attention distribution. 

$$A(j) = \text{softmax}\left(\frac{Q^jK^{\top}}{\sqrt{n}}\right)$$

Question here: why not $/n$?

Repeat the attention calculation for each query, we get a $A$. 

### Combine the attention and the input

First we need to know what is $Q$ and $K$? Both are linear transformation of the input $S$,

$$Q = S W_q$$
$$K = S W_k$$ 
where $S\in R^{N\times E}$, $N$ is the length of the input sequence, $E$ is the embedding dimension. $W_q, W_k \in R^{E\times H}$, h is the hidden layer dimension. $Q, K \in R^{H\times E}$, and $A \in R^{H \times H} = QK^{\top}$, 

Meanwhile, the input multiplies a matrix $V = S W_v, V \in R^{H \times V}$ for the simple reason of dimension matching.

Attended input is $S' \in R^{H\times V}= AV$


### A simple example: counting letters

The concept of ``attention'' roots in psychology, but the algorithm that implements the attention is not at all psychological. We start with a simple example before discussing a psychology project for better illustration. The simple example is **counting letters**. 






Consider a sequence, where each element is a randomly selected letter or null/blank. The task is to count how many times each letter appears in the sequence.

The task is of course super simple, but we need it to show the attention we learn. Our hypothesis is: the same letter will share same attention. 

In [162]:
## create the task 
class task:

    def __init__(self, win_size=10, vocab_size=3):
        self.win_size = win_size 
        self.vocab_size = vocab_size
        assert vocab_size <= 26

    def next_batch(self, batch_size=100):
        # create input seq 
        seq = np.random.choice(np.arange(self.vocab_size+1), [batch_size, self.win_size])
        # one hot encoding 
        x = np.eye(self.vocab_size+1)[seq]
        # create label
        lab = x.sum(1)[:, 1:].astype(np.int32)
        # one hot encoding
        y = np.eye(self.win_size+1)[lab]
        return x, y

    def toStr(self, samples, labels):
        # label
        samples = samples.reshape(-1, self.win_size, self.vocab_size+1)
        idx  = np.expand_dims(np.argmax(samples, axis=2), 2)
        strs = np.array(list(' ' + string.ascii_uppercase))
        # label
        labels = labels.reshape(-1, self.vocab_size, self.win_size+1)
        num  = np.expand_dims(np.argmax(labels, axis=2), 2)
        return strs[idx].reshape(-1, self.win_size), num.reshape(-1, self.vocab_size)

In [174]:
win_size   = 10 
vocab_size = 5
np.random.seed(1)

# task example 
t = task(win_size, vocab_size)
x, y = t.next_batch(15)

# get some samples 
xx = x[:2, :, :]
yy = y[:2, :, :]

# visualize samples 
x_str, y_str = t.toStr(xx,yy) 
print(f'''
Input string sequence:
{x_str}

Counts of each word:
{y_str}
''')


Input string sequence:
[['E' 'C' 'D' ' ' 'A' 'C' 'E' ' ' ' ' 'A']
 ['D' 'E' 'D' 'A' 'B' 'D' 'E' 'B' 'D' 'C']]

Counts of each word:
[[2 0 2 1 2]
 [1 2 1 4 2]]



In [176]:
import torch
import torch.nn as nn 
from torch.optim import adam

Next we will design our model with attention 

In [None]:
class ATTN(nn.Modules):
    '''Attention Module

    Scaled dot-product attention proposed in "Attention is all you need"

    Args:
        win_size: the size of the sentence window
    
    Inputs:
        Q: (B, H, E), the query array
        K: (B, H, E), the key array 
        V: (B, H, E), the value array 

    Returns:
        context: input weighted by attention 
        attn: attention 

    '''
    def __init__(self, win_size):
        super().__init__()
        self.d = np.sqrt(win_size)

    def forward(self, Q, K, V):
        # BxHxH = BxHxE b@ BxExH
        score = torch.bmm(Q, K.transpose(1, 2)) / self.d 
        attn  = torch.softmax(score, dim=-1)
        # BxHxE = BxHxH b@ BxHxE
        context = torch.bmm(attn, V)
        return context, attn 