In [1]:
import numpy as np

In [2]:
L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [3]:
print(f'Q\n{q}')
print(f'K\n{k}')
print(f'V\n{v}')

Q
[[ 1.81007476e+00  5.79714443e-01  1.76340488e+00 -1.36508567e+00
  -6.16563287e-01  2.25550782e-01  6.97932570e-01  4.53068254e-01]
 [ 4.40413862e-01  1.41182465e+00  1.12179795e+00  2.55940040e+00
  -6.11930195e-01 -1.29197911e+00  4.35326619e-01  5.87760655e-01]
 [ 1.37479398e-01  4.75943855e-01  1.70270553e-01 -2.99509648e-01
  -9.17416007e-01  3.41058826e-02  1.35763200e-01  7.02759309e-04]
 [-1.02789839e+00 -1.25648099e+00 -2.10596641e+00  8.53510939e-03
  -1.08815736e+00  2.73482192e-01 -2.95768431e-01 -1.63643952e+00]]
K
[[ 0.23806166  0.81478651  0.1867359  -0.23954279 -0.43809154  0.71938587
   0.32942035 -0.7737555 ]
 [ 1.30796644  0.28949965  0.46690281 -0.79771644 -0.23239687  1.46718916
   0.96671762  1.31305307]
 [-1.02961989  0.32733656  1.25499661 -1.8275007   0.46145633  0.21470431
   1.32039387  0.97140891]
 [ 0.02289806  0.28978226  0.05236321  0.60885248  0.7258343   1.19419763
  -0.34371495 -1.39029048]]
V
[[ 1.37933415 -0.10749346 -1.57829697  1.22250701 -1.189

# Self Attention

In [5]:
np.matmul(q, k.T)

array([[ 1.8712587 ,  6.19145356,  4.15940504, -1.57731701],
       [-0.12115231, -1.09390409, -2.67478445, -0.91758004],
       [ 0.99468909,  1.03143942,  0.53920418, -0.70517691],
       [ 0.17845267, -4.47881966, -4.43512822,  1.42083484]])

In [6]:
q.var(), k.var(), np.matmul(q, k.T).var()

(1.0807309759803259, 0.6559745871143683, 7.21530915931742)

In [9]:
scaled = np.matmul(q, k.T) / np.sqrt(d_k)
scaled

array([[ 0.66158986,  2.1890094 ,  1.47057175, -0.55766578],
       [-0.04283381, -0.3867535 , -0.94567911, -0.32441353],
       [ 0.3516757 ,  0.3646689 ,  0.19063747, -0.24931769],
       [ 0.06309255, -1.58350188, -1.56805462,  0.50234097]])

# Masking

In [10]:
mask = np.tril(np.ones((L, L)))
print(mask)
mask[mask == 0] = -np.infty
mask[mask == 1] = 0
print(mask)

[[1. 0. 0. 0.]
 [1. 1. 0. 0.]
 [1. 1. 1. 0.]
 [1. 1. 1. 1.]]
[[  0. -inf -inf -inf]
 [  0.   0. -inf -inf]
 [  0.   0.   0. -inf]
 [  0.   0.   0.   0.]]


In [11]:
scaled + mask

array([[ 0.66158986,        -inf,        -inf,        -inf],
       [-0.04283381, -0.3867535 ,        -inf,        -inf],
       [ 0.3516757 ,  0.3646689 ,  0.19063747,        -inf],
       [ 0.06309255, -1.58350188, -1.56805462,  0.50234097]])

In [17]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [20]:
attention = softmax(scaled + mask)
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.58514235, 0.41485765, 0.        , 0.        ],
       [0.34912087, 0.35368667, 0.29719247, 0.        ],
       [0.34014176, 0.06554704, 0.06656742, 0.52774378]])

In [21]:
new_v = np.matmul(attention, v)
new_v

array([[ 1.37933415, -0.10749346, -1.57829697,  1.22250701, -1.18933662,
        -0.12370067, -0.05942792,  0.01060972],
       [ 1.0330743 , -0.28357112, -0.41087137,  1.02617547, -0.22353044,
        -0.47738086,  0.29065512, -0.22642169],
       [ 1.09051302, -0.51897789, -0.59980993,  0.60701144,  0.63950322,
        -0.53760485, -0.12560718, -0.02638372],
       [ 0.42173924, -1.45392657, -0.18455794, -0.23598787, -0.85649913,
        -0.09220333, -0.85970381,  0.12538971]])

In [23]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_prodict_attention(q, k, v, mask=None):
    d_k = q.shape[-1]
    scaled = np.matmul(q, k.T) / np.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention = softmax(scaled)
    output = np.matmul(attention, v)
    return output, attention

In [26]:
values, attention = scaled_dot_prodict_attention(q, k, v, mask=mask)
print(f'Q\n{q}')
print(f'K\n{k}')
print(f'V\n{v}')
print(f'New V\n{values}')
print(f'Attention\n{attention}')

Q
[[ 1.81007476e+00  5.79714443e-01  1.76340488e+00 -1.36508567e+00
  -6.16563287e-01  2.25550782e-01  6.97932570e-01  4.53068254e-01]
 [ 4.40413862e-01  1.41182465e+00  1.12179795e+00  2.55940040e+00
  -6.11930195e-01 -1.29197911e+00  4.35326619e-01  5.87760655e-01]
 [ 1.37479398e-01  4.75943855e-01  1.70270553e-01 -2.99509648e-01
  -9.17416007e-01  3.41058826e-02  1.35763200e-01  7.02759309e-04]
 [-1.02789839e+00 -1.25648099e+00 -2.10596641e+00  8.53510939e-03
  -1.08815736e+00  2.73482192e-01 -2.95768431e-01 -1.63643952e+00]]
K
[[ 0.23806166  0.81478651  0.1867359  -0.23954279 -0.43809154  0.71938587
   0.32942035 -0.7737555 ]
 [ 1.30796644  0.28949965  0.46690281 -0.79771644 -0.23239687  1.46718916
   0.96671762  1.31305307]
 [-1.02961989  0.32733656  1.25499661 -1.8275007   0.46145633  0.21470431
   1.32039387  0.97140891]
 [ 0.02289806  0.28978226  0.05236321  0.60885248  0.7258343   1.19419763
  -0.34371495 -1.39029048]]
V
[[ 1.37933415 -0.10749346 -1.57829697  1.22250701 -1.189