In [4]:
import numpy as np

def split_heads(x, n_heads):
    #Split the last dimension into multiple heads
    batch_size, seq_len, d_model = x.shape
    d_head = d_model // n_heads
    #reordering the dimensions for parallel processing
    return x.reshape(batch_size, seq_len, n_heads, d_head).transpose(0, 2, 1, 3)

def combine_heads(x):
    #Combine attention heads back to original shape
    batch_size, n_heads, seq_len, d_head = x.shape
    #combining the multiheads
    return x.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, n_heads * d_head)

def multi_head_attention_simple(X, WQ, WK, WV, n_heads=2):
    """
    Simplified multi-head attention
    X: Input matrix (batch_size x seq_len x d_model)
    WQ, WK, WV: Weight matrices (d_model x d_model)
    n_heads: Number of attention heads
    """
    #Compute Q, K, V
    Q = X @ WQ
    K = X @ WK
    V = X @ WV
    
    #Split into multiple heads
    Q_split = split_heads(Q, n_heads)
    K_split = split_heads(K, n_heads)
    V_split = split_heads(V, n_heads)
    
    #Compute the scaled dot product attention for each head
    d_head = Q.shape[-1] // n_heads
    attention_outputs = []
    
    for h in range(n_heads):
        #Get current head's Q, K, V
        q = Q_split[:, h, :, :]
        k = K_split[:, h, :, :]
        v = V_split[:, h, :, :]
        
        #Compute attention
        scores = (q @ k.swapaxes(-1, -2)) / np.sqrt(d_head)
        weights = np.exp(scores) / np.sum(np.exp(scores), axis=-1, keepdims=True)
        head_output = weights @ v
        attention_outputs.append(head_output)
    
    #Combine heads and return
    combined = np.concatenate(attention_outputs, axis=-1)
    return combined

#Example usage
X = np.array([[[1, 0, 0, 1], [0, 1, 1, 0]]])  #Batch of 2 tokens, d_model=4
WQ = np.eye(4)  
WK = np.eye(4)
WV = np.eye(4)
n_heads = 2

output = multi_head_attention_simple(X, WQ, WK, WV, n_heads)
print("Multi-head Attention Output:\n", output)

Multi-head Attention Output:
 [[[0.66976155 0.33023845 0.33023845 0.66976155]
  [0.33023845 0.66976155 0.66976155 0.33023845]]]
