In [35]:
import numpy as np

In [51]:
def softmax(x):
    x_exp = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return x_exp / x_exp.sum(axis=-1, keepdims=True)

In [58]:
def multi_headed_attention(X, Wq, Wk, Wv, Wo, num_heads, mask=None):
    """
    X : n_samples, d_model
    Wq, Wk, Wv: d_model, d_model
    """
    n_samples, d_model = X.shape

    assert d_model % num_heads == 0

    d_head = d_model // num_heads

    Q = X @ Wq # n_samples, d_model
    K = X @ Wk
    V = X @ Wv

    Q = Q.reshape(n_samples, num_heads, d_head).transpose(1, 0, 2) # num_heads, n_samples, d_head
    K = K.reshape(n_samples, num_heads, d_head).transpose(1, 0, 2) # num_heads, n_samples, d_head
    V = V.reshape(n_samples, num_heads, d_head).transpose(1, 0, 2) # num_heads, n_samples, d_head

    scores = Q @ K.transpose(0, 2, 1) / np.sqrt(d_head) # num_heads, n_samples, n_samples

    if mask:
        causal_mask = np.triu(np.ones((n_samples, n_samples)), k = 1)
        scores[:, causal_mask == 1] = -np.inf

    weights = softmax(scores) # num_heads, n_samples, n_samples

    output = weights @ V # num_heads, n_samples, d_head

    concat = output.transpose(1, 0, 2),reshape(n_samples, d_model)

    output = concat @ Wo

    return output, weights
    