In [8]:
import numpy as np

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def single_head_attention_with_weights(x, W_O, W_V, W_K, W_Q):
    """
    Implements single-head attention mechanism with weight matrices.
    
    Parameters:
    - x: Input matrix of shape (n, d_model)
    - W_O, W_V, W_K, W_Q: Weight matrices, each of shape (d_model, d_model)
    
    Returns:
    - Output matrix after attention and linear transformations, of shape (n, d_model)
    """
    
    # Compute Query, Key, Value matrices from input and corresponding weight matrices (n x d_model)
    Q = np.dot(x, W_Q)
    K = np.dot(x, W_K)
    V = np.dot(x, W_V)
    
    # Compute attention scores (n x n)
    attention_scores = np.dot(Q, K.T)
    
    # Apply softmax to get attention distribution
    attention_weights = softmax(attention_scores)
    
    # Compute weighted sum of value vectors (n x d_model)
    weighted_sum = np.dot(attention_weights, V)
    
    # Apply the output weight matrix W_O (n x d_model)
    output = np.dot(weighted_sum, W_O)
    
    return output

def single_head_attention_refactored(x, W_O, W_V, W_K, W_Q):
    """
    Implements single-head attention mechanism with weight matrices.
    
    Parameters:
    - x: Input matrix of shape (n, d_model)
    - W_O, W_V, W_K, W_Q: Weight matrices, each of shape (d_model, d_model)
    
    Returns:
    - Output matrix after attention and linear transformations, of shape (n, d_model)
    """
    
    # Compute attention scores
    W_QK = np.dot(W_Q, W_K.T)
    attention_scores = np.dot(np.dot(x, W_QK), x.T)
    
    # Apply softmax to get attention distribution
    A = softmax(attention_scores)
    

    # Get W_V W_O transformation matrix
    W_V_O = np.dot(W_V, W_O)
    
    # Compute output matrix
    result = np.dot(A, np.dot(x, W_V_O))
    
    return result

# Test the function
n = 3  # Number of tokens
d_model = 12  # Dimensionality of each token

# Randomly initialize input matrix and weight matrices
x = np.random.rand(n, d_model)
W_O = np.random.rand(d_model, d_model)
W_V = np.random.rand(d_model, d_model)
W_K = np.random.rand(d_model, d_model)
W_Q = np.random.rand(d_model, d_model)

# Compute the output after single-head attention and linear transformations
output = single_head_attention_with_weights(x, W_O, W_V, W_K, W_Q)
print("Output matrix shape:", output.shape)
print(output)

refactored_output = single_head_attention_refactored(x, W_O, W_V, W_K, W_Q)
print("Refactored output matrix shape:", refactored_output.shape)
print(refactored_output)

assert np.allclose(output, refactored_output), "The refactored function produces different results!"

Output matrix shape: (3, 12)
[[5.30984289e+01 5.72017866e+01 5.69877472e+01 4.46998351e+01
  6.52204369e+01 4.65477197e+01 4.96062475e+01 4.84270962e+01
  5.46022068e+01 7.65112718e+01 3.47745563e+01 5.97145625e+01]
 [1.88563769e-10 2.01904032e-10 2.00673811e-10 1.57291810e-10
  2.34654430e-10 1.69160248e-10 1.80991391e-10 1.71185050e-10
  1.91803692e-10 2.77942069e-10 1.22604462e-10 2.10412911e-10]
 [6.31252502e-17 6.76329228e-17 6.72293265e-17 5.26723705e-17
  7.86233398e-17 5.67073674e-17 6.06817541e-17 5.73296135e-17
  6.42436082e-17 9.31473601e-17 4.10568552e-17 7.04968392e-17]]
