In [68]:
import numpy as np

In [69]:
def compute_qkv(X: np.ndarray, W_q: np.ndarray, W_k: np.ndarray, W_v: np.ndarray):
	"""
	Compute Query (Q), Key (K), and Value (V) matrices.
	"""
	return np.dot(X, W_q), np.dot(X, W_k), np.dot(X, W_v)

In [84]:
def masked_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, mask: np.ndarray) -> np.ndarray:
    """
	Compute masked self-attention.
	"""
    def softmax(x, axis = -1):
        x = x - np.max(x, axis = axis, keepdims = True)
        return np.exp(x) / np.sum(np.exp(x), axis = axis, keepdims =True)
    #score = Q @ K.T + mask
    score = Q @ K.T * K.shape[-1]**(-0.5) + mask
    maskedattention = softmax(score) @ V
    return maskedattention

In [85]:
np.random.seed(42)
X = np.arange(48).reshape(6,8)
X

array([[ 0,  1,  2,  3,  4,  5,  6,  7],
       [ 8,  9, 10, 11, 12, 13, 14, 15],
       [16, 17, 18, 19, 20, 21, 22, 23],
       [24, 25, 26, 27, 28, 29, 30, 31],
       [32, 33, 34, 35, 36, 37, 38, 39],
       [40, 41, 42, 43, 44, 45, 46, 47]])

In [86]:
X = np.random.permutation(X.flatten()).reshape(6, 8)
X

array([[27, 40, 26, 43, 24, 37, 12, 19],
       [ 4, 25,  8,  3,  6, 39, 33, 13],
       [17, 45, 15,  9, 16, 29, 32, 46],
       [ 0, 31, 30,  5, 11, 34,  1, 44],
       [21,  2, 36, 35, 23, 41, 10, 22],
       [18, 47, 20,  7, 42, 14, 28, 38]])

In [87]:
mask = np.triu(np.ones((6, 6))*(-np.inf), k=1)
mask

array([[  0., -inf, -inf, -inf, -inf, -inf],
       [  0.,   0., -inf, -inf, -inf, -inf],
       [  0.,   0.,   0., -inf, -inf, -inf],
       [  0.,   0.,   0.,   0., -inf, -inf],
       [  0.,   0.,   0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.,   0.,   0.]])

In [88]:
W_q = np.random.randint(0,4, size=(8,8))
W_k = np.random.randint(0,5, size=(8,8))
W_v = np.random.randint(0,6, size=(8,8))

In [89]:
W_q.shape

(8, 8)

In [90]:
W_q

array([[0, 3, 1, 1, 1, 0, 1, 0],
       [1, 3, 3, 2, 3, 2, 3, 0],
       [3, 2, 2, 1, 0, 3, 1, 3],
       [3, 1, 1, 1, 1, 1, 3, 1],
       [0, 2, 1, 1, 3, 1, 1, 1],
       [3, 1, 2, 3, 2, 3, 1, 2],
       [3, 0, 1, 3, 0, 3, 0, 1],
       [2, 0, 3, 1, 0, 3, 3, 3]])

In [91]:
Q, K, V = compute_qkv(X, W_q, W_k, W_v)

In [92]:
Q.shape, K.shape, V.shape

((6, 8), (6, 8), (6, 8))

In [93]:
Q

array([[432, 381, 409, 366, 336, 429, 420, 288],
       [300, 157, 254, 300, 178, 338, 180, 183],
       [392, 286, 435, 376, 267, 481, 377, 298],
       [329, 214, 370, 257, 199, 405, 315, 307],
       [412, 263, 315, 294, 213, 389, 298, 324],
       [330, 340, 418, 345, 320, 443, 370, 279]])

In [94]:
K

array([[570, 504, 491, 531, 304, 449, 307, 408],
       [357, 344, 315, 311, 185, 369, 285, 292],
       [622, 434, 417, 529, 265, 539, 363, 428],
       [427, 284, 284, 307, 176, 376, 360, 349],
       [473, 411, 400, 491, 235, 371, 358, 342],
       [701, 418, 350, 553, 294, 498, 286, 382]])

In [95]:
V

array([[547, 490, 399, 495, 485, 439, 645, 393],
       [203, 293, 199, 399, 220, 163, 279, 174],
       [471, 472, 429, 538, 377, 450, 531, 362],
       [398, 388, 244, 338, 224, 385, 413, 295],
       [541, 331, 366, 492, 475, 367, 423, 422],
       [500, 463, 514, 447, 403, 500, 556, 496]])

In [96]:
masked_attentions = masked_attention(Q, K, V, mask)
masked_attentions

array([[547., 490., 399., 495., 485., 439., 645., 393.],
       [547., 490., 399., 495., 485., 439., 645., 393.],
       [471., 472., 429., 538., 377., 450., 531., 362.],
       [471., 472., 429., 538., 377., 450., 531., 362.],
       [471., 472., 429., 538., 377., 450., 531., 362.],
       [471., 472., 429., 538., 377., 450., 531., 362.]])