<a href="https://colab.research.google.com/github/douhua2882/Transformer/blob/main/Self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Self Attention in Transformers

### Generate Data

In [1]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [None]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-0.11993449  0.89862434 -0.76009501 -0.04604154 -1.40243564  0.31652101
   2.16818116 -0.56526164]
 [ 2.13863082  1.31234223 -0.36791218  1.40673364 -0.50086097 -1.22534046
   0.0807025  -0.46283111]
 [-1.72215265 -0.75502004 -1.42236718  1.05042606  1.09068783 -0.79608508
   0.05587742 -0.68369517]
 [-0.33438194 -0.21406342  1.24018574 -0.24386563  0.02757836  0.46461054
  -1.65045735  0.40249933]]
K
 [[-0.29338514  2.54454292  0.66040032 -1.86401072  0.69913026 -1.23585533
  -0.9788095  -1.02004732]
 [-0.43401444 -0.2558092  -0.47753853 -0.83951567 -0.90884813 -0.1018832
   1.07452724 -0.19372144]
 [-0.61532695  1.16622267 -0.36956243 -0.74272091 -0.34949601  0.36100618
   0.71014707  0.66462162]
 [ 1.54212128 -0.84179554 -0.01605768  0.40334175  0.64683248 -0.21154629
  -1.35798793  0.64093102]]
V
 [[ 0.66999533 -1.13268771 -0.06110204 -0.39308294  0.75843569  0.50818355
   2.36827749 -0.35537324]
 [ 1.20410627 -0.49431986  1.17368738  1.06898047 -0.70626029  1.48225817
   1.60

## Self Attention

In [None]:
np.matmul(q, k.T)

array([[-1.01167192,  3.90543002,  3.20534766, -5.22853396],
       [ 1.40402646, -1.51276306, -1.21192138,  2.29561373],
       [-1.92417717,  0.02029251, -1.15865456, -1.2138563 ],
       [ 1.47700078, -2.11145615, -1.06756228,  1.96509185]])

In [None]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(1.021085278110049, 0.8490688448001151, 5.172010342081148)

In [None]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.021085278110049, 0.8490688448001151, 0.6465012927601435)

In [None]:
scaled

array([[-0.35768004,  1.38077803,  1.13326153, -1.84856591],
       [ 0.49639832, -0.53484251, -0.42847891,  0.81162202],
       [-0.68029936,  0.00717449, -0.40964625, -0.42916301],
       [ 0.52219863, -0.74651248, -0.37744026,  0.69476489]])

## Masking
- To ensure words don't get context from words generated in the future
- Not required in the encoders, but required in the decoders.

In [None]:
mask = np.tril(np.ones((L, L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [None]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [None]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [None]:
scaled + mask

array([[-0.35768004,        -inf,        -inf,        -inf],
       [ 0.49639832, -0.53484251,        -inf,        -inf],
       [-0.68029936,  0.00717449, -0.40964625,        -inf],
       [ 0.52219863, -0.74651248, -0.37744026,  0.69476489]])

## Softmax

In [None]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [None]:
attention = softmax(scaled + mask)

In [None]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.73715638, 0.26284362, 0.        , 0.        ],
       [0.23258487, 0.46253816, 0.30487697, 0.        ],
       [0.34767361, 0.09776367, 0.14140459, 0.41315813]])

In [None]:
new_v = np.matmul(attention, v)
new_v

array([[ 0.66999533, -1.13268771, -0.06110204, -0.39308294,  0.75843569,
         0.50818355,  2.36827749, -0.35537324],
       [ 0.81038298, -0.9648968 ,  0.26345448, -0.00878891,  0.3734497 ,
         0.76421284,  2.16700146, -0.01268525],
       [ 1.39694629, -1.01527832,  0.31976439,  1.52382609, -0.41738949,
         1.08820765,  1.0407995 ,  0.44626275],
       [ 1.51980485, -0.224784  ,  0.13429173,  0.63911266,  0.74886157,
         1.14843433,  0.4710114 ,  0.40328325]])

## Function

In [None]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T


def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled += mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [None]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[-0.11993449  0.89862434 -0.76009501 -0.04604154 -1.40243564  0.31652101
   2.16818116 -0.56526164]
 [ 2.13863082  1.31234223 -0.36791218  1.40673364 -0.50086097 -1.22534046
   0.0807025  -0.46283111]
 [-1.72215265 -0.75502004 -1.42236718  1.05042606  1.09068783 -0.79608508
   0.05587742 -0.68369517]
 [-0.33438194 -0.21406342  1.24018574 -0.24386563  0.02757836  0.46461054
  -1.65045735  0.40249933]]
K
 [[-0.29338514  2.54454292  0.66040032 -1.86401072  0.69913026 -1.23585533
  -0.9788095  -1.02004732]
 [-0.43401444 -0.2558092  -0.47753853 -0.83951567 -0.90884813 -0.1018832
   1.07452724 -0.19372144]
 [-0.61532695  1.16622267 -0.36956243 -0.74272091 -0.34949601  0.36100618
   0.71014707  0.66462162]
 [ 1.54212128 -0.84179554 -0.01605768  0.40334175  0.64683248 -0.21154629
  -1.35798793  0.64093102]]
V
 [[ 0.66999533 -1.13268771 -0.06110204 -0.39308294  0.75843569  0.50818355
   2.36827749 -0.35537324]
 [ 1.20410627 -0.49431986  1.17368738  1.06898047 -0.70626029  1.48225817
   1.60