# Self Attention in Transformers

## Generate Data

In [1]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-1.76349031 -0.07880513 -1.15513671 -0.21057209  0.74424039  0.82765225
   0.0192568   1.48961222]
 [-0.54103414 -0.50396646 -0.18011104  0.54465742 -1.61525081 -0.87904977
  -1.07513707 -0.04925841]
 [ 0.03908773 -0.56277519  0.15265451 -0.79751986 -1.12610502 -0.54959184
   0.29082436 -0.65783831]
 [ 1.06499905  0.27957301 -0.39346881  0.24551245  0.50918087 -0.39092876
   2.09970951  0.18345928]]
K
 [[-0.33779617  0.58951322  2.01563811  0.49571726  0.55097651  0.38274338
   1.06731527 -0.22021791]
 [ 1.64793253 -0.31745302 -0.26702562 -0.770391    0.26460754  1.28309513
   0.95870637 -2.46016015]
 [-0.36042034  1.17402281 -0.58846035 -0.25589226  0.43745058 -0.55249164
   0.23008756  0.44399825]
 [ 0.96153579 -0.35424022 -0.50452882 -0.892767    1.43270988 -2.644971
  -1.8568502  -0.02916864]]
V
 [[-2.57545071  0.53723243  0.0453602   1.43544355  0.06780701  0.55553782
   0.68318127  0.55479496]
 [-0.36574114 -2.19010554  0.49446074  1.72304628 -1.2750318   1.70822737
  -0.207

## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$ 

In [3]:
np.matmul(q, k.T)

array([[-1.46412703, -4.79775699,  1.81082808, -2.09899411],
       [-2.57045658, -3.56797621, -0.92022489,  1.27160117],
       [-0.8081562 ,  1.71075636, -0.97468528,  0.19136634],
       [ 1.46524573,  2.77702094,  1.1163928 , -1.23636393]])

In [4]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(0.6939387331648185, 1.1178149717330133, 4.23453464871582)

In [5]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(0.6939387331648185, 1.1178149717330133, 0.5293168310894775)

Notice the reduction in variance of the product

In [6]:
scaled

array([[-0.51764708, -1.69626325,  0.64022441, -0.74210648],
       [-0.90879364, -1.26147009, -0.32534863,  0.4495789 ],
       [-0.28572637,  0.60484371, -0.34460329,  0.06765822],
       [ 0.5180426 ,  0.98182517,  0.39470446, -0.43712066]])

## Masking

- This is to ensure words don't get context from words generated in the future. 
- Not required in the encoders, but required int he decoders

In [7]:
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [8]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [9]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [10]:
scaled + mask

array([[-0.51764708,        -inf,        -inf,        -inf],
       [-0.90879364, -1.26147009,        -inf,        -inf],
       [-0.28572637,  0.60484371, -0.34460329,        -inf],
       [ 0.5180426 ,  0.98182517,  0.39470446, -0.43712066]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [11]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [12]:
attention = softmax(scaled + mask)

In [13]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.58726646, 0.41273354, 0.        , 0.        ],
       [0.22834486, 0.55636641, 0.21528874, 0.        ],
       [0.25914851, 0.41206614, 0.22907814, 0.09970721]])

In [14]:
new_v = np.matmul(attention, v)
new_v

array([[-2.57545071,  0.53723243,  0.0453602 ,  1.43544355,  0.06780701,
         0.55553782,  0.68318127,  0.55479496],
       [-1.66342946, -0.58843143,  0.23071905,  1.55414685, -0.48642761,
         1.03129146,  0.31570587,  0.06708916],
       [-0.4702781 , -1.10252401,  0.12453667,  1.55806606, -0.6207903 ,
         1.04165357,  0.28813533, -0.13992943],
       [-0.51074625, -0.81849502, -0.10216045,  1.25896969, -0.46801633,
         0.79053741,  0.56498604, -0.10143245]])

In [15]:
v

array([[-2.57545071,  0.53723243,  0.0453602 ,  1.43544355,  0.06780701,
         0.55553782,  0.68318127,  0.55479496],
       [-0.36574114, -2.19010554,  0.49446074,  1.72304628, -1.2750318 ,
         1.70822737, -0.20716411, -0.62685314],
       [ 1.49240926, -0.03110761, -0.74747264,  1.26176986,  0.33959596,
        -0.16536377,  1.14912445,  0.38156212],
       [-0.3459211 , -0.48264915, -1.46866364, -1.12405396, -0.38096535,
        -0.19507932,  2.10683506, -0.74527535]])

# Function

- Use mask in case of Decoder only, otherwise use mask=None.

In [16]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.shape[-1]
    scaled = np.matmul(q, k.T) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention = softmax(scaled)
    out = np.matmul(attention, v)
    return out, attention

In [17]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[-1.76349031 -0.07880513 -1.15513671 -0.21057209  0.74424039  0.82765225
   0.0192568   1.48961222]
 [-0.54103414 -0.50396646 -0.18011104  0.54465742 -1.61525081 -0.87904977
  -1.07513707 -0.04925841]
 [ 0.03908773 -0.56277519  0.15265451 -0.79751986 -1.12610502 -0.54959184
   0.29082436 -0.65783831]
 [ 1.06499905  0.27957301 -0.39346881  0.24551245  0.50918087 -0.39092876
   2.09970951  0.18345928]]
K
 [[-0.33779617  0.58951322  2.01563811  0.49571726  0.55097651  0.38274338
   1.06731527 -0.22021791]
 [ 1.64793253 -0.31745302 -0.26702562 -0.770391    0.26460754  1.28309513
   0.95870637 -2.46016015]
 [-0.36042034  1.17402281 -0.58846035 -0.25589226  0.43745058 -0.55249164
   0.23008756  0.44399825]
 [ 0.96153579 -0.35424022 -0.50452882 -0.892767    1.43270988 -2.644971
  -1.8568502  -0.02916864]]
V
 [[-2.57545071  0.53723243  0.0453602   1.43544355  0.06780701  0.55553782
   0.68318127  0.55479496]
 [-0.36574114 -2.19010554  0.49446074  1.72304628 -1.2750318   1.70822737
  -0.207