# Self Attention

In [2]:
import torch
import math
import torch.nn.functional as F

In [3]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

In [4]:
seq_len, d_k = 3, 2

In [5]:
Q = torch.tensor([[ 1.0463, -1.2652],
        [ 0.0063,  0.9499],
        [-0.7820, -0.7173]])

$$ Q = \begin{bmatrix} Q_{1,1} & Q_{1,2} \\ Q_{2,1} & Q_{2,2} \\ Q_{3,1} & Q_{3,2} \end{bmatrix}$$

$$ Q = \begin{bmatrix} \vec{Q_{1}} \\ \vec{Q_{2}} \end{bmatrix}$$

In [6]:
K = torch.tensor([[-1.1701, -1.6579],
        [-0.9172, -0.2540],
        [-0.1772,  0.1998]])

$$ K = \begin{bmatrix} K_{1,1} & K_{1,2} \\ K_{2,1} & K_{2,2} \\ K_{3,1} & K_{3,2} \end{bmatrix}$$

$$ K^{T} = \begin{bmatrix} \vec{K_{1}} \\ \vec{K_{2}} \end{bmatrix}$$

$$ K^{T} = \begin{bmatrix} K_{1,1} & K_{2,1} \\ K_{1,2} & K_{2,2} \\ K_{1,3} & K_{2,3} \end{bmatrix}$$

<!-- $$ K = \begin{bmatrix} \vec{K_{1}} \\ \vec{K_{2}} \\ \vec{K_{3}} \end{bmatrix}$$ -->

In [7]:
V = torch.tensor([[-0.5293,  0.1230],
        [-0.7970,  1.0645],
        [-0.6991,  0.2718]])

$$ V = \begin{bmatrix} V_{1,1} & V_{1,2} \\ V_{2,1} & V_{2,2}  \\ V_{3,1} & V_{3,2} \end{bmatrix}$$

In [8]:
values, attention = scaled_dot_product(Q, K, V)

what happens in attention

$$ QK^{T} = \begin{bmatrix} Q_{1,1}\times K_{1,1} + Q_{1,2}\times K_{1,2} &
Q_{1,1}\times K_{2,1} + Q_{1,2}\times K_{2,2} & 
Q_{1,1}\times K_{3,1} + Q_{1,2}\times K_{3,2} \\ 
Q_{2,1}\times K_{1,1} + Q_{2,2}\times K_{1,2} & 
Q_{2,1}\times K_{2,1} + Q_{2,2}\times K_{2,2} & 
Q_{2,1}\times K_{3,1} + Q_{2,2}\times K_{3,2} \\
Q_{3,1}\times K_{1,1} + Q_{3,2}\times K_{1,2} & 
Q_{3,1}\times K_{2,1} + Q_{3,2}\times K_{2,2} & 
Q_{3,1}\times K_{3,1} + Q_{3,2}\times K_{3,2} \end{bmatrix}$$

$$ QK^{T} = \begin{bmatrix} \vec{Q_{1}}\cdot \vec{K_{1}} & 
\vec{Q_{1}}\cdot \vec{K_{2}} & 
\vec{Q_{1}}\cdot \vec{K_{3}} \\ 
\vec{Q_{2}}\cdot \vec{K_{1}} & 
\vec{Q_{2}}\cdot \vec{K_{2}} & 
\vec{Q_{2}}\cdot \vec{K_{3}} \\
\vec{Q_{3}}\cdot \vec{K_{1}} & 
\vec{Q_{3}}\cdot \vec{K_{2}} & 
\vec{Q_{3}}\cdot \vec{K_{3}} \end{bmatrix}$$

In [9]:
attn_logits = torch.matmul(Q, K.T)
attn_logits

tensor([[ 0.8733, -0.6383, -0.4382],
        [-1.5822, -0.2471,  0.1887],
        [ 2.1042,  0.8994, -0.0047]])

In [10]:
# torch.matmul(Q, K.transpose(-2, -1))

what is $\vec{Q_{1}}$? It is a representation of 1st sentence part. <br>
what is K1? It is another representation of 1st sentence part. <br>

Each row of the attention matrix (softmax(QKTQK T)) contains weights representing how much a query (e.g., a word) "pays attention" to all other keys (words) in the sequence.

$$ QK^{T}/\sqrt(d_{k}) = \begin{bmatrix} \frac{\vec{Q_{1}}\cdot \vec{K_{1}}}{\sqrt(d_{k})} & 
\frac{\vec{Q_{1}}\cdot \vec{K_{2}}}{ \sqrt(d_{k}) } & 
\frac{\vec{Q_{1}}\cdot \vec{K_{3}}}{ \sqrt(d_{k}) } \\ 
\frac{\vec{Q_{2}}\cdot \vec{K_{1}}}{ \sqrt(d_{k}) } & 
\frac{\vec{Q_{2}}\cdot \vec{K_{2}}}{ \sqrt(d_{k}) } & 
\frac{\vec{Q_{2}}\cdot \vec{K_{3}}}{ \sqrt(d_{k}) } \\
\frac{\vec{Q_{3}}\cdot \vec{K_{1}}}{ \sqrt(d_{k}) } & 
\frac{\vec{Q_{3}}\cdot \vec{K_{2}}}{ \sqrt(d_{k}) } & 
\frac{\vec{Q_{3}}\cdot \vec{K_{3}}}{ \sqrt(d_{k}) } \end{bmatrix}$$

In [11]:
scaled_attention_logits = torch.matmul(Q, K.T)/ math.sqrt(d_k)

In [48]:
scaled_attention_logits

tensor([[ 0.6175, -0.4514, -0.3098],
        [-1.1188, -0.1747,  0.1334],
        [ 1.4879,  0.6360, -0.0034]])

