Implement masked self-attention, a variation of the attention mechanism used in sequence modeling tasks such as text generation. Your task is to compute masked self-attention using query (Q), key (K), value (V) matrices and an attention mask.

Example:
Input:
masked_attention(Q, K, V, mask)
Output:
[[547. 490. 399. 495. 485. 439. 645. 393.]
 [547. 490. 399. 495. 485. 439. 645. 393.]
 [471. 472. 429. 538. 377. 450. 531. 362.]
 [471. 472. 429. 538. 377. 450. 531. 362.]
 [471. 472. 429. 538. 377. 450. 531. 362.]
 [471. 472. 429. 538. 377. 450. 531. 362.]]
Reasoning:
The function computes self-attention by applying a mask to restrict information flow, ensuring causal dependencies are maintained.



In [None]:
import numpy as np

def masked_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, mask: np.ndarray) -> np.ndarray:
    """
    Q,K,V: shape (T,D) or (B,T,D)
    mask: same score shape as the attention scores:
          (T,T) for 2D inputs, or (B,1,T,T) or (B,T,T) for batched (will broadcast).
          Values: 0 for allowed, -1e9 (or -np.inf) for masked.
    Returns: same leading dims as Q, last dim = D (i.e., (T,D) or (B,T,D))
    """
    # handle 2D -> 3D for unified math
    added_batch = False
    if Q.ndim == 2:
        Q = Q[None, ...]; K = K[None, ...]; V = V[None, ...]
        added_batch = True

    B,T,D = Q.shape
    d_k = D

    # scores: (B,T,T)
    scores = np.matmul(Q, np.swapaxes(K, -1, -2)) / np.sqrt(d_k)

    # apply mask (broadcast OK). mask zeros keep scores; large negative blocks
    scores = scores + mask  # e.g., mask shape (T,T) or (B,1,T,T) or (B,T,T)

    # stable softmax over last dim (keys)
    scores_shift = scores - np.max(scores, axis=-1, keepdims=True)
    attn = np.exp(scores_shift)
    attn /= np.sum(attn, axis=-1, keepdims=True)

    # output: (B,T,D)
    out = np.matmul(attn, V)

    if added_batch:
        out = out[0]
    return out

def causal_mask(T: int, *, batch: int | None = None):
    """
    Lower-triangular (causal) mask.
    Returns (T,T) if batch=None, else (batch,1,T,T) for per-batch broadcast.
    """
    m = np.triu(np.ones((T,T), dtype=np.float64), k=1)  # 1 above diagonal
    m = np.where(m==1, -1e9, 0.0)
    if batch is None:
        return m
    return m[None, None, ...].repeat(batch, axis=0)

# Example (2D, causal)
T,D = 6, 8
np.random.seed(0)
Q = np.random.randn(T,D)
K = np.random.randn(T,D)
V = np.random.randn(T,D)
mask = causal_mask(T)                # (T,T) with 0 or -1e9
out = masked_attention(Q,K,V,mask)
print(out.shape)  # (6,8)

# Example (batched)
B,T,D = 2, 6, 8
Qb = np.random.randn(B,T,D); Kb = np.random.randn(B,T,D); Vb = np.random.randn(B,T,D)
mask_b = causal_mask(T, batch=B)     # (B,1,T,T), broadcasts over heads if you add them later
out_b = masked_attention(Qb,Kb,Vb,mask_b)
print(out_b.shape)  # (2,6,8)

In [None]:
import numpy as np

def masked_attention(Q, K, V, mask):
    # Q,K,V: (T,D); mask: (T,T) with 0 for allowed, -1e9 (or -np.inf) to block
    d_k = Q.shape[1]
    scores = (Q @ K.T) / np.sqrt(d_k)            # (T,T)
    scores = scores + mask                       # apply mask
    scores -= np.max(scores, axis=1, keepdims=True)  # stable softmax
    attn = np.exp(scores)
    attn /= np.sum(attn, axis=1, keepdims=True)  # row-wise softmax
    return attn @ V                              # (T,D)

def causal_mask(T):
    # disallow attending to future positions (upper triangle)
    m = np.triu(np.ones((T,T), dtype=np.float64), k=1)
    return np.where(m==1, -1e9, 0.0)

# example
np.random.seed(0)
T,D = 6,8
Q = np.random.randn(T,D)
K = np.random.randn(T,D)
V = np.random.randn(T,D)
mask = causal_mask(T)
out = masked_attention(Q,K,V,mask)
print(out)