## Imports

In [1]:
import numpy as np

## Self-attention

In [2]:
def compute_qkv(X, W_q, W_k, W_v):
    # Do dot product of input and weights
	Q = X @ W_q
	K = X @ W_k
	V = X @ W_v
	return Q, K, V

def softmax(X):
    # Doing this make the softmax numerically stable
    X_shifted = X - np.max(X, axis=1, keepdims=True)
    return np.exp(X_shifted)/np.sum(np.exp(X_shifted), axis=1, keepdims=True)

def self_attention(Q, K, V):
    d_k = Q.shape[1]
    scores = (Q @ K.T) / np.sqrt(d_k)
    attention_weights = softmax(scores)
    attention_output = np.matmul(attention_weights, V)
    return attention_output

In [3]:
X = np.array([[1, 0], [0, 1]])
W_q = np.array([[1, 0], [0, 1]])
W_k = np.array([[1, 0], [0, 1]])
W_v = np.array([[1, 2], [3, 4]])


In [4]:
Q, K, V = compute_qkv(X, W_q, W_k, W_v)
output = self_attention(Q, K, V)

In [5]:
print(output)

[[1.6604769 2.6604769]
 [2.3395231 3.3395231]]


## Multi-head attention

In [25]:
def compute_qkv(X, W_q, W_k, W_v):
    Q = X @ W_q
    K = X @ W_k
    V = X @ W_v
    return Q, K, V

def softmax(X):
    X_shifted = X - np.max(X, axis=1, keepdims=True)
    return np.exp(X_shifted) / np.sum(np.exp(X_shifted), axis=1, keepdims=True)

def self_attention(Q, K, V):
    dk = K.shape[1]
    scores = (Q @ K.T)/ np.sqrt(dk)
    attn_wei = softmax(scores)
    attn_outputs = attn_wei @ V
    return attn_outputs


def multi_head_attention(Q, K, V, n_heads):
    seq_len, d_model = X.shape
    # Need to ensure the dimension can be divided equally for all heads
    assert d_model % n_heads == 0

    d_k = d_model // n_heads

    # perform splits over heads
    Q_reshaped = Q.reshape(seq_len, n_heads, d_k).transpose(1,0,2)
    K_reshaped = K.reshape(seq_len, n_heads, d_k).transpose(1,0,2)
    V_reshaped = V.reshape(seq_len, n_heads, d_k).transpose(1,0,2)

    attention = []

    # calculate attention
    for i in range(n_heads):
        attn = self_attention(Q_reshaped[i],K_reshaped[i], V_reshaped[i])
        attention.append(attn)
    
    # concatenate the attention output
    attention_op = np.concatenate(attention, axis=-1)
    return attention_op

In [26]:
m, n = 4, 4 
n_heads = 2 
np.random.seed(42) 
X = np.arange(m*n).reshape(m,n) 
X = np.random.permutation(X.flatten()).reshape(m, n) 

W_q = np.random.randint(0,4,size=(n,n)) 
W_k = np.random.randint(0,5,size=(n,n)) 
W_v = np.random.randint(0,6,size=(n,n)) 

Q, K, V = compute_qkv(X, W_q, W_k, W_v) 
print(multi_head_attention(Q, K, V, n_heads))

[[103. 109.  46.  99.]
 [103. 109.  46.  99.]
 [103. 109.  46.  99.]
 [103. 109.  46.  99.]]


## Masked self-head attention

In [4]:
def compute_qkv(X: np.ndarray, W_q: np.ndarray, W_k: np.ndarray, W_v: np.ndarray):
	"""
	Compute Query (Q), Key (K), and Value (V) matrices.
	"""
	return np.dot(X, W_q), np.dot(X, W_k), np.dot(X, W_v)

def softmax(X):
    X = X - np.max(X, axis=1, keepdims=True)
    return np.exp(X) / np.sum(np.exp(X), axis=1, keepdims=True)

def masked_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, mask: np.ndarray) -> np.ndarray:
    """
    Compute masked self-attention.
    """
    dk = K.shape[1]
    QKt = Q @ K.T
    scores = np.divide(QKt, np.sqrt(dk))
    masked_scores = np.add(scores, mask)
    softmax_scores = softmax(masked_scores)
    attention = softmax_scores @ V
    return attention

In [21]:
np.random.seed(42) 
X = np.arange(48).reshape(6,8) 
X = np.random.permutation(X.flatten()).reshape(6, 8)

mask = np.triu(np.ones((6, 6))*(-np.inf), k=1)
mask

array([[  0., -inf, -inf, -inf, -inf, -inf],
       [  0.,   0., -inf, -inf, -inf, -inf],
       [  0.,   0.,   0., -inf, -inf, -inf],
       [  0.,   0.,   0.,   0., -inf, -inf],
       [  0.,   0.,   0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.,   0.,   0.]])

In [22]:
W_q = np.random.randint(0,4,size=(8,8)) 
W_k = np.random.randint(0,5,size=(8,8)) 
W_v = np.random.randint(0,6,size=(8,8)) 

Q, K, V = compute_qkv(X, W_q, W_k, W_v) 
print(masked_attention(Q, K, V, mask))

[[547. 490. 399. 495. 485. 439. 645. 393.]
 [547. 490. 399. 495. 485. 439. 645. 393.]
 [471. 472. 429. 538. 377. 450. 531. 362.]
 [471. 472. 429. 538. 377. 450. 531. 362.]
 [471. 472. 429. 538. 377. 450. 531. 362.]
 [471. 472. 429. 538. 377. 450. 531. 362.]]