$$ softmax(QK^{T}/\sqrt(d_{k})) = \begin{bmatrix} \frac{exp({\frac{\vec{Q_{1}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) }{ exp({\frac{\vec{Q_{1}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{1}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{1}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}})} & 
\frac{exp(\frac{\vec{Q_{1}}\cdot \vec{K_{2}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{1}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{1}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{1}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}})  } & 
\frac{exp(\frac{\vec{Q_{1}}\cdot \vec{K_{3}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{1}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{1}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{1}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}}) } \\ 
\frac{exp(\frac{\vec{Q_{2}}\cdot \vec{K_{1}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{2}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{2}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{2}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}}) } & 
\frac{exp(\frac{\vec{Q_{2}}\cdot \vec{K_{2}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{2}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{2}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{2}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}}) } & 
\frac{exp(\frac{\vec{Q_{2}}\cdot \vec{K_{3}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{2}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{2}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{2}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}}) } \\
\frac{exp(\frac{\vec{Q_{3}}\cdot \vec{K_{1}}}{ \sqrt(d_{k}) })}{  exp({\frac{\vec{Q_{3}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{3}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{3}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}}) } & 
\frac{exp(\frac{\vec{Q_{3}}\cdot \vec{K_{2}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{3}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{3}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{3}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}}) } & 
\frac{exp(\frac{\vec{Q_{3}}\cdot \vec{K_{3}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{3}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{3}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{3}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}})  } \end{bmatrix}$$

In [49]:
attention_weights = F.softmax(scaled_attention_logits, dim=-1)

In [50]:
attention_weights

tensor([[0.5750, 0.1975, 0.2275],
        [0.1415, 0.3637, 0.4949],
        [0.6054, 0.2583, 0.1363]])

you can see that values in each row add up to 1.

In [52]:
attention_weights[0].sum()

tensor(1.)

In [53]:
attention_weights[1].sum()

tensor(1.)

In [54]:
attention_weights[2].sum()

tensor(1.0000)

In [65]:
attention_weights

tensor([[0.5750, 0.1975, 0.2275],
        [0.1415, 0.3637, 0.4949],
        [0.6054, 0.2583, 0.1363]])

In [66]:
V

tensor([[-0.5293,  0.1230],
        [-0.7970,  1.0645],
        [-0.6991,  0.2718]])

In [63]:
torch.matmul(attention_weights, V)

tensor([[-0.6208,  0.3428],
        [-0.7107,  0.5390],
        [-0.6216,  0.3864]])

In [64]:
V

tensor([[-0.5293,  0.1230],
        [-0.7970,  1.0645],
        [-0.6991,  0.2718]])

In [11]:
Q[0]

tensor([ 1.0463, -1.2652])

In [12]:
K.T[:,0]

tensor([-1.1701, -1.6579])

In [13]:
torch.dot(Q[0],  K.T[:,0])

tensor(0.8733)

In [25]:
d_k = Q.size()[-1]
attn_logits = torch.matmul(Q, K.transpose(-2, -1))

In [26]:
attn_logits

tensor([[ 0.8733, -0.6383, -0.4382],
        [-1.5822, -0.2471,  0.1887],
        [ 2.1042,  0.8994, -0.0047]])

In [27]:
torch.matmul(K, Q.transpose(-2, -1))

tensor([[ 0.8733, -1.5822,  2.1042],
        [-0.6383, -0.2471,  0.8994],
        [-0.4382,  0.1887, -0.0047]])

In [28]:
attn_logits[0][0]/math.sqrt(d_k)

tensor(0.6175)

In [29]:
attn_logits[0][1]/math.sqrt(d_k)

tensor(-0.4514)

In [30]:
attn_logits = attn_logits / math.sqrt(d_k)

In [31]:
attn_logits

tensor([[ 0.6175, -0.4514, -0.3098],
        [-1.1188, -0.1747,  0.1334],
        [ 1.4879,  0.6360, -0.0034]])

In [33]:
mask = None

In [34]:
if mask is not None:
    attn_logits = attn_logits.masked_fill(mask == 0, -9e15)    

In [35]:
attention = F.softmax(attn_logits, dim=-1)

In [36]:
attention

tensor([[0.5750, 0.1975, 0.2275],
        [0.1415, 0.3637, 0.4949],
        [0.6054, 0.2583, 0.1363]])

In [37]:
torch.sum(attention[:, 1])

tensor(0.8194)

In [38]:
torch.sum(attention[1, :])

tensor(1.)

In [39]:
torch.sum(attn_logits[0, :])

tensor(-0.1437)

In [40]:
torch.exp(attn_logits[0, :])/torch.sum(torch.exp(attn_logits[0, :]))

tensor([0.5750, 0.1975, 0.2275])

In [42]:
torch.exp(attn_logits[0, :])

tensor([1.8543, 0.6368, 0.7336])

In [41]:
torch.exp(attn_logits[1, :])/torch.sum(torch.exp(attn_logits[1, :]))

tensor([0.1415, 0.3637, 0.4949])

In [48]:
torch.exp(attn_logits[2, :])/torch.sum(torch.exp(attn_logits[2, :]))

tensor([0.4905, 0.4635, 0.0460])

In [27]:
values = torch.matmul(attention, v)

In [28]:
values

tensor([[ 0.1270,  0.9533],
        [ 0.2498,  0.9650],
        [-0.2065,  0.8016]])

In [29]:
v

tensor([[-0.5255,  1.2225],
        [ 0.0399,  0.3073],
        [ 0.7107,  1.2939]])

In [52]:
attention[0,:]

tensor([0.2965, 0.3237, 0.3798])

In [50]:
torch.dot(attention[0,:], v[:,0])

tensor(0.1270)

In [51]:
torch.dot(attention[0,:], v[:,1])

tensor(0.9533)