In [4]:
import numpy as np

In [5]:
def compute_qkv(X: np.ndarray, W_q: np.ndarray, W_k: np.ndarray, W_v: np.ndarray):
	"""
	Compute Query (Q), Key (K), and Value (V) matrices.
	"""
    Q = X @ W_q
    K = X @ W_k
    V = X @ W_v

	return Q, K, V

In [17]:
def self_attention(Q, K, V):
    """
    Scaled dot-product attention (no batch).
    Q, K, V: shape (seq_len, d_head)
    Returns: output of shape (seq_len, d_head)
    """
    def softmax(x):
        """
        Stable softmax over the last axis.
        """
        e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
        return e_x / np.sum(e_x, axis=-1, keepdims=True)
        
    d_k = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)       # (seq_len, seq_len)
    weights = softmax(scores)             # (seq_len, seq_len)
    output = weights @ V                  # (seq_len, d_head)
    return output


In [21]:
def multi_head_attention(Q, K, V, n_heads):
    """
    Multi-head self-attention (no batch).
    Q, K, V: shape (seq_len, d_model)
    n_heads: number of attention heads
    Returns: (seq_len, d_model)
    """
    seq_len, d_model = Q.shape
    assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
    d_head = d_model // n_heads

    # Split into head
    def split_heads(x):
        return x.reshape(seq_len, n_heads, d_head).transpose(1, 0, 2) #  # (n_heads, seq_len, d_head)
    
    Q_heads = split_heads(Q)
    K_heads = split_heads(K)
    V_heads = split_heads(V)
    
    # Compute attention for each head
    heads_output = [ self_attention(Q_heads[i], K_heads[i], V_heads[i]) for i in range(n_heads)]

    # Concatenate heads
    concat = np.concatenate(heads_output, axis=-1)  # (seq_len, d_model)
    return concat.tolist()

In [23]:
Q = np.array([[1, 0], [0, 1]])
K = np.array([[1, 0], [0, 1]])
V = np.array([[1, 0], [0, 1]])
n_heads = 2

In [24]:
multi_head_attention(Q, K, V, n_heads)

array([[0.73105858, 0.5       ],
       [0.5       , 0.73105858]])

## Test Case 1 ##

In [25]:
np.random.seed(42)
m, n = 4, 4
n_heads = 2
X = np.arange(m*n).reshape(m,n) 
X = np.random.permutation(X.flatten()).reshape(m, n)
X

array([[ 0,  1,  5, 14],
       [13, 11,  8,  9],
       [ 2, 15,  4,  7],
       [10, 12,  3,  6]])

In [26]:
W_q = np.random.randint(0,4,size=(n,n))
W_k = np.random.randint(0,5,size=(n,n)) 
W_v = np.random.randint(0,6,size=(n,n)) 
Q, K, V = compute_qkv(X, W_q, W_k, W_v) 

In [27]:
print(multi_head_attention(Q, K, V, n_heads))

[[103. 109.  46.  99.]
 [103. 109.  46.  99.]
 [103. 109.  46.  99.]
 [103. 109.  46.  99.]]


## Test case 2 ##

In [28]:
m, n = 6, 8
n_heads = 4
np.random.seed(42) 
X = np.arange(m*n).reshape(m,n) 
X = np.random.permutation(X.flatten()).reshape(m, n)
X

array([[27, 40, 26, 43, 24, 37, 12, 19],
       [ 4, 25,  8,  3,  6, 39, 33, 13],
       [17, 45, 15,  9, 16, 29, 32, 46],
       [ 0, 31, 30,  5, 11, 34,  1, 44],
       [21,  2, 36, 35, 23, 41, 10, 22],
       [18, 47, 20,  7, 42, 14, 28, 38]])

In [29]:
W_q = np.random.randint(0,4,size=(n,n)) 
W_k = np.random.randint(0,5,size=(n,n)) 
W_v = np.random.randint(0,6,size=(n,n))

In [30]:
Q, K, V = compute_qkv(X, W_q, W_k, W_v)

In [31]:
Q.shape

(6, 8)

In [35]:
s = self_attention(Q,K,V)
s.shape

(6, 8)

In [36]:
s

array([[471., 472., 429., 538., 377., 450., 531., 362.],
       [471., 472., 429., 538., 377., 450., 531., 362.],
       [471., 472., 429., 538., 377., 450., 531., 362.],
       [471., 472., 429., 538., 377., 450., 531., 362.],
       [471., 472., 429., 538., 377., 450., 531., 362.],
       [471., 472., 429., 538., 377., 450., 531., 362.]])

In [32]:
mhead_attns = multi_head_attention(Q, K, V, n_heads)

In [33]:
mhead_attns.shape

(6, 8)

In [34]:
mhead_attns

array([[500., 463., 399., 495., 377., 450., 531., 362.],
       [500., 463., 399., 495., 377., 450., 531., 362.],
       [500., 463., 399., 495., 377., 450., 531., 362.],
       [500., 463., 399., 495., 377., 450., 531., 362.],
       [500., 463., 399., 495., 377., 450., 531., 362.],
       [500., 463., 399., 495., 377., 450., 531., 362.]])

## Test Case 3 ## 

In [37]:
m, n = 6, 8
n_heads = 2
np.random.seed(42)
X = np.arange(m*n).reshape(m,n) 
X = np.random.permutation(X.flatten()).reshape(m, n)
X

array([[27, 40, 26, 43, 24, 37, 12, 19],
       [ 4, 25,  8,  3,  6, 39, 33, 13],
       [17, 45, 15,  9, 16, 29, 32, 46],
       [ 0, 31, 30,  5, 11, 34,  1, 44],
       [21,  2, 36, 35, 23, 41, 10, 22],
       [18, 47, 20,  7, 42, 14, 28, 38]])

In [38]:
W_q = np.random.randint(0,4,size=(n,n)) 
W_k = np.random.randint(0,5,size=(n,n)) 
W_v = np.random.randint(0,6,size=(n,n)) 
Q, K, V = compute_qkv(X, W_q, W_k, W_v)

In [39]:
# test multi-head attention 
actual_output = multi_head_attention(Q, K, V, n_heads) 
print(actual_output)

[[547. 490. 399. 495. 377. 450. 531. 362.]
 [547. 490. 399. 495. 377. 450. 531. 362.]
 [547. 490. 399. 495. 377. 450. 531. 362.]
 [547. 490. 399. 495. 377. 450. 531. 362.]
 [547. 490. 399. 495. 377. 450. 531. 362.]
 [547. 490. 399. 495. 377. 450. 531. 362.]]
