Implement masked self-attention, a variation of the attention mechanism used in sequence modeling tasks such as text generation. Your task is to compute masked self-attention using query (Q), key (K), value (V) matrices and an attention mask.

Example:
Input:
masked_attention(Q, K, V, mask)
Output:
[[547. 490. 399. 495. 485. 439. 645. 393.]
 [547. 490. 399. 495. 485. 439. 645. 393.]
 [471. 472. 429. 538. 377. 450. 531. 362.]
 [471. 472. 429. 538. 377. 450. 531. 362.]
 [471. 472. 429. 538. 377. 450. 531. 362.]
 [471. 472. 429. 538. 377. 450. 531. 362.]]
Reasoning:
The function computes self-attention by applying a mask to restrict information flow, ensuring causal dependencies are maintained.



In [None]:
import numpy as np

def compute_qkv(X: np.ndarray, W_q: np.ndarray, W_k: np.ndarray, W_v: np.ndarray):
    Q = np.dot(X, W_q)
    K = np.dot(X, W_k)
    V = np.dot(X, W_v)
    return Q, K, V

def masked_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, mask: np.ndarray) -> np.ndarray:
    d_k = Q.shape[1]
    scores = np.matmul(Q, K.T) / np.sqrt(d_k)
    scores = scores + mask  # Apply mask
    attention_weights = np.exp(scores - np.max(scores, axis=1, keepdims=True)) #Subtracts max for numerical stability
    attention_weights = attention_weights / np.sum(attention_weights, axis=1, keepdims=True)
    return np.matmul(attention_weights, V)