In [4]:
import numpy as np

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def single_head_attention_with_weights(x, W_O, W_V, W_K, W_Q):
    """
    Implements single-head attention mechanism with weight matrices.
    
    Parameters:
    - x: Input matrix of shape (n, d_model)
    - W_O, W_V, W_K, W_Q: Weight matrices, each of shape (d_model, d_model)
    
    Returns:
    - Output matrix after attention and linear transformations, of shape (n, d_model)
    """
    
    # Compute Query, Key, Value matrices from input and corresponding weight matrices (n x d_model)
    Q = np.dot(x, W_Q)
    K = np.dot(x, W_K)
    V = np.dot(x, W_V)
    
    # Compute attention scores (n x n)
    attention_scores = np.dot(Q, K.T)
    
    # Apply softmax to get attention distribution
    attention_weights = softmax(attention_scores)
    
    # Compute weighted sum of value vectors (n x d_model)
    weighted_sum = np.dot(attention_weights, V)
    
    # Apply the output weight matrix W_O (n x d_model)
    output = np.dot(weighted_sum, W_O)
    
    return output

def single_head_attention_refactored(x, W_O, W_V, W_K, W_Q):
    """
    Implements single-head attention mechanism with weight matrices.
    
    Parameters:
    - x: Input matrix of shape (n, d_model)
    - W_O, W_V, W_K, W_Q: Weight matrices, each of shape (d_model, d_model)
    
    Returns:
    - Output matrix after attention and linear transformations, of shape (n, d_model)
    """
    
    # Compute Query, Key, Value matrices from input and corresponding weight matrices (n x d_model)
    Q = np.dot(x, W_Q)
    K = np.dot(x, W_K)
    
    # Compute attention scores (n x n)
    attention_scores = np.dot(Q, K.T)
    
    # Apply softmax to get attention distribution
    A = softmax(attention_scores)

    # Get W_V W_O transformation matrix
    W_V_O = np.dot(W_V, W_O)
    
    # Compute output matrix
    result = np.dot(A, np.dot(x, W_V_O))
    
    return result

# Test the function
n = 3  # Number of tokens
d_model = 12  # Dimensionality of each token

# Randomly initialize input matrix and weight matrices
x = np.random.rand(n, d_model)
W_O = np.random.rand(d_model, d_model)
W_V = np.random.rand(d_model, d_model)
W_K = np.random.rand(d_model, d_model)
W_Q = np.random.rand(d_model, d_model)

# Compute the output after single-head attention and linear transformations
output = single_head_attention_with_weights(x, W_O, W_V, W_K, W_Q)
print("Output matrix shape:", output.shape)
print(output)

refactored_output = single_head_attention_refactored(x, W_O, W_V, W_K, W_Q)
print("Refactored output matrix shape:", refactored_output.shape)
print(refactored_output)

assert np.allclose(output, refactored_output), "The refactored function produces different results!"

Output matrix shape: (3, 12)
[[4.34668214e+01 3.94642356e+01 5.56716441e+01 5.66825338e+01
  3.56639409e+01 3.73573684e+01 3.93370176e+01 4.95249003e+01
  3.57495586e+01 5.15701875e+01 4.58689552e+01 4.10114466e+01]
 [3.26192644e-16 3.02689267e-16 4.26371550e-16 4.34286530e-16
  2.76677853e-16 2.84378781e-16 3.08742926e-16 3.88121678e-16
  2.80779464e-16 4.03171632e-16 3.56820676e-16 3.20965314e-16]
 [5.32273872e-07 4.93918127e-07 6.95735281e-07 7.08652029e-07
  4.51468684e-07 4.64039867e-07 5.03788292e-07 6.33314413e-07
  4.58159905e-07 6.57869505e-07 5.82244275e-07 5.23731237e-07]]
