### **Scaled Dot-Product Attention: Q,K -> MatMul-> Scale-> Mask(opt.)-> SoftMax,V -> MatMul**

In [None]:
import numpy as np
import math

# L= sequence length (in our case it is 4 because the sentence consists of 4 words)
# d_k , d_v = dimensions of q, k, and v
# q:what information should I look for?,
# k:how relevant am I to what the query is looking for?,
# q:actual information contained in each token
L, d_k, d_v = 4,8,8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)
# q and k have same shapes because they are going to get MatMul-ed :D

In [None]:
np.matmul(q,k.T)

array([[-0.6565045 , -3.76956714,  3.08105246,  2.43434131],
       [ 2.07802559,  1.21139921,  7.93190485, -1.00497858],
       [ 4.20399631, -0.29168475, -1.308575  , -5.64194761],
       [ 4.73150822, -1.13656296,  2.20243154, -5.90394453]])

In [None]:
k.var() , q.var(), np.matmul(q,k.T).var()

(1.2087949575675125, 1.4246945192878466, 13.143538524054767)

In [None]:
#we scale the q k and matmul because their variances have too much difference
scaled = np.matmul(q,k.T) / math.sqrt(d_k)

In [None]:
k.var(), q.var(), scaled.var()
#now they are good

(1.2087949575675125, 1.4246945192878466, 1.6429423155068454)

In [None]:
#mask time, it is optional because the words happen to enter the encoder mechanism simultaneuosly
mask= np.tril(np.ones( (L,L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [None]:
mask[mask== 0] = -np.infty
mask[mask== 1] = 0

In [None]:
scaled + mask

array([[-0.23210939,        -inf,        -inf,        -inf],
       [ 0.73469299,  0.4282943 ,        -inf,        -inf],
       [ 1.48633715, -0.10312613, -0.46265113,        -inf],
       [ 1.67284077, -0.40183569,  0.77867714, -2.08735961]])

In [None]:
#softmax time
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [None]:
attention = softmax(scaled+mask)

In [None]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.57600598, 0.42399402, 0.        , 0.        ],
       [0.742692  , 0.15153523, 0.10577277, 0.        ],
       [0.64192037, 0.08062331, 0.26251301, 0.01494331]])

In [None]:
new_v = np.matmul(attention,v)
new_v

array([[-0.66661609, -0.2300184 ,  0.10529725,  0.8967182 ,  0.1888911 ,
         0.73803245, -0.49179694, -0.52730004],
       [-0.43042733, -0.3005801 ,  0.27028764,  0.20148928,  0.08388766,
         0.65991313, -1.40654681, -0.4815708 ],
       [-0.60938126, -0.25463428,  0.15628372,  0.65480118, -0.0135991 ,
         0.76663431, -0.72213858, -0.41393903],
       [-0.69240361, -0.2469094 ,  0.08791747,  0.76954556, -0.2253876 ,
         0.85725767, -0.409163  , -0.26081935]])

In [None]:
v

array([[-0.66661609, -0.2300184 ,  0.10529725,  0.8967182 ,  0.1888911 ,
         0.73803245, -0.49179694, -0.52730004],
       [-0.10955928, -0.39643985,  0.49443102, -0.74299578, -0.05876202,
         0.55378618, -2.64925633, -0.41944655],
       [-0.9235716 , -0.2243191 ,  0.029843  ,  0.95871349, -1.37069767,
         1.27240117,  0.42138193,  0.38992467],
       [-0.88377787, -0.56258775, -1.8317103 ,  0.14400398,  1.19939318,
         0.32319981,  0.6360734 ,  0.61045097]])

In [None]:
def scaled_dot_product_attention(q,k,v, mask=None):
  d_k = q.shape[-1]
  scaled= np.matmul(q,k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled= scaled+mask
  attention = softmax(scaled)
  out = np.matmul(attention,v)
  return out, attention


In [None]:
scaled_dot_product_attention(q,k,v, mask=mask)

(array([[-0.66661609, -0.2300184 ,  0.10529725,  0.8967182 ,  0.1888911 ,
          0.73803245, -0.49179694, -0.52730004],
        [-0.51113582, -0.27646833,  0.21390848,  0.43905722,  0.11976857,
          0.68660743, -1.09396597, -0.49719702],
        [-0.57636641, -0.28894883,  0.21514358,  0.31871729, -0.60417695,
          0.91032976, -0.87821687, -0.07484719],
        [-0.84876995, -0.32258153, -0.4060383 ,  0.62750641, -0.59076015,
          0.9682549 ,  0.23279007,  0.3669685 ]]),
 array([[1.        , 0.        , 0.        , 0.        ],
        [0.72088975, 0.27911025, 0.        , 0.        ],
        [0.18065601, 0.3695087 , 0.44983529, 0.        ],
        [0.0218007 , 0.07263483, 0.65240409, 0.25316038]]))