<h1 style="text-align: center;" markdown="1">Self-Attention</h1>


In [2]:
import numpy as np
import math

We will start by initially creating q,k,v vectors of lenght 8 for a sentence of size 4 eg, My name is Ankit

In [3]:
L, d_k, d_v= 4, 8, 8 # L is the length of input sequence, 8 is the size of these vectors
q = np.random.randn(L, d_k) # Randomly initialized Query vector
k = np.random.randn(L, d_k) # Randomly initialized Key vector
v = np.random.randn(L, d_v) # Randomly initialized Value vector



In [4]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[ 3.27658360e-01 -8.60237515e-01  7.01499354e-01  4.34182897e-01
   2.80853985e-01  1.38040540e+00 -2.11988316e+00  1.78151400e-01]
 [-9.49837507e-01  2.73231377e+00 -1.61674175e+00  3.87097462e-01
   3.19327511e+00 -1.23995173e+00  1.37240127e+00 -4.59724315e-01]
 [ 1.51060386e+00  2.12124668e-01  2.90103932e-01 -2.37498202e-01
  -1.11965198e-01 -7.47350388e-01  1.36331793e+00 -6.00731544e-01]
 [ 1.11494243e+00 -1.63116249e+00 -6.52089359e-01 -3.78436575e-01
   1.83021635e-01  7.68739277e-01  1.75082700e-03 -6.52297657e-01]]
K
 [[ 0.91667915  0.55853644 -1.14464798 -0.16684445  1.82033297 -0.50047993
   1.0108848   0.80116587]
 [-0.45192925  1.24437185  0.40770284  0.50844675  0.36116926  0.45497211
   0.65543604  1.58465177]
 [-0.94197646 -0.97448117  2.64182891  0.0268378   0.10629438  0.48575333
  -1.7628804   0.32155008]
 [-0.44359675  0.03060568 -0.96249477 -0.57973624  0.52199649 -1.97860248
   0.10932422 -0.23303265]]
