<a href="https://colab.research.google.com/github/ganesh3/pytorch-work/blob/master/Attentation_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Based on article at this [link](https://machinelearningmastery.com/the-attention-mechanism-from-scratch/)

The general attention mechanism makes use of three main components, namely the queries ,$Q$ , the keys, $K$ , and the values, $V$. 

Each query vector,$q=s_{t-1}$, is matched against a database of keys to compute a score value. This matching operation is computed as the dot product of the specific query under consideration with each key vector, $k_i$

$\mathrm{e}_{\mathrm{q}, \mathrm{k}_{\mathrm{i}}}=\mathrm{q} \cdot \mathrm{k}_{\mathrm{i}}$

The scores are passed through a softmax operation to generate the weights:

$\alpha_{\mathrm{q}, \mathrm{k}_{\mathrm{i}}}=\operatorname{softmax}\left(\mathrm{e}_{\mathrm{q}, \mathrm{k}_{\mathrm{i}}}\right)$

The generalized attention is then computed by a weighted sum of the value vectors, $v_k$, where each value vector is paired with a corresponding key:

$\operatorname{attention}(\mathrm{q}, \mathrm{K}, \mathrm{V})=\sum_{\mathrm{i}} \alpha_{\mathrm{q}, \mathrm{k}_{\mathrm{i}}} \mathrm{v}_{\mathrm{k}_{\mathrm{i}}}$

# The General Attention Mechanism with NumPy and SciPy

In [18]:
import numpy as np
from scipy.special import softmax

np.set_printoptions(precision=5)

In [3]:
# encoder representations of four different words
word_1 = np.array([1, 0, 0])
word_2 = np.array([0, 1, 0])
word_3 = np.array([1, 1, 0])
word_4 = np.array([0, 0, 1])

In [5]:
# generating the weight matrices
np.random.seed(42)

In [13]:
W_Q = np.random.randint(3, size=(3,3))
W_K = np.random.randint(3, size=(3,3))
W_V = np.random.randint(3, size=(3,3))

In [12]:
W_Q, W_K, W_V

(array([[8, 4, 1],
        [3, 6, 7],
        [2, 0, 3]]), array([[1, 1, 1],
        [1, 1, 1],
        [1, 0, 2]]), array([[1, 1, 1],
        [1, 1, 1],
        [2, 2, 1]]))

In [14]:
# generating the queries, keys and values
query_1 = word_1 @ W_Q
key_1 = word_1 @ W_K
value_1 = word_1 @ W_V

query_2 = word_2 @ W_Q
key_2 = word_2 @ W_K
value_2 = word_2 @ W_V

query_3 = word_3 @ W_Q
key_3 = word_3 @ W_K
value_3 = word_3 @ W_V

query_4 = word_4 @ W_Q
key_4 = word_4 @ W_K
value_4 = word_4 @ W_V

Considering only the first word for the time being, the next step scores its query vector against all the key vectors using a dot product operation. 

In [16]:
scores = np.array([np.dot(query_1, key_1), np.dot(query_1, key_2), np.dot(query_1, key_3), np.dot(query_1, key_4)])

In [17]:
scores

array([0, 0, 0, 2])

The score values are subsequently passed through a softmax operation to generate the weights. Before doing so, it is common practice to divide the score values by the square root of the dimensionality of the key vectors (in this case, three) to keep the gradients stable. 

In [19]:
weights = softmax(scores / key_1.shape[0] ** 0.5)

In [20]:
weights

array([0.16199, 0.16199, 0.16199, 0.51402])

In [21]:
# computing the attention by a weighted sum of the value vectors
attention = (weights[0] * value_1) + (weights[1] * value_2) + (weights[2] * value_3) + (weights[3] * value_4)
attention

array([1.67601, 1.67601, 0.64798])

For faster processing, the same calculations can be implemented in matrix form to generate an attention output for all four words in one go:

In [22]:
# stacking the word embeddings into a single array
words = np.array([word_1, word_2, word_3, word_4])

In [23]:
# generating the queries, keys and values
Q = words @ W_Q
K = words @ W_K
V = words @ W_V

In [24]:
# scoring the query vectors against all key vectors
scores = Q @ K.transpose()

In [25]:
# computing the weights by a softmax operation
weights = softmax(scores / K.shape[1] ** 0.5, axis=1)

In [26]:
# computing the attention by a weighted sum of the value vectors
attention = weights @ V

In [30]:
#matches for word1
attention[0]

array([1.67601, 1.67601, 0.64798])

In [29]:
attention

array([[1.67601, 1.67601, 0.64798],
       [1.67601, 1.67601, 0.64798],
       [1.84696, 1.84696, 0.30608],
       [1.67601, 1.67601, 0.64798]])