Refactored output matrix shape: (3, 12)
[[5.30984289e+01 5.72017866e+01 5.69877472e+01 4.46998351e+01
  6.52204369e+01 4.65477197e+01 4.96062475e+01 4.84270962e+01
  5.46022068e+01 7.65112718e+01 3.47745563e+01 5.97145625e+01]
 [1.88563769e-10 2.01904032e-10 2.00673811e-10 1.57291810e-10
  2.34654430e-10 1.69160248e-10 1.80991391e-10 1.71185050e-10
  1.91803692e-10 2.77942069e-10 1.22604462e-10 2.10412911e

In [57]:
import numpy as np

def softmax(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

def split_heads(x, num_heads):
    """
    Splits the last dimension of x into (num_heads, d_head).
    """
    seq_length, d_model = x.shape
    d_head = d_model // num_heads
    x = x.reshape(seq_length, num_heads, d_head)
    return x.transpose(1, 0, 2)  # (num_heads, seq_length, d_head)

def multi_head_attention_with_weights(x, W_O, W_V, W_K, W_Q, num_heads):
    """
    Implements multi-head attention mechanism with weight matrices using np.einsum.
    
    Parameters:
    - x: Input matrix of shape (seq_length, d_model)
    - W_O, W_V, W_K, W_Q: Weight matrices, each of shape (d_model, d_model)
    - num_heads: Number of attention heads
    
    Returns:
    - Output matrix after attention and linear transformations, of shape (seq_length, d_model)
    """
    #print(x.shape, W_O.shape, W_V.shape, W_K.shape, W_Q.shape)
    seq_length, d_model = x.shape
    d_head = d_model // num_heads
    sqrt_d_head = np.sqrt(d_head)
    
    # Compute Query, Key, Value matrices for each head
    Q = split_heads(np.dot(x, W_Q), num_heads)  # (num_heads, seq_length, d_head)
    K = split_heads(np.dot(x, W_K), num_heads)  # (num_heads, seq_length, d_head)
    V = split_heads(np.dot(x, W_V), num_heads)  # (num_heads, seq_length, d_head)
    
    # Compute attention scores and weights for each head
    attention_scores = np.einsum('nqd,nkd->nqk', Q, K)  # (num_heads, seq_length, seq_length)
    attention_weights = softmax(attention_scores / sqrt_d_head)
    
    # Compute weighted sum of value vectors for each head
    weighted_sum = np.einsum('nqk,nkd->nqd', attention_weights, V)  # (num_heads, seq_length, d_head)
    
    # Concatenate heads and apply the output weight matrix W_O
    weighted_sum = weighted_sum.transpose(1, 0, 2).reshape(seq_length, d_model)  # (seq_length, d_model)
    output = np.dot(weighted_sum, W_O)
    
    return output

def refactored_multi_head_attention(X, W_Q, W_K, W_V, W_O, num_heads, d_head):
    """
    Implements refactored multi-head attention mechanism.

    Parameters:
    - X: Input matrix of shape (seq_len, d_model)
    - W_Q, W_K, W_V: Weight matrices for Query, Key, and Value, each of shape (d_model, d_model)
    - W_O: Output weight matrix of shape (d_model, d_model)
    - num_heads: Number of attention heads
    - d_head: Dimension of each head

    Returns:
    - Output matrix after attention, of shape (seq_len, d_model)
    """
    seq_len, d_model = X.shape
    sqrt_d_head = np.sqrt(d_head)
    
    # Initialize the result matrix
    result = np.zeros((seq_len, d_model))

    for i in range(num_heads):
        # Extract the weight matrices for the current head
        W_Q_i = W_Q[:, i * d_head:(i + 1) * d_head]
        W_K_i = W_K[:, i * d_head:(i + 1) * d_head]
        W_V_i = W_V[:, i * d_head:(i + 1) * d_head]
        W_O_i = W_O[i * d_head:(i + 1) * d_head, :]

        # Compute the combined Query-Key weight matrix for the current head
        W_QK_i = np.dot(W_Q_i, W_K_i.T)

        # Compute the attention pattern
        A = softmax(np.dot(np.dot(X, W_QK_i), X.T) / sqrt_d_head)
        #print(f"{A=}")

        # Compute the combined Value-Output weight matrix for the current head
        W_VO_i = np.dot(W_V_i, W_O_i)
        #print(f"{W_VO_i=}")

        # Compute the result for the current head and add it to the overall result
        XWV = np.dot(X, W_VO_i)
        r_i = np.dot(A, XWV)
        # print(f"{r_i=}")
        # break
        result += r_i

    return result

# Test the function
seq_len = 3
d_model = 12
num_heads = 3
d_head = 4

# Randomly initialize input matrix and weight matrices
X = np.random.rand(seq_len, d_model)
W_Q = np.random.rand(d_model, d_model)
W_K = np.random.rand(d_model, d_model)
W_V = np.random.rand(d_model, d_model)
W_O = np.random.rand(d_model, d_model)

# Compute the output after normal multi-head attention
output = multi_head_attention_with_weights(X, W_O, W_V, W_K, W_Q, num_heads)
print("Original multi-head output matrix shape:", output.shape)
print(output)

# Compute the output after refactored multi-head attention
refactored_output = refactored_multi_head_attention(X, W_Q, W_K, W_V, W_O, num_heads, d_head)
print("Refactored multi-head output matrix shape:", refactored_output.shape)
print(refactored_output)

assert np.allclose(output, refactored_output), "The refactored function produces different results!"


Original multi-head output matrix shape: (3, 12)
[[13.11041448 14.45759892 17.24024685 19.25915873 17.07997889 14.14956077
  11.3144583  19.49129843 15.59462499 17.40295627 17.59148781 20.68279173]
 [13.10060618 14.38749463 17.21519085 19.24719397 17.0751548  14.15010593
  11.30695468 19.46354041 15.57782665 17.3633173  17.6274578  20.64612068]
 [13.14999024 14.48320425 17.25558089 19.30573325 17.11166652 14.18967852
  11.36247037 19.52591501 15.64968297 17.49962161 17.6540556  20.76210828]]
Refactored multi-head output matrix shape: (3, 12)
[[13.11041448 14.45759892 17.24024685 19.25915873 17.07997889 14.14956077
  11.3144583  19.49129843 15.59462499 17.40295627 17.59148781 20.68279173]
 [13.10060618 14.38749463 17.21519085 19.24719397 17.0751548  14.15010593
  11.30695468 19.46354041 15.57782665 17.3633173  17.6274578  20.64612068]
 [13.14999024 14.48320425 17.25558089 19.30573325 17.11166652 14.18967852
  11.36247037 19.52591501 15.64968297 17.49962161 17.6540556  20.76210828]]


In [2]:
import numpy as np

def softmax(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

def split_heads(x, num_heads):
    """
    Splits the last dimension of x into (num_heads, d_head) for a batch of sequences.
    """
    batch_size, seq_length, d_model = x.shape
    d_head = d_model // num_heads
    x = x.reshape(batch_size, seq_length, num_heads, d_head)
    return x.transpose(0, 2, 1, 3)  # (batch_size, num_heads, seq_length, d_head)

def batched_classic_multi_head_attention(x, W_Q, W_K, W_V, W_O, num_heads):
    """
    Implements multi-head attention mechanism with weight matrices using np.einsum for a batch of sequences.
    
    Parameters:
    - x: Input matrix of shape (batch_size, seq_length, d_model)
    - W_O, W_V, W_K, W_Q: Weight matrices, each of shape (d_model, d_model)
    - num_heads: Number of attention heads
    
    Returns:
    - Output matrix after attention and linear transformations, of shape (batch_size, seq_length, d_model)
    """
    batch_size, seq_length, d_model = x.shape
    d_head = d_model // num_heads
    sqrt_d_head = np.sqrt(d_head)
    
    # Compute Query, Key, Value matrices for each head
    Q = split_heads(np.dot(x, W_Q), num_heads)  # (batch_size, num_heads, seq_length, d_head)
    K = split_heads(np.dot(x, W_K), num_heads)  # (batch_size, num_heads, seq_length, d_head)
    V = split_heads(np.dot(x, W_V), num_heads)  # (batch_size, num_heads, seq_length, d_head)
    
    # Compute attention scores and weights for each head
    attention_scores = np.einsum('bnqd,bnkd->bnqk', Q, K)  # (batch_size, num_heads, seq_length, seq_length)
    attention_weights = softmax(attention_scores / sqrt_d_head)
    #print(f"{attention_weights=}")
    
    # Compute weighted sum of value vectors for each head
    weighted_sum = np.einsum('bnqk,bnkd->bnqd', attention_weights, V)  # (batch_size, num_heads, seq_length, d_head)
    
    # Concatenate heads and apply the output weight matrix W_O
    weighted_sum = weighted_sum.transpose(0, 2, 1, 3).reshape(batch_size, seq_length, d_model)  # (batch_size, seq_length, d_model)
    output = np.dot(weighted_sum, W_O)
    
    return output


def batched_refactored_multi_head_attention(X, W_Q, W_K, W_V, W_O, num_heads, d_head):
    """
    Implements refactored multi-head attention mechanism for a batch of sequences.

    Parameters:
    - X: Input matrix of shape (batch_size, seq_len, d_model)
    - W_Q, W_K, W_V: Weight matrices for Query, Key, and Value, each of shape (d_model, d_model)
    - W_O: Output weight matrix of shape (d_model, d_model)
    - num_heads: Number of attention heads
    - d_head: Dimension of each head

    Returns:
    - Output matrix after attention, of shape (batch_size, seq_len, d_model)
    """
    batch_size, seq_len, d_model = X.shape
    sqrt_d_head = np.sqrt(d_head)
    
    # Initialize the result matrix
    result = np.zeros((batch_size, seq_len, d_model))

    for i in range(num_heads):
        # Extract the weight matrices for the current head
        W_Q_i = W_Q[:, i * d_head:(i + 1) * d_head]
        W_K_i = W_K[:, i * d_head:(i + 1) * d_head]
        W_V_i = W_V[:, i * d_head:(i + 1) * d_head]
        W_O_i = W_O[i * d_head:(i + 1) * d_head, :]

        # Compute the combined Query-Key weight matrix for the current head
        W_QK_i = np.dot(W_Q_i, W_K_i.T) # (d_model, d_model)

        # Compute the attention pattern
        XWQ = np.einsum('bsd,de->bse', X, W_QK_i) # (batch_size, seq_len, d_model)
        A = softmax(np.einsum('bqd,bkd->bqk', XWQ, X) / sqrt_d_head) # (batch_size, seq_len, seq_len)
        #print(f"{A=}")

        # Compute the combined Value-Output weight matrix for the current head
        W_VO_i = np.dot(W_V_i, W_O_i) # (d_model, d_model)
        #print(f"{W_VO_i.shape=}")

        # Compute the result for the current head and add it to the overall result
        XWV = np.einsum('bsd,de->bse', X, W_VO_i)
        #print(f"{XWV.shape=}")
        
        r_i = np.einsum('bqk,bkd->bqd', A, XWV)
        # print(f"{r_i=}")
        # break
        result += r_i

    return result


def embedding(x, W_emb):
    """
    Embeds the input sequence using an embedding matrix.
    
    Parameters:
    - x: Input matrix of shape (batch_size, seq_length)
    - W_emb: Embedding matrix of shape (vocab_size, d_model)
    
    Returns:
    - Embedded input of shape (batch_size, seq_length, d_model)
    """
    return W_emb[x]

def unembedding(x, W_unemb):
    """
    Unembeds the output of the transformer using an unembedding matrix.
    
    Parameters:
    - x: Output matrix of shape (batch_size, seq_length, d_model)
    - W_unemb: Unembedding matrix of shape (d_model, vocab_size)
    
    Returns:
    - Unembedded output of shape (batch_size, seq_length, vocab_size)
    """
    return np.dot(x, W_unemb)

# Define dimensions
batch_size = 2
seq_len = 3
d_model = 12
num_heads = 3
d_head = 4
vocab_size = 20

# Initialize input, embedding matrix, and unembedding matrix
X = np.random.randint(vocab_size, size=(batch_size, seq_len))

# Randomly initialize weight matrices
W_Q = np.random.rand(d_model, d_model)
W_K = np.random.rand(d_model, d_model)
W_V = np.random.rand(d_model, d_model)
W_O = np.random.rand(d_model, d_model)

W_emb = np.random.rand(vocab_size, d_model)
W_unemb = np.random.rand(d_model, vocab_size)

# Embed the input
embedded_X = embedding(X, W_emb)

# Apply the multi-head attention mechanism
attention_output = batched_classic_multi_head_attention(embedded_X, W_Q, W_K, W_V, W_O, num_heads)
print("Original multi-head attn output shape before unembed:", attention_output.shape)

# Unembed the output
final_output = unembedding(attention_output, W_unemb)
print("After unembed:", final_output.shape)
print(final_output)

# Compute the output after refactored multi-head attention
refactored_output = batched_refactored_multi_head_attention(embedded_X, W_Q, W_K, W_V, W_O, num_heads, d_head)
print("Refactored multi-head attn output shape before unembed:", refactored_output.shape)
final_refactored_output = unembedding(refactored_output, W_unemb)

print("After unembed:", final_refactored_output.shape)
print(final_refactored_output)

assert np.allclose(final_output, final_refactored_output), "The refactored function produces different results!"


Original multi-head attn output shape before unembed: (2, 3, 12)
After unembed: (2, 3, 20)
[[[146.81187383 163.95858761 134.85684915 154.77074487 147.7483432
   183.25228493 209.37117439 163.32475811 144.59898727 125.35787835
   117.17723472 143.57535574 176.25787664 167.85282957 151.87931502
   124.01945561 155.22062407 161.30247235 171.30248467 177.80024107]
  [146.88605344 164.04509407 134.92249762 154.85118363 147.81875626
   183.35077024 209.48031051 163.40620493 144.67221005 125.42087017
   117.23783112 143.65309427 176.3454822  167.94105749 151.95675649
   124.08419015 155.30505857 161.38253771 171.39291331 177.88518633]
  [146.46580151 163.55742799 134.54858107 154.39717396 147.41749169
   182.79696926 208.86486401 162.94305965 144.25748274 125.06482097
   116.8960692  143.21650813 175.84878217 167.44471323 151.51912288
   123.71903293 154.83066532 160.92819316 170.8837564  177.40123355]]

 [[111.28055273 123.94937867 101.8669829  117.02385175 111.7720689
   138.80348445 158.43

In [11]:
import numpy as np

def softmax(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

def split_heads(x, num_heads):
    """
    Splits the last dimension of x into (num_heads, d_head) for a batch of sequences.
    """
    batch_size, seq_length, d_model = x.shape
    d_head = d_model // num_heads
    x = x.reshape(batch_size, seq_length, num_heads, d_head)
    return x.transpose(0, 2, 1, 3)  # (batch_size, num_heads, seq_length, d_head)

def batched_classic_multi_head_attention(x, W_Q, W_K, W_V, W_O, num_heads):
    """
    Implements multi-head attention mechanism with weight matrices using np.einsum for a batch of sequences.
    
    Parameters:
    - x: Input matrix of shape (batch_size, seq_length, d_model)
    - W_O, W_V, W_K, W_Q: Weight matrices, each of shape (d_model, d_model)
    - num_heads: Number of attention heads
    
    Returns:
    - Output matrix after attention and linear transformations, of shape (batch_size, seq_length, d_model)
    """
    batch_size, seq_length, d_model = x.shape
    d_head = d_model // num_heads
    sqrt_d_head = np.sqrt(d_head)
    
    # Compute Query, Key, Value matrices for each head
    Q = split_heads(np.dot(x, W_Q), num_heads)  # (batch_size, num_heads, seq_length, d_head)
    K = split_heads(np.dot(x, W_K), num_heads)  # (batch_size, num_heads, seq_length, d_head)
    V = split_heads(np.dot(x, W_V), num_heads)  # (batch_size, num_heads, seq_length, d_head)
    
    # Compute attention scores and weights for each head
    attention_scores = np.einsum('bnqd,bnkd->bnqk', Q, K)  # (batch_size, num_heads, seq_length, seq_length)

    # Apply autoregressive masking
    mask = np.tril(np.ones((seq_length, seq_length)), k=0)
    attention_scores = attention_scores - ((1 - mask) * 1e9)
    print(f"{attention_scores[0]=}")

    attention_weights = softmax(attention_scores / sqrt_d_head)
    
    # Compute weighted sum of value vectors for each head
    weighted_sum = np.einsum('bnqk,bnkd->bnqd', attention_weights, V)  # (batch_size, num_heads, seq_length, d_head)
    
    # Concatenate heads and apply the output weight matrix W_O
    weighted_sum = weighted_sum.transpose(0, 2, 1, 3).reshape(batch_size, seq_length, d_model)  # (batch_size, seq_length, d_model)
    output = np.dot(weighted_sum, W_O)
    
    return output


def batched_refactored_multi_head_attention(X, W_Q, W_K, W_V, W_O, num_heads, d_head):
    """
    Implements refactored multi-head attention mechanism for a batch of sequences.

    Parameters:
    - X: Input matrix of shape (batch_size, seq_len, d_model)
    - W_Q, W_K, W_V: Weight matrices for Query, Key, and Value, each of shape (d_model, d_model)
    - W_O: Output weight matrix of shape (d_model, d_model)
    - num_heads: Number of attention heads
    - d_head: Dimension of each head

    Returns:
    - Output matrix after attention, of shape (batch_size, seq_len, d_model)
    """
    batch_size, seq_len, d_model = X.shape
    sqrt_d_head = np.sqrt(d_head)
    
    # Initialize the result matrix
    result = np.zeros((batch_size, seq_len, d_model))

    # Autoregressive mask
    mask = np.tril(np.ones((seq_len, seq_len)), k=0).astype(bool)
    inverted_mask = np.logical_not(mask)

    for i in range(num_heads):
        # Extract the weight matrices for the current head
        W_Q_i = W_Q[:, i * d_head:(i + 1) * d_head]
        W_K_i = W_K[:, i * d_head:(i + 1) * d_head]
        W_V_i = W_V[:, i * d_head:(i + 1) * d_head]
        W_O_i = W_O[i * d_head:(i + 1) * d_head, :]

        # Compute the combined Query-Key weight matrix for the current head
        W_QK_i = np.dot(W_Q_i, W_K_i.T) # (d_model, d_model)

        # Compute the attention pattern
        XWQ = np.einsum('bsd,de->bse', X, W_QK_i) # (batch_size, seq_len, d_model)
        A_scores = np.einsum('bqd,bkd->bqk', XWQ, X)
        
        # Apply autoregressive masking
        
        A_scores[:, inverted_mask] = -1e9
        print(f"{A_scores[0]=}")

        # compute attention weights
        A = softmax(A_scores / sqrt_d_head) # (batch_size, seq_len, seq_len)

        # Compute the combined Value-Output weight matrix for the current head
        W_VO_i = np.dot(W_V_i, W_O_i) # (d_model, d_model)
        #print(f"{W_VO_i.shape=}")

        # Compute the result for the current head and add it to the overall result
        XWV = np.einsum('bsd,de->bse', X, W_VO_i)
        #print(f"{XWV.shape=}")
        
        r_i = np.einsum('bqk,bkd->bqd', A, XWV)
        # print(f"{r_i=}")
        # break
        result += r_i

    return result


def embedding(x, W_emb):
    """
    Embeds the input sequence using an embedding matrix.
    
    Parameters:
    - x: Input matrix of shape (batch_size, seq_length)
    - W_emb: Embedding matrix of shape (vocab_size, d_model)
    
    Returns:
    - Embedded input of shape (batch_size, seq_length, d_model)
    """
    return W_emb[x]

def unembedding(x, W_unemb):
    """
    Unembeds the output of the transformer using an unembedding matrix.
    
    Parameters:
    - x: Output matrix of shape (batch_size, seq_length, d_model)
    - W_unemb: Unembedding matrix of shape (d_model, vocab_size)
    
    Returns:
    - Unembedded output of shape (batch_size, seq_length, vocab_size)
    """
    return np.dot(x, W_unemb)

# Define dimensions
batch_size = 2
seq_len = 3
d_model = 12
num_heads = 3
d_head = 4
vocab_size = 20

# Initialize input, embedding matrix, and unembedding matrix
X = np.random.randint(vocab_size, size=(batch_size, seq_len))

# Randomly initialize weight matrices
W_Q = np.random.rand(d_model, d_model)
W_K = np.random.rand(d_model, d_model)
W_V = np.random.rand(d_model, d_model)
W_O = np.random.rand(d_model, d_model)

W_emb = np.random.rand(vocab_size, d_model)
W_unemb = np.random.rand(d_model, vocab_size)

# Embed the input
embedded_X = embedding(X, W_emb)

# Apply the multi-head attention mechanism
attention_output = batched_classic_multi_head_attention(embedded_X, W_Q, W_K, W_V, W_O, num_heads)
print("Original multi-head attn output shape before unembed:", attention_output.shape)

# Unembed the output
final_output = unembedding(attention_output, W_unemb)
print("After unembed:", final_output.shape)
#print(final_output)

# Compute the output after refactored multi-head attention
refactored_output = batched_refactored_multi_head_attention(embedded_X, W_Q, W_K, W_V, W_O, num_heads, d_head)
print("Refactored multi-head attn output shape before unembed:", refactored_output.shape)
final_refactored_output = unembedding(refactored_output, W_unemb)

print("After unembed:", final_refactored_output.shape)
#print(final_refactored_output)

assert np.allclose(final_output, final_refactored_output), "The refactored function produces different results!"


attention_scores[0]=array([[[ 1.06025957e+01, -9.99999982e+08, -9.99999988e+08],
        [ 2.29134931e+01,  3.92934814e+01, -9.99999975e+08],
        [ 1.25861977e+01,  2.17198193e+01,  1.36278582e+01]],

       [[ 1.59389491e+01, -9.99999975e+08, -9.99999986e+08],
        [ 2.82133068e+01,  4.45525559e+01, -9.99999974e+08],
        [ 1.74981948e+01,  2.80794414e+01,  1.60657226e+01]],

       [[ 1.57744620e+01, -9.99999972e+08, -9.99999986e+08],
        [ 2.85692821e+01,  5.09878051e+01, -9.99999974e+08],
        [ 1.35442650e+01,  2.36867968e+01,  1.19094057e+01]]])
Original multi-head attn output shape before unembed: (2, 3, 12)
After unembed: (2, 3, 20)
A_scores[0]=array([[ 1.06025957e+01, -1.00000000e+09, -1.00000000e+09],
       [ 2.29134931e+01,  3.92934814e+01, -1.00000000e+09],
       [ 1.25861977e+01,  2.17198193e+01,  1.36278582e+01]])
A_scores[0]=array([[ 1.59389491e+01, -1.00000000e+09, -1.00000000e+09],
       [ 2.82133068e+01,  4.45525559e+01, -1.00000000e+09],
       [ 