# BERT's Anatomy Step by Step: Self-Attention

In [None]:
import matplotlib
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'svg'

import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoConfig, AutoTokenizer
from transformers import BertForPreTraining

In [None]:
model_checkpoint = 'bert-base-uncased'

model = BertForPreTraining.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
config = AutoConfig.from_pretrained(model_checkpoint)

In [None]:
encoding = tokenizer.encode("let's tokenize something?", return_tensors="pt")

In [None]:
tokens = tokenizer.convert_ids_to_tokens(encoding.flatten())

In [None]:
seq_embedding = model.bert.embeddings.word_embeddings(encoding)
seq_embedding.shape   # (batch_size, seq_len, hidden_size)

## Self Attention

In the context of deep learning, attention is a mechanism that enables a model to focus on specific parts of input data while processing information. It allows the model to assign varying degrees of importance or relevance to different elements of the input, rather than treating all elements uniformly.

Attention mechanisms have been widely employed in natural language processing tasks, such as machine translation, text summarization, and sentiment analysis, as well as in computer vision applications, allowing models to selectively attend to relevant regions of an image or sequence. The concept is inspired by human cognitive processes that involve selectively focusing on specific information to better comprehend and process complex data.

### A basic self attention implementation

Let's consider the following sequence
$$
\text{seq} = [~~\text{`this'}, ~~\text{`is'}, ~~\text{`a'}, ~~\text{`sequence'}~~]
$$

Replacing the tokens by the embedding vectors, we can see it as a `(seq_len, hidden_size)` matrix $E$.
$$
\text{seq} \equiv E = [~~~~~\vec{e}_1, ~~~~~~~~\vec{e}_2, ~~~~~~~~\vec{e}_3, ~~~~~~~\vec{e}_4~~~~]
$$ 

The relationship between token $i$ and token $j$ within the sentence can be determined by calculating the attention of token $i$ towards token $j$.
For that we can use cosine similarity betwen them.
To make it simpler, let's write it only as the dot product between the vectors and leave out the normalization.
$$
a_{ij} = \vec{e}_{i} \cdot \vec{e}_{j}
$$

Considering all embeding vectors together as the matrix $E$, we can compute the attentions all at once by performing the matrix multiplication $EE^{\text{T}}$. That's `(seq_len, hidden_size)` x `(hidden_size, seq_len)` that gives `(seq_len, seq_len)`. Elementwise, we have
$$
a_{ij} = \sum_k^{\mathrm{hid\ size}} e_{ik} e_{kj}
$$
where $e_{ik}$ are the elements of the matrix $E$ and $a_{ij}$, the elements of the attention matrix.

We could normalize the $a_{ij}$ with a softmax to ensure all the columns or rows sum to 1, but that's not necessary since at this point we only want to understand the concept..

Finally, each of the embedding vectors is written as an average of all the embedding vectors weighted by the correspoding attention coefficients
$$
\text{seq}^* \equiv \text{E}^*_{ij} = \sum_k^{\mathrm{hid\ size}} a_{ik}e_{kj}
$$

That's the matrix multiplication $AE^{\text{T}}$, where the shapes are `(seq_len, seq_len)` x `(seq_len, hidden_size)`  -> `(seq_len, hidden_size)`.

If the averaging step was not clear, prehaps it's easier if written explicitly for our example `"This is a sequence"`:
$$
\vec{e}^*_{i} = a_{i1}\vec{e}_{1} + a_{i2}\vec{e}_{2} + a_{i3}\vec{e}_{3} + a_{i4}\vec{e}_{4}
$$
The original embedding vector representation of each token has been modified by adding information of the context in which the token is given. That's called **contextual representation**. For instance, now the word "flies" is represented differently in the famous examples "time flies like an arrow" and "fruit flies like bananas".

### Scaled dot product attention

> from Natural Language Processing with Transformers, Revised Edition

The basic attention mechanism above will assign a very large score to identical words in the context, and in particular to the current word itself: the dot product of a query with itself is 1. But in practice, the meaning of a word will be better informed by complementary words in the context than by identical words—for example, the meaning of “flies” is better defined by incorporating information from “time” and “arrow” than by another mention of “flies”.
Let’s allow the model to create a different set of vectors for the query, key, and value of a token by using three different linear projections to project our initial token vector into three different spaces.

$$
Q = W^{\text{q}}E
$$
$$
K = W^{\text{k}}E
$$
$$
V = W^{\text{v}}E
$$

Now the first matrix multiplication $EE^{\text{T}}$ that we did to get the similarities becomes $EQ^{\text{T}}$:
$$
\color{gray} EE^{\text{T}}
 ~~~ \Rightarrow  ~~~
\color{black} QK^{\text{T}}
$$
Then that's notmilzed by the square root of the size of the matrix $d_k$ and a softmax function is applied to it.
Now the attention coefficients are
$$
a_{ij} = \sum_k^{\mathrm{hid\ size}} e_{ik} e_{kj}
~~~ \Rightarrow  ~~~
a_{ij} = \text{softmax} \left( \frac{Q K^{\text{T}}}{\sqrt{d_k}} \right)_{ij}
% a = \text{softmax} \left( \frac{\mathbf{q} \cdot \mathbf{k}}{\sqrt{d_k}} \right)
$$

In the scaled dot product attention, the $a_{ij}$ are normalized with a softmax to ensure all the columns or rows sum to 1.

$$
\text{seq}^* \equiv \text{E}^*_{ij} = \sum_k^{\mathrm{hid\ size}} a_{ik}e_{kj}
~~~ \Rightarrow  ~~~
\text{seq}^* \equiv \text{E}^*_{ij} = \sum_k^{\mathrm{hid\ size}} a_{ik}v_{kj}
$$

We often see the whole thing written in one step as
$$
\text{E}^* = \text{softmax} \left( \frac{Q K^{\text{T}}}{\sqrt{d_k}} \right)V
$$

In [None]:
head_dim = config.hidden_size // config.num_attention_heads

query = nn.Linear(config.hidden_size, 64)(seq_embedding)  # Q = W_qE
key = nn.Linear(config.hidden_size, 64)(seq_embedding)    # K = W_kE
value = nn.Linear(config.hidden_size, 64)(seq_embedding)  # V = W_vE

dim_k = query.shape[-1]

scores = torch.bmm(query, key.transpose(1, 2)) / torch.math.sqrt(dim_k)  # QK^T / sqrt(dim_k)

weights = F.softmax(scores, dim=-1)  # softmax(QK^T / sqrt(dim_k))

seq_embedding_att = torch.bmm(weights, value)   # softmax(QK^T / sqrt(dim_k))V

In [None]:
plt.matshow(weights.detach().numpy()[0], cmap='Blues', interpolation='nearest')
plt.xticks(range(weights.shape[-1]), tokens, rotation=80)
plt.yticks(range(weights.shape[-1]), tokens)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['bottom'].set_visible(False)
plt.gca().spines['left'].set_visible(False)
plt.grid(color='w', alpha=0.5)
plt.show()

In [None]:
def scaled_dot_product_attention(seq_embedding):
    head_dim = config.hidden_size // config.num_attention_heads
    query = nn.Linear(config.hidden_size, head_dim)(seq_embedding)
    key = nn.Linear(config.hidden_size, head_dim)(seq_embedding)
    value = nn.Linear(config.hidden_size, head_dim)(seq_embedding)
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / torch.math.sqrt(dim_k)
    weights = F.softmax(scores, dim=-1)
    return torch.bmm(weights, value)

In [None]:
att_concat = torch.cat([scaled_dot_product_attention(seq_embedding) for i in range(12)], dim=-1)

In [None]:
nn.Linear(config.hidden_size, config.hidden_size)(att_concat).shape

In [None]:
att_concat.shape