V
 [[ 1.08096222 -1.08672196  0.42258225  0.98401604 -0.

self attention = softmax( ((Q.K^T)/d_k^0.5) + M)

new V = self attention . V

> Every single word has to look every other word just too see if that word has his affinity towards other word or not!

In [6]:
np.matmul(q, k.T) # (4*8) X (8*4) = (4*4)

array([[-3.23537351, -1.0894297 ,  6.88930738, -3.95651577],
       [ 9.89383476,  4.12712959, -8.85871715,  6.21407766],
       [ 2.27786849, -0.86004765, -3.8410947 ,  0.90415367],
       [ 0.3481266 , -3.60858168, -1.01353836, -0.97077642]])

In [7]:
# We need to divide the above product with sqrt of d_k to minimize the variance and stablize it's values
q.var(), k.var(), np.matmul(q, k.T).var()

(1.351075926742389, 0.954693994236617, 21.421766396649787)

In [8]:
scaled= np.matmul(q,k.T)/math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.351075926742389, 0.954693994236617, 2.6777207995812233)

In [9]:
scaled

array([[-1.14387727, -0.38517157,  2.43573798, -1.39883957],
       [ 3.49799883,  1.45916066, -3.13202949,  2.19700823],
       [ 0.80534813, -0.30407276, -1.35803206,  0.31966659],
       [ 0.12308134, -1.27582629, -0.35833992, -0.3432213 ]])

## Masking

* This is to keep words unaware of context from upcoming words in future.
* It's only required in decoder and not in encoder

In [10]:
mask= np.tril(np.ones((L,L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [11]:
mask[mask==0]=-np.infty
mask[mask==1]=0

In [12]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [13]:
scaled, scaled+mask

(array([[-1.14387727, -0.38517157,  2.43573798, -1.39883957],
        [ 3.49799883,  1.45916066, -3.13202949,  2.19700823],
        [ 0.80534813, -0.30407276, -1.35803206,  0.31966659],
        [ 0.12308134, -1.27582629, -0.35833992, -0.3432213 ]]),
 array([[-1.14387727,        -inf,        -inf,        -inf],
        [ 3.49799883,  1.45916066,        -inf,        -inf],
        [ 0.80534813, -0.30407276, -1.35803206,        -inf],
        [ 0.12308134, -1.27582629, -0.35833992, -0.3432213 ]]))

## Softmax

softmax= e^xi / sum_all j e^xj

In [14]:
def softmax(x): # Row adds up to 1
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [15]:
attention= softmax(scaled+mask) # for decoder
# attention= softmax(scaled) # for encoder

In [16]:
attention # Shows which word has most focus on

array([[1.        , 0.        , 0.        , 0.        ],
       [0.88481491, 0.11518509, 0.        , 0.        ],
       [0.69219202, 0.22825023, 0.07955775, 0.        ],
       [0.40126987, 0.09906008, 0.24794648, 0.25172357]])

In [17]:
new_v = np.matmul(attention, v) # (4*4) * (4*8) = (4*8)
new_v

array([[ 1.08096222, -1.08672196,  0.42258225,  0.98401604, -0.14328115,
         0.39715357, -0.6160035 ,  0.29124401],
       [ 0.7675462 , -0.94544621,  0.16186523,  0.86176733, -0.20624521,
         0.29412166, -0.51566741,  0.16162403],
       [ 0.31301225, -0.78051764, -0.16706936,  0.54650442, -0.38664307,
         0.11156445, -0.3731514 ,  0.05816239],
       [-0.0515132 , -0.79152134, -0.27007465,  0.17990589, -0.33661285,
         0.10573455, -0.11486654,  0.41754912]])

In [18]:
v

array([[ 1.08096222, -1.08672196,  0.42258225,  0.98401604, -0.14328115,
         0.39715357, -0.6160035 ,  0.29124401],
       [-1.64001506,  0.13978882, -1.84087925, -0.07730802, -0.68991496,
        -0.49733641,  0.25508224, -0.83407513],
       [-0.76532273, -0.75673576, -0.49519144, -1.47035393, -1.63392942,
        -0.62627373, -0.06261209,  0.59005495],
       [-0.52856258, -0.72170471, -0.53433832,  0.62480097,  0.77208357,
         0.59953601,  0.48693523,  0.94152158]])

## Complete function

In [19]:
def softmax(x):
    return ( np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention( q, k, v, mask=None): # mask is none by default and is in encoder setting which 
                                                       # can be changed to decoder by passing mask
    d_k=q.shape[-1]
    scaled= np.matmul(q, k.T)/math.sqrt(d_k)
    if mask is not None:
        scaled= scaled+mask
    attention= softmax(scaled)
    out=np.matmul(attention, v)
    return out, attention # new values , attention

In [20]:
values, attention= scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 3.27658360e-01 -8.60237515e-01  7.01499354e-01  4.34182897e-01
   2.80853985e-01  1.38040540e+00 -2.11988316e+00  1.78151400e-01]
 [-9.49837507e-01  2.73231377e+00 -1.61674175e+00  3.87097462e-01
   3.19327511e+00 -1.23995173e+00  1.37240127e+00 -4.59724315e-01]
 [ 1.51060386e+00  2.12124668e-01  2.90103932e-01 -2.37498202e-01
  -1.11965198e-01 -7.47350388e-01  1.36331793e+00 -6.00731544e-01]
 [ 1.11494243e+00 -1.63116249e+00 -6.52089359e-01 -3.78436575e-01
   1.83021635e-01  7.68739277e-01  1.75082700e-03 -6.52297657e-01]]
K
 [[ 0.91667915  0.55853644 -1.14464798 -0.16684445  1.82033297 -0.50047993
   1.0108848   0.80116587]
 [-0.45192925  1.24437185  0.40770284  0.50844675  0.36116926  0.45497211
   0.65543604  1.58465177]
 [-0.94197646 -0.97448117  2.64182891  0.0268378   0.10629438  0.48575333
  -1.7628804   0.32155008]
 [-0.44359675  0.03060568 -0.96249477 -0.57973624  0.52199649 -1.97860248
   0.10932422 -0.23303265]]
V
 [[ 1.08096222 -1.08672196  0.42258225  0.98401604 -0.