<a href="https://colab.research.google.com/github/luizvalle/TransformerFromScratch/blob/main/TransformersFromScratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [85]:
import numpy as np


In [107]:
def attention(X, Y, W_Q, W_K, W_V):
    num_input_tokens, dim_x = X.shape
    num_context_tokens, dim_y = Y.shape
    w_q_in_dim, w_q_out_dim = W_Q.shape
    w_k_in_dim, w_k_out_dim = W_K.shape
    w_v_in_dim, w_v_out_dim = W_V.shape

    assert dim_x == w_q_in_dim
    assert dim_y == w_k_in_dim
    assert dim_y == w_v_in_dim
    assert w_q_out_dim == w_k_out_dim
    assert w_v_out_dim == w_q_out_dim

    Q = X @ W_Q
    K = Y @ W_K
    V = Y @ W_V

    A = np.exp(Q @ K.T / np.sqrt(w_q_out_dim))

    assert A.shape[0] == num_input_tokens
    assert A.shape[1] == num_context_tokens

    S = A.sum(axis=1) # Sum each row
    A_bar = A / S[:,np.newaxis] # Normalized A

    assert all(A_bar.sum(axis=1).round(3) == 1) # Ensure all rows add up to 1

    X_prime = A @ V

    assert X_prime.shape[0] == num_input_tokens
    assert X_prime.shape[1] == w_v_out_dim

    return X_prime

In [106]:
X = np.array([[0, 1], [1, 2]])
Y = np.array([[1, 4], [1, -2]])
W_Q = np.array([[1, 1, 2], [-1, -1, 2]])
W_K = np.array([[0, 1, 0], [9, -4, 2]])
W_V = np.array([[0, 1, 2], [-8, -1, 2]])
attention(X, Y, W_Q, W_K, W_V)

array([[ 2.67167691e+01,  5.17666536e+00, -3.00505381e+00],
       [-1.88421715e+08, -1.76645357e+07,  5.88817858e+07]])