In [1]:
import numpy as np
np.random.seed(42)

In [2]:
def softmax(x):
  e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
  return e_x / e_x.sum(axis=-1, keepdims=True)

In [3]:
def scaled_dot_product_attention(Q, K, V):
  score = np.dot(Q, K.T)
  d_k = K.shape[-1]
  scaled_score = score / np.sqrt(d_k)
  weights = softmax(scaled_score)
  output = np.dot(weights, V)

  return output, weights

In [4]:
X = np.random.rand(2, 3)
W_q = np.random.rand(X.shape[-1], X.shape[-1])
W_k = np.random.rand(X.shape[-1], X.shape[-1])
W_v = np.random.rand(X.shape[-1], X.shape[-1])

Q = np.dot(X, W_q)
K = np.dot(X, W_k)
V = np.dot(X, W_v)

# Attention
output, weights = scaled_dot_product_attention(Q, K, V)

In [5]:
print(f"X: {X}")
print(f"Output: {output}")
print(f"Weights: {weights}")

# X: [[0.37454012 0.95071431 0.73199394]
#     [0.59865848 0.15601864 0.15599452]]
# Output: [[0.90262448 0.86128361 0.15745389]
#           [0.83269156 0.8194087  0.15429674]]
# Weights: [[0.69255977 0.30744023]
#           [0.58601833 0.41398167]]

X: [[0.37454012 0.95071431 0.73199394]
 [0.59865848 0.15601864 0.15599452]]
Output: [[0.90262448 0.86128361 0.15745389]
 [0.83269156 0.8194087  0.15429674]]
Weights: [[0.69255977 0.30744023]
 [0.58601833 0.41398167]]
