In [1]:
import numpy as np

def softmax(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

def split_heads(x, num_heads):
    """
    Splits the last dimension of x into (num_heads, d_head) for a batch of sequences.
    """
    batch_size, seq_length, d_model = x.shape
    d_head = d_model // num_heads
    x = x.reshape(batch_size, seq_length, num_heads, d_head)
    return x.transpose(0, 2, 1, 3)  # (batch_size, num_heads, seq_length, d_head)

def batched_classic_multi_head_attention(x, W_Q, W_K, W_V, W_O, num_heads):
    """
    Implements multi-head attention mechanism with weight matrices using np.einsum for a batch of sequences.
    
    Parameters:
    - x: Input matrix of shape (batch_size, seq_length, d_model)
    - W_O, W_V, W_K, W_Q: Weight matrices, each of shape (d_model, d_model)
    - num_heads: Number of attention heads
    
    Returns:
    - Output matrix after attention and linear transformations, of shape (batch_size, seq_length, d_model)
    """
    batch_size, seq_length, d_model = x.shape
    d_head = d_model // num_heads
    sqrt_d_head = np.sqrt(d_head)
    
    # Compute Query, Key, Value matrices for each head
    Q = split_heads(np.dot(x, W_Q), num_heads)  # (batch_size, num_heads, seq_length, d_head)
    K = split_heads(np.dot(x, W_K), num_heads)  # (batch_size, num_heads, seq_length, d_head)
    V = split_heads(np.dot(x, W_V), num_heads)  # (batch_size, num_heads, seq_length, d_head)
    
    # Compute attention scores and weights for each head
    attention_scores = np.einsum('bnqd,bnkd->bnqk', Q, K)  # (batch_size, num_heads, seq_length, seq_length)

    # Apply autoregressive masking
    mask = np.tril(np.ones((seq_length, seq_length)), k=0)
    attention_scores = attention_scores - ((1 - mask) * 1e9)
    print(f"{attention_scores[0]=}")

    attention_weights = softmax(attention_scores / sqrt_d_head)
    
    # Compute weighted sum of value vectors for each head
    weighted_sum = np.einsum('bnqk,bnkd->bnqd', attention_weights, V)  # (batch_size, num_heads, seq_length, d_head)
    
    # Concatenate heads and apply the output weight matrix W_O
    weighted_sum = weighted_sum.transpose(0, 2, 1, 3).reshape(batch_size, seq_length, d_model)  # (batch_size, seq_length, d_model)
    output = np.dot(weighted_sum, W_O)
    
    return output


def batched_refactored_multi_head_attention(X, W_Q, W_K, W_V, W_O, num_heads, d_head):
    """
    Implements refactored multi-head attention mechanism for a batch of sequences.

    Parameters:
    - X: Input matrix of shape (batch_size, seq_len, d_model)
    - W_Q, W_K, W_V: Weight matrices for Query, Key, and Value, each of shape (d_model, d_model)
    - W_O: Output weight matrix of shape (d_model, d_model)
    - num_heads: Number of attention heads
    - d_head: Dimension of each head

    Returns:
    - Output matrix after attention, of shape (batch_size, seq_len, d_model)
    """
    batch_size, seq_len, d_model = X.shape
    sqrt_d_head = np.sqrt(d_head)
    
    # Initialize the result matrix
    result = np.zeros((batch_size, seq_len, d_model))

    # Autoregressive mask
    mask = np.tril(np.ones((seq_len, seq_len)), k=0).astype(bool)
    inverted_mask = np.logical_not(mask)

    for i in range(num_heads):
        # Extract the weight matrices for the current head
        W_Q_i = W_Q[:, i * d_head:(i + 1) * d_head]
        W_K_i = W_K[:, i * d_head:(i + 1) * d_head]
        W_V_i = W_V[:, i * d_head:(i + 1) * d_head]
        W_O_i = W_O[i * d_head:(i + 1) * d_head, :]

        # Compute the combined Query-Key weight matrix for the current head
        W_QK_i = np.dot(W_Q_i, W_K_i.T) # (d_model, d_model)

        # Compute the attention pattern
        XWQ = np.einsum('bsd,de->bse', X, W_QK_i) # (batch_size, seq_len, d_model)
        A_scores = np.einsum('bqd,bkd->bqk', XWQ, X)
        
        # Apply autoregressive masking
        
        A_scores[:, inverted_mask] = -1e9
        print(f"{A_scores[0]=}")

        # compute attention weights
        A = softmax(A_scores / sqrt_d_head) # (batch_size, seq_len, seq_len)

        # Compute the combined Value-Output weight matrix for the current head
        W_VO_i = np.dot(W_V_i, W_O_i) # (d_model, d_model)
        #print(f"{W_VO_i.shape=}")

        # Compute the result for the current head and add it to the overall result
        XWV = np.einsum('bsd,de->bse', X, W_VO_i)
        #print(f"{XWV.shape=}")
        
        r_i = np.einsum('bqk,bkd->bqd', A, XWV)
        # print(f"{r_i=}")
        # break
        result += r_i

    return result


def embedding(x, W_emb):
    """
    Embeds the input sequence using an embedding matrix.
    
    Parameters:
    - x: Input matrix of shape (batch_size, seq_length)
    - W_emb: Embedding matrix of shape (vocab_size, d_model)
    
    Returns:
    - Embedded input of shape (batch_size, seq_length, d_model)
    """
    return W_emb[x]

def unembedding(x, W_unemb):
    """
    Unembeds the output of the transformer using an unembedding matrix.
    
    Parameters:
    - x: Output matrix of shape (batch_size, seq_length, d_model)
    - W_unemb: Unembedding matrix of shape (d_model, vocab_size)
    
    Returns:
    - Unembedded output of shape (batch_size, seq_length, vocab_size)
    """
    return np.dot(x, W_unemb)

# Define dimensions
batch_size = 2
seq_len = 3
d_model = 12
num_heads = 3
d_head = 4
vocab_size = 20

# Initialize input, embedding matrix, and unembedding matrix
X = np.random.randint(vocab_size, size=(batch_size, seq_len))

# Randomly initialize weight matrices
W_Q = np.random.rand(d_model, d_model)
W_K = np.random.rand(d_model, d_model)
W_V = np.random.rand(d_model, d_model)
W_O = np.random.rand(d_model, d_model)

W_emb = np.random.rand(vocab_size, d_model)
W_unemb = np.random.rand(d_model, vocab_size)

# Embed the input
embedded_X = embedding(X, W_emb)

# Apply the multi-head attention mechanism
attention_output = batched_classic_multi_head_attention(embedded_X, W_Q, W_K, W_V, W_O, num_heads)
print("Original multi-head attn output shape before unembed:", attention_output.shape)

# Unembed the output
final_output = unembedding(attention_output, W_unemb)
print("After unembed:", final_output.shape)
#print(final_output)

# Compute the output after refactored multi-head attention
refactored_output = batched_refactored_multi_head_attention(embedded_X, W_Q, W_K, W_V, W_O, num_heads, d_head)
print("Refactored multi-head attn output shape before unembed:", refactored_output.shape)
final_refactored_output = unembedding(refactored_output, W_unemb)

print("After unembed:", final_refactored_output.shape)
#print(final_refactored_output)

assert np.allclose(final_output, final_refactored_output), "The refactored function produces different results!"


attention_scores[0]=array([[[ 3.33760652e+01, -9.99999948e+08, -9.99999976e+08],
        [ 5.04926032e+01,  7.87519439e+01, -9.99999964e+08],
        [ 2.47624994e+01,  3.86156738e+01,  1.77009859e+01]],

       [[ 3.76492073e+01, -9.99999950e+08, -9.99999973e+08],
        [ 5.38064569e+01,  7.14070105e+01, -9.99999961e+08],
        [ 3.06108894e+01,  4.05976343e+01,  2.22454985e+01]],

       [[ 4.13436573e+01, -9.99999947e+08, -9.99999971e+08],
        [ 5.80094571e+01,  7.30794967e+01, -9.99999960e+08],
        [ 2.82867568e+01,  3.56925183e+01,  1.97417509e+01]]])
Original multi-head attn output shape before unembed: (2, 3, 12)
After unembed: (2, 3, 20)
A_scores[0]=array([[ 3.33760652e+01, -1.00000000e+09, -1.00000000e+09],
       [ 5.04926032e+01,  7.87519439e+01, -1.00000000e+09],
       [ 2.47624994e+01,  3.86156738e+01,  1.77009859e+01]])
A_scores[0]=array([[ 3.76492073e+01, -1.00000000e+09, -1.00000000e+09],
       [ 5.38064569e+01,  7.14070105e+01, -1.00000000e+09],
       [ 