<a href="https://colab.research.google.com/github/ayushksingh28/transformers_scratch/blob/main/Self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
!pip install torch



In [22]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Pytorch version: {torch.__version__}")
    print(f"Device name: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("No GPU available.")

Pytorch version: 2.0.1+cu118
Device name: NVIDIA A100-SXM4-40GB


In [23]:
import numpy as np
import math

In [24]:
L, d_v, d_k = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)


In [25]:
print("Q/n", q)
print("K/n", k)
print("V/n", v)

Q/n [[-0.0331857   2.06896505 -0.11488886  0.88116544  0.11719453  1.33834359
   0.02257637  1.16781053]
 [-0.30329635 -1.07928697  1.37743666 -0.42731295 -1.63323351  0.42569583
   0.80718759 -0.41248306]
 [-0.46543505  0.29416147  0.67930279  0.54110615 -1.52135385 -0.00423845
  -0.7653162   1.35744534]
 [-1.01328392 -0.11631818  0.14606273 -0.53768603  0.9199883  -1.29632442
   0.83848266  1.2300858 ]]
K/n [[-1.74445624e+00  5.01438036e-01  1.37156200e+00  1.18732476e+00
  -1.24336869e+00  1.82895126e+00 -1.43497565e+00 -8.26174852e-01]
 [-7.63499000e-01  1.32199944e-01 -1.41747891e+00 -6.90487346e-01
  -7.74263859e-02  8.61092061e-01 -5.22362621e-01 -4.23884940e-01]
 [-6.48913917e-01  1.02524546e+00 -1.29468132e+00  2.14224226e-01
  -1.05157958e+00 -2.20927995e+00  8.34204911e-01  1.64175674e+00]
 [ 8.47225239e-02 -2.01615668e-03  3.97779898e-01  3.03064464e-01
   1.72678590e-01 -8.16785022e-01 -2.02026247e-01  1.15189116e+00]]
V/n [[-1.21830369 -1.9244519  -0.17813956  0.1238206  

#Self Attention

In [26]:
np.matmul(q, k.T)

array([[ 3.28883808,  0.48982601,  1.33632222,  0.4820941 ],
       [ 3.36154911, -1.32232906, -2.01144638, -0.87304188],
       [ 4.39418418, -1.00376182,  3.03940548,  1.85317703],
       [-4.46303773, -1.22440431,  4.84945995,  2.27474393]])

In [27]:
#why do we need the sqrt(d_k) in he denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(0.8284968390662946, 1.0306241289466467, 6.308769804267942)

In [28]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()


(0.8284968390662946, 1.0306241289466467, 0.7885962255334926)

In [29]:
scaled

array([[ 1.16277985,  0.17317965,  0.47246125,  0.170446  ],
       [ 1.18848709, -0.46751392, -0.71115369, -0.30866692],
       [ 1.55357872, -0.3548834 ,  1.07459211,  0.65519702],
       [-1.57792212, -0.4328923 ,  1.71454301,  0.80424343]])

#Masking

In [30]:
mask = np.tril(np.ones((L, L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [31]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [32]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [33]:
scaled + mask

array([[ 1.16277985,        -inf,        -inf,        -inf],
       [ 1.18848709, -0.46751392,        -inf,        -inf],
       [ 1.55357872, -0.3548834 ,  1.07459211,        -inf],
       [-1.57792212, -0.4328923 ,  1.71454301,  0.80424343]])

#Softmax

In [45]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [46]:
attention = softmax(scaled + mask)

In [47]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.83970046, 0.16029954, 0.        , 0.        ],
       [0.56570075, 0.08389811, 0.35040114, 0.        ],
       [0.02387776, 0.07503669, 0.64252938, 0.25855617]])

In [51]:
new_v = np.matmul(attention, v)
new_v

array([[-1.21830369, -1.9244519 , -0.17813956,  0.1238206 ,  0.99814184,
         0.43866518,  0.023729  , -0.67082976],
       [-0.93458494, -1.66200609, -0.03419924,  0.15793951,  0.86992568,
         0.12683013, -0.00846391, -0.41481851],
       [-0.88667574, -1.26400542, -0.12423146,  0.2284335 ,  0.81091651,
         0.22544575,  0.59164703, -1.0658907 ],
       [-0.57666853, -0.77402159,  0.13537193, -0.11793143,  0.57828824,
         0.21365852,  1.50002071, -0.91533911]])

In [52]:
v

array([[-1.21830369, -1.9244519 , -0.17813956,  0.1238206 ,  0.99814184,
         0.43866518,  0.023729  , -0.67082976],
       [ 0.55162491, -0.28723073,  0.71980635,  0.33666532,  0.19828825,
        -1.50666192, -0.17710069,  0.92625059],
       [-0.69566188, -0.43162938, -0.23929193,  0.37141007,  0.65533725,
         0.29594258,  1.6925799 , -2.18067921],
       [-0.54915374, -1.65992012,  0.92577837, -1.48823415,  0.45832411,
         0.48765926,  1.64455891,  1.67207951]])

In [53]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k =q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
    attention = softmax(scaled)
    out = np.matmul(attention, v)
    return out, attention

In [55]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[-0.0331857   2.06896505 -0.11488886  0.88116544  0.11719453  1.33834359
   0.02257637  1.16781053]
 [-0.30329635 -1.07928697  1.37743666 -0.42731295 -1.63323351  0.42569583
   0.80718759 -0.41248306]
 [-0.46543505  0.29416147  0.67930279  0.54110615 -1.52135385 -0.00423845
  -0.7653162   1.35744534]
 [-1.01328392 -0.11631818  0.14606273 -0.53768603  0.9199883  -1.29632442
   0.83848266  1.2300858 ]]
K
 [[-1.74445624e+00  5.01438036e-01  1.37156200e+00  1.18732476e+00
  -1.24336869e+00  1.82895126e+00 -1.43497565e+00 -8.26174852e-01]
 [-7.63499000e-01  1.32199944e-01 -1.41747891e+00 -6.90487346e-01
  -7.74263859e-02  8.61092061e-01 -5.22362621e-01 -4.23884940e-01]
 [-6.48913917e-01  1.02524546e+00 -1.29468132e+00  2.14224226e-01
  -1.05157958e+00 -2.20927995e+00  8.34204911e-01  1.64175674e+00]
 [ 8.47225239e-02 -2.01615668e-03  3.97779898e-01  3.03064464e-01
   1.72678590e-01 -8.16785022e-01 -2.02026247e-01  1.15189116e+00]]
V
 [[-1.21830369 -1.9244519  -0.17813956  0.1238206   0.