# Self Attention in Transformers

## Generate Data

In [6]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [7]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-0.29071144 -1.052991   -0.98942321 -0.92868288  0.61144028  0.26903044
   0.46918848 -0.29149256]
 [-1.56133703  0.03052599  0.43176384  0.48118882  1.48503777 -0.56440178
  -1.95126407 -2.55662422]
 [-0.57329254 -0.82606486  0.65375037  0.16381062  0.07570454  1.1881824
  -2.3370799  -0.9624577 ]
 [ 1.18993814  0.81444341  0.96794046 -0.69342096 -0.28468606 -0.43967067
   0.07312889 -0.7314036 ]]
K
 [[ 0.87234172 -0.37437621 -0.22993135 -0.49516156  1.05696625  0.08139426
   0.335167    0.73810794]
 [-0.53381132 -0.02988164  2.17935585  0.26489901  1.45755181  1.75288246
   0.3314582   0.06292065]
 [-1.08549485 -0.72249727  1.60814021  1.03018781 -0.06491381  0.02085411
  -0.26663793  1.51183359]
 [ 0.14342002 -0.95578698  0.47519809  0.11451661  0.41585379 -0.28159015
  -1.41775128 -0.61641328]]
V
 [[-1.49793213 -0.9999246   0.20835007  0.81162595 -0.18394427  0.1143808
  -0.46859362  1.07741661]
 [-0.2383662  -1.80720243 -1.16076983  0.88053246 -0.60149722  0.43794888
   2.615

## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$ 

In [8]:
np.matmul(q, k.T)

array([[ 1.43823534e+00, -7.15702178e-01, -2.07137236e+00,
         8.12203528e-02],
       [-2.72835811e+00,  2.26854328e+00, -5.90258892e-01,
         5.12600694e+00],
       [-1.73925976e+00,  3.15674681e+00,  1.62715616e+00,
         4.64031166e+00],
       [ 1.88535654e-03,  5.88439137e-02, -2.15382306e+00,
         1.25369256e-01]])

In [9]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(0.9696122177211076, 0.7256089967229388, 5.211127799895605)

In [10]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(0.9696122177211076, 0.7256089967229388, 0.6513909749869508)

Notice the reduction in variance of the product

In [11]:
scaled

array([[ 5.08492981e-01, -2.53038932e-01, -7.32340721e-01,
         2.87157311e-02],
       [-9.64620260e-01,  8.02051169e-01, -2.08688032e-01,
         1.81231714e+00],
       [-6.14921184e-01,  1.11607854e+00,  5.75286579e-01,
         1.64059792e+00],
       [ 6.66574196e-04,  2.08044652e-02, -7.61491446e-01,
         4.43247256e-02]])

## Masking

- This is to ensure words don't get context from words generated in the future. 
- Not required in the encoders, but required int he decoders

In [12]:
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [13]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [14]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [15]:
scaled + mask

array([[ 5.08492981e-01,            -inf,            -inf,
                   -inf],
       [-9.64620260e-01,  8.02051169e-01,            -inf,
                   -inf],
       [-6.14921184e-01,  1.11607854e+00,  5.75286579e-01,
                   -inf],
       [ 6.66574196e-04,  2.08044652e-02, -7.61491446e-01,
         4.43247256e-02]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [16]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [17]:
attention = softmax(scaled + mask)

In [18]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.14595676, 0.85404324, 0.        , 0.        ],
       [0.10066378, 0.56837746, 0.33095876, 0.        ],
       [0.28315573, 0.28891569, 0.13213697, 0.29579161]])

    ([[1.        , 0.        , 0.        , 0.      ], decoder looking at the 1st word
    [0.51359112, 0.48640888, 0.        , 0.        ], decoder looking at first 2 words
    [0.53753304, 0.27144826, 0.1910187 , 0.        ], 3 previous words
    [0.19293995, 0.03256643, 0.57960627, 0.19488734]]) 4 previous words

### without mask, it will look at the values

In [19]:
new_v = np.matmul(attention, v)
new_v

array([[-1.49793213, -0.9999246 ,  0.20835007,  0.81162595, -0.18394427,
         0.1143808 , -0.46859362,  1.07741661],
       [-0.42220836, -1.68937477, -0.96093753,  0.87047509, -0.54055254,
         0.39072193,  2.16562956,  0.1784471 ],
       [-0.13242603, -0.96663354, -0.13365541,  1.05237177, -0.33784789,
         0.07708723,  1.3995478 ,  0.01156488],
       [ 0.18039406, -0.49454708, -0.06125226,  0.56634035, -0.52337182,
         0.20335531,  0.85606823,  0.02876221]])

In [20]:
v

array([[-1.49793213, -0.9999246 ,  0.20835007,  0.81162595, -0.18394427,
         0.1143808 , -0.46859362,  1.07741661],
       [-0.2383662 , -1.80720243, -1.16076983,  0.88053246, -0.60149722,
         0.43794888,  2.61582067,  0.0248124 ],
       [ 0.46484175,  0.48705698,  1.52625267,  1.42070759,  0.06812357,
        -0.55398759, -0.12102809, -0.33537337],
       [ 2.06898056,  0.83287762,  0.04544564, -0.35702105, -1.03622482,
         0.39771129,  0.84178805, -0.80856907]])

# Function

In [21]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [22]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[-0.29071144 -1.052991   -0.98942321 -0.92868288  0.61144028  0.26903044
   0.46918848 -0.29149256]
 [-1.56133703  0.03052599  0.43176384  0.48118882  1.48503777 -0.56440178
  -1.95126407 -2.55662422]
 [-0.57329254 -0.82606486  0.65375037  0.16381062  0.07570454  1.1881824
  -2.3370799  -0.9624577 ]
 [ 1.18993814  0.81444341  0.96794046 -0.69342096 -0.28468606 -0.43967067
   0.07312889 -0.7314036 ]]
K
 [[ 0.87234172 -0.37437621 -0.22993135 -0.49516156  1.05696625  0.08139426
   0.335167    0.73810794]
 [-0.53381132 -0.02988164  2.17935585  0.26489901  1.45755181  1.75288246
   0.3314582   0.06292065]
 [-1.08549485 -0.72249727  1.60814021  1.03018781 -0.06491381  0.02085411
  -0.26663793  1.51183359]
 [ 0.14342002 -0.95578698  0.47519809  0.11451661  0.41585379 -0.28159015
  -1.41775128 -0.61641328]]
V
 [[-1.49793213 -0.9999246   0.20835007  0.81162595 -0.18394427  0.1143808
  -0.46859362  1.07741661]
 [-0.2383662  -1.80720243 -1.16076983  0.88053246 -0.60149722  0.43794888
   2.615