Refactored output matrix shape: (3, 12)
[[4.34668214e+01 3.94642356e+01 5.56716441e+01 5.66825338e+01
  3.56639409e+01 3.73573684e+01 3.93370176e+01 4.95249003e+01
  3.57495586e+01 5.15701875e+01 4.58689552e+01 4.10114466e+01]
 [3.26192644e-16 3.02689267e-16 4.26371550e-16 4.34286530e-16
  2.76677853e-16 2.84378781e-16 3.08742926e-16 3.88121678e-16
  2.80779464e-16 4.03171632e-16 3.56820676e-16 3.20965314e

In [5]:
# Function to debug and compare each step in the refactored and einsum functions
def debug_attention_steps(x, W_O, W_V, W_K, W_Q):
    print("Debugging Steps")
    
    # ---- Refactored Function ----
    print("Refactored Function")
    Q_ref = np.dot(x, W_Q)
    K_ref = np.dot(x, W_K)
    attention_scores_ref = np.dot(Q_ref, K_ref.T)
    A_ref = softmax(attention_scores_ref)
    W_V_O_ref = np.dot(W_V, W_O)
    result_ref = np.dot(A_ref, np.dot(x, W_V_O_ref))
    print("Q_ref:", Q_ref)
    print("K_ref:", K_ref)
    print("attention_scores_ref:", attention_scores_ref)
    print("A_ref:", A_ref)
    print("W_V_O_ref:", W_V_O_ref)
    print("result_ref:", result_ref)
    
    # ---- Einsum Function ----
    print("\nEinsum Function")
    Q_ein = np.einsum('nd,df->nf', x, W_Q)
    K_ein = np.einsum('nd,df->nf', x, W_K)
    attention_scores_ein = np.einsum('nf,mf->nm', Q_ein, K_ein)
    A_ein = softmax(attention_scores_ein)
    W_V_O_ein = np.einsum('df,fg->dg', W_V, W_O)
    result_ein = np.einsum('nm,nd,dg->ng', A_ein, x, W_V_O_ein)
    print("Q_ein:", Q_ein)
    print("K_ein:", K_ein)
    print("attention_scores_ein:", attention_scores_ein)
    print("A_ein:", A_ein)
    print("W_V_O_ein:", W_V_O_ein)
    print("result_ein:", result_ein)
    
    # Compare each step
    print("\nComparison Results")
    print("Q matrices close:", np.allclose(Q_ref, Q_ein))
    print("K matrices close:", np.allclose(K_ref, K_ein))
    print("Attention scores close:", np.allclose(attention_scores_ref, attention_scores_ein))
    print("A matrices close:", np.allclose(A_ref, A_ein))
    print("W_V_O matrices close:", np.allclose(W_V_O_ref, W_V_O_ein))
    print("Final result matrices close:", np.allclose(result_ref, result_ein))

# Debugging with a new set of random weights and input
x = np.random.rand(n, d_model)
W_O = np.random.rand(d_model, d_model)
W_V = np.random.rand(d_model, d_model)
W_K = np.random.rand(d_model, d_model)
W_Q = np.random.rand(d_model, d_model)

# Run the debugging function
debug_attention_steps(x, W_O, W_V, W_K, W_Q)

Debugging Steps
Refactored Function
Q_ref: [[2.74164853 2.16543522 3.11395906 1.81372275 3.61045065 2.25350768
  2.4131675  1.75332325 2.64397586 2.36416384 3.40024893 2.80104433]
 [3.86351318 3.06026361 3.73121894 2.31102605 4.31653228 2.19899644
  3.5773023  2.39155238 3.38761674 3.42601296 4.26252272 3.38465289]
 [3.12868281 2.78401292 2.6377517  2.07847662 3.97412076 2.39345565
  2.89580702 2.36366808 2.98321191 3.41700392 3.74121058 2.30881082]]
K_ref: [[2.55193091 2.28271    3.0672082  3.19981717 1.79893658 2.38405286
  3.20266548 3.22662727 3.01204698 2.64811881 2.62990482 2.91258509]
 [3.18274375 3.12399677 3.44239717 3.91213621 2.6854836  3.70796548
  3.30034205 4.54121761 4.24141233 3.5039551  2.9552325  4.00057431]
 [2.89827177 3.22737594 2.62614944 3.39441195 2.47542763 2.55954154
  2.85613994 3.43214411 3.47064866 3.59597897 2.54266209 3.22358044]]
attention_scores_ref: [[ 83.87261953 108.03645489  92.23718861]
 [108.20991025 138.66504892 119.28251361]
 [ 93.43469087 120.6

In [6]:
# Debugging to focus only on the final result
def debug_attention_final_result(x, W_O, W_V, W_K, W_Q):
    print("Debugging Final Result")
    
    # ---- Refactored Function ----
    A_ref = softmax(np.dot(np.dot(x, W_Q), np.dot(x, W_K).T))
    W_V_O_ref = np.dot(W_V, W_O)
    result_ref = np.dot(A_ref, np.dot(x, W_V_O_ref))
    print("result_ref:", result_ref)
    
    # ---- Einsum Function ----
    Q_ein = np.einsum('nd,df->nf', x, W_Q)
    K_ein = np.einsum('nd,df->nf', x, W_K)
    A_ein = softmax(np.einsum('nf,mf->nm', Q_ein, K_ein))
    W_V_O_ein = np.einsum('df,fg->dg', W_V, W_O)
    result_ein = np.einsum('nm,nd,dg->ng', A_ein, x, W_V_O_ein)
    print("result_ein:", result_ein)
    
    # Compare final result
    print("\nComparison Results")
    print("Final result matrices close:", np.allclose(result_ref, result_ein))

# Debugging with a new set of random weights and input
x = np.random.rand(n, d_model)
W_O = np.random.rand(d_model, d_model)
W_V = np.random.rand(d_model, d_model)
W_K = np.random.rand(d_model, d_model)
W_Q = np.random.rand(d_model, d_model)

# Run the debugging function
debug_attention_final_result(x, W_O, W_V, W_K, W_Q)

Debugging Final Result
result_ref: [[3.68681062e+01 4.15191810e+01 5.20266937e+01 2.58056177e+01
  5.47430686e+01 6.56202572e+01 7.10332117e+01 4.93797984e+01
  5.11934269e+01 4.91536849e+01 4.21172514e+01 3.58321896e+01]
 [4.83510397e-05 5.48072588e-05 6.92200023e-05 3.35715348e-05
  7.26165553e-05 8.65084539e-05 9.41101726e-05 6.51617943e-05
  6.73264760e-05 6.52249480e-05 5.61023288e-05 4.75213558e-05]
 [2.81859197e-07 3.21934515e-07 4.05123624e-07 1.94783350e-07
  4.25432524e-07 5.04764577e-07 5.50433505e-07 3.81109919e-07
  3.90191765e-07 3.82340383e-07 3.27098458e-07 2.78386508e-07]]
result_ein: [[4.09144465e+01 4.56754279e+01 5.62733940e+01 2.90068854e+01
  5.95864664e+01 7.21382600e+01 7.74379842e+01 5.42451985e+01
  5.63564979e+01 5.34810083e+01 4.53639004e+01 3.90167095e+01]
 [5.26446504e-05 5.83653105e-05 7.47315118e-05 3.70046236e-05
  7.80836877e-05 9.40544617e-05 1.01696505e-04 7.03420156e-05
  7.46888828e-05 7.00189700e-05 6.13162905e-05 5.11094287e-05]
 [2.66467313e-07 

In [7]:
# Debugging by trying alternative ways of performing the final multiplication
def debug_attention_alternative_approaches(x, W_O, W_V, W_K, W_Q):
    print("Debugging with Alternative Approaches")
    
    # ---- Common Steps ----
    # Compute the attention matrix A for both versions
    Q = np.dot(x, W_Q)
    K = np.dot(x, W_K)
    A = softmax(np.dot(Q, K.T))
    
    # Compute W_V_O transformation matrix for both versions
    W_V_O = np.dot(W_V, W_O)
    
    # ---- Approach 1: Using np.matmul for final multiplication ----
    x_transform_1 = np.dot(x, W_V_O)
    result_1 = np.matmul(A, x_transform_1)
    
    # ---- Approach 2: Using np.einsum for final multiplication ----
    x_transform_2 = np.einsum('nd,dg->ng', x, W_V_O)
    result_2 = np.einsum('nm,mg->ng', A, x_transform_2)
    
    # ---- Approach 3: Using np.dot for intermediate steps and then np.matmul ----
    x_transform_3 = np.dot(x, W_V_O)
    result_3 = np.matmul(A, x_transform_3)
    
    # ---- Approach 4: Using np.einsum for intermediate steps and then np.dot ----
    x_transform_4 = np.einsum('nd,dg->ng', x, W_V_O)
    result_4 = np.dot(A, x_transform_4)
    
    print("\nResults")
    print("Result using np.matmul in Approach 1:", result_1)
    print("Result using np.einsum in Approach 2:", result_2)
    print("Result using np.dot in intermediate steps in Approach 3:", result_3)
    print("Result using np.einsum in intermediate steps in Approach 4:", result_4)
    
    # Compare each approach
    print("\nComparison Results")
    print("Approaches 1 and 2 close:", np.allclose(result_1, result_2))
    print("Approaches 1 and 3 close:", np.allclose(result_1, result_3))
    print("Approaches 1 and 4 close:", np.allclose(result_1, result_4))
    print("Approaches 2 and 3 close:", np.allclose(result_2, result_3))
    print("Approaches 2 and 4 close:", np.allclose(result_2, result_4))
    print("Approaches 3 and 4 close:", np.allclose(result_3, result_4))

# Debugging with a new set of random weights and input
x = np.random.rand(n, d_model)
W_O = np.random.rand(d_model, d_model)
W_V = np.random.rand(d_model, d_model)
W_K = np.random.rand(d_model, d_model)
W_Q = np.random.rand(d_model, d_model)

# Run the debugging function
debug_attention_alternative_approaches(x, W_O, W_V, W_K, W_Q)

Debugging with Alternative Approaches

Results
Result using np.matmul in Approach 1: [[1.08048540e-04 1.24148156e-04 1.39292108e-04 1.36772395e-04
  1.24422408e-04 1.39042323e-04 1.30770611e-04 1.19650614e-04
  1.44268754e-04 1.17531403e-04 1.17849378e-04 1.05803953e-04]
 [4.90091778e+01 5.67871606e+01 6.47516397e+01 6.28263480e+01
  5.71823505e+01 6.38459024e+01 6.06966172e+01 5.47726209e+01
  6.60903199e+01 5.37428371e+01 5.39181156e+01 4.83328985e+01]
 [2.14384554e-16 2.46504306e-16 2.76408024e-16 2.71709242e-16
  2.47229796e-16 2.75953639e-16 2.59570572e-16 2.37588654e-16
  2.86261541e-16 2.33357345e-16 2.34195273e-16 2.10245049e-16]]
Result using np.einsum in Approach 2: [[1.08048540e-04 1.24148156e-04 1.39292108e-04 1.36772395e-04
  1.24422408e-04 1.39042323e-04 1.30770611e-04 1.19650614e-04
  1.44268754e-04 1.17531403e-04 1.17849378e-04 1.05803953e-04]
 [4.90091778e+01 5.67871606e+01 6.47516397e+01 6.28263480e+01
  5.71823505e+01 6.38459024e+01 6.06966172e+01 5.47726209e+01
  6.