In [1]:
import numpy as np

def self_attention_simple(X, WQ, WK, WV):
    """"
    self-attention 
    X: Input matrix (n x d)
    WQ: Query weights (d x d)
    WK: Key weights (d x d)
    WV: Value weights (d x d)
    Returns: Attention output (n x d)
    """
    #Compute Q, K, V in one step
    Q = X @ WQ
    K = X @ WK
    V = X @ WV
    
    #Compute attention scores and apply softmax to convert into vectpr  of probabilities
    d_k = Q.shape[1]
    scores = (Q @ K.T) / np.sqrt(d_k) #dot product of query and key pair dividing by the dimension of key vectors
    weights = np.exp(scores) / np.sum(np.exp(scores), axis=1, keepdims=True)
    
    #Compute weighted sum of values
    return weights @ V

#Example
X = np.array([[1, 0], [0, 1]])
WQ = np.array([[1, 0], [0, 1]])
WK = np.array([[1, 0], [0, 1]])
WV = np.array([[1, 2], [3, 4]])

output = self_attention_simple(X, WQ, WK, WV)
print("Self-Attention Output:\n", output)

Self-Attention Output:
 [[1.6604769 2.6604769]
 [2.3395231 3.3395231]]
