# Self Attention in Transformers

## Generate Data

In [1]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [20]:
q.shape

(4, 8)

In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-1.02063382e+00  2.03048424e+00 -1.16038296e-01  1.23542950e+00
  -1.86683578e+00 -4.93252452e-01 -7.50532860e-01 -9.71111369e-01]
 [ 7.56319420e-01  7.01961219e-01 -1.02065331e+00 -1.52032162e-01
  -6.89020632e-01  7.08388230e-01  3.95367727e-01  1.77194968e+00]
 [ 1.39943116e+00 -1.31547632e+00 -3.24576911e+00  8.74456236e-01
   1.70940100e+00  6.99874739e-01  3.67117761e-01 -1.05423505e-01]
 [-1.07093347e+00 -1.22531265e+00  3.85057190e-01  4.51410227e-02
   1.17234206e+00  5.09807526e-03  1.53375644e+00  2.53841899e-03]]
K
 [[-0.46739791 -0.58959332  1.18027457 -0.47494861  0.53690126  1.16251636
  -0.22781065  0.04919067]
 [-0.47751967  0.38535486 -0.24983902  0.90406397 -0.67596343 -0.69480326
  -0.57761484  0.71547854]
 [-1.70636866  1.21441076 -0.578801   -0.28385592 -0.19336143  0.11628707
   0.89986119  0.38635698]
 [ 0.78397436 -2.44225155 -0.54793936  0.29862405 -0.09020252 -0.34433245
   0.20369221 -0.08662458]]
V
 [[-1.15986259 -1.08271319 -2.52803376 -1.50479629 -0.

## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$ 

In [23]:
np.matmul(q, k.T)

array([[-2.89635117,  3.75906359,  3.17694268, -5.05711256],
       [-1.44914615,  1.03988299,  1.45180749, -0.86230591],
       [-2.48213948, -1.50294568, -2.31456117,  6.03818996],
       [ 1.94209911, -1.69628715,  1.25874512,  2.16012026]])

In [24]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(1.353665160507569, 0.6177191983301971, 7.934026736240158)

In [25]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.353665160507569, 0.6177191983301971, 0.9917533420300196)

Notice the reduction in variance of the product

In [26]:
scaled

array([[-1.02401478,  1.32902968,  1.12321886, -1.78795929],
       [-0.51235054,  0.36765416,  0.51329146, -0.30487118],
       [-0.87756883, -0.53137154, -0.81832095,  2.13482254],
       [ 0.68663573, -0.59972808,  0.44503361,  0.76371784]])

## Masking

- This is to ensure words don't get context from words generated in the future. 
- Not required in the encoders, but required int he decoders

In [33]:
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [34]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [35]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [36]:
scaled + mask

array([[-1.02401478,        -inf,        -inf,        -inf],
       [-0.51235054,  0.36765416,        -inf,        -inf],
       [-0.87756883, -0.53137154, -0.81832095,        -inf],
       [ 0.68663573, -0.59972808,  0.44503361,  0.76371784]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [87]:
def softmax(x):
  # return np.exp(scaled + mask)/np.expand_dims(np.sum(np.exp(scaled + mask), axis=1), axis=1)
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [88]:
attention = softmax(scaled + mask)

In [89]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.29317681, 0.70682319, 0.        , 0.        ],
       [0.287793  , 0.40684763, 0.30535937, 0.        ],
       [0.31829162, 0.08793555, 0.24997625, 0.34379658]])

In [94]:
new_v = np.matmul(attention, v)
new_v

array([[-1.15986259, -1.08271319, -2.52803376, -1.50479629, -0.21416274,
        -0.39536532,  0.11824353,  0.77411504],
       [-0.83760971, -1.85509107, -0.60166189,  0.43156591,  0.05139406,
        -0.40142599,  0.17052387, -0.23215534],
       [-0.47545082, -1.14007767, -0.87943521,  0.11586811, -0.44648627,
        -0.28310573, -0.00372917,  0.13740197],
       [-0.01221302, -0.22224388, -0.18190689, -1.22454608, -0.53130902,
         0.1757147 , -0.28726643,  0.01812349]])

In [None]:
v

array([[-0.00368231,  1.43739233, -0.59614565, -1.23171219,  1.12030717,
        -0.98620738, -0.15461465, -1.03106383],
       [ 0.85585446, -1.79878344,  0.67321704,  0.05607552, -0.15542661,
        -1.41264124, -0.40136933, -1.17626611],
       [ 0.50465335,  2.28693419,  0.67128338,  0.2506863 ,  1.78802234,
         0.14775751, -0.11405725,  0.88026286],
       [-0.68069105,  0.68385101,  0.17994557, -1.68013201,  0.91543969,
        -0.19108312,  0.03160471,  1.40527326]])

# Function

In [97]:
def softmax(x):
  # # return np.exp(scaled + mask)/np.expand_dims(np.sum(np.exp(scaled + mask), axis=1), axis=1)
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [98]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[-1.02063382e+00  2.03048424e+00 -1.16038296e-01  1.23542950e+00
  -1.86683578e+00 -4.93252452e-01 -7.50532860e-01 -9.71111369e-01]
 [ 7.56319420e-01  7.01961219e-01 -1.02065331e+00 -1.52032162e-01
  -6.89020632e-01  7.08388230e-01  3.95367727e-01  1.77194968e+00]
 [ 1.39943116e+00 -1.31547632e+00 -3.24576911e+00  8.74456236e-01
   1.70940100e+00  6.99874739e-01  3.67117761e-01 -1.05423505e-01]
 [-1.07093347e+00 -1.22531265e+00  3.85057190e-01  4.51410227e-02
   1.17234206e+00  5.09807526e-03  1.53375644e+00  2.53841899e-03]]
K
 [[-0.46739791 -0.58959332  1.18027457 -0.47494861  0.53690126  1.16251636
  -0.22781065  0.04919067]
 [-0.47751967  0.38535486 -0.24983902  0.90406397 -0.67596343 -0.69480326
  -0.57761484  0.71547854]
 [-1.70636866  1.21441076 -0.578801   -0.28385592 -0.19336143  0.11628707
   0.89986119  0.38635698]
 [ 0.78397436 -2.44225155 -0.54793936  0.29862405 -0.09020252 -0.34433245
   0.20369221 -0.08662458]]
V
 [[-1.15986259 -1.08271319 -2.52803376 -1.50479629 -0.