In [3]:
import numpy as np

def softmax(x, axis=-1):
    """Compute softmax values for each row of x."""
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / e_x.sum(axis=axis, keepdims=True)

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Fixed implementation of scaled dot-product attention.
    
    Args:
        Q: Queries matrix (batch_size, seq_len_q, d_k)
        K: Keys matrix (batch_size, seq_len_k, d_k)
        V: Values matrix (batch_size, seq_len_k, d_v)
        mask: Optional mask (batch_size, seq_len_q, seq_len_k) or (batch_size, 1, seq_len_k)
    
    Returns:
        output: Attention-weighted values (batch_size, seq_len_q, d_v)
        attention_weights: Softmax scores (batch_size, seq_len_q, seq_len_k)
    """
    # Compute dot products between queries and keys
    matmul_qk = np.matmul(Q, K.transpose(0, 2, 1))  # (batch_size, seq_len_q, seq_len_k)
    
    # Scale by square root of key dimension
    d_k = K.shape[-1]
    scaled_attention_logits = matmul_qk / np.sqrt(d_k)
    
    # Apply mask (if provided)
    if mask is not None:
        # Expand mask dimensions if needed
        if len(mask.shape) == 2:
            mask = np.expand_dims(mask, 1)  # (batch_size, 1, seq_len_k)
        scaled_attention_logits += (mask * -1e9)  # Large negative for masked positions
    
    # Softmax to get attention weights
    attention_weights = softmax(scaled_attention_logits, axis=-1)  # (batch_size, seq_len_q, seq_len_k)
    
    # Weighted sum of values
    output = np.matmul(attention_weights, V)  # (batch_size, seq_len_q, d_v)
    
    return output, attention_weights

# Example dimensions
batch_size = 2
seq_len = 4
d_k = 8  # depth of Q, K
d_v = 6  # depth of V

# Random queries, keys, values
Q = np.random.randn(batch_size, seq_len, d_k)
K = np.random.randn(batch_size, seq_len, d_k)
V = np.random.randn(batch_size, seq_len, d_v)

# Correct mask shape (batch_size, seq_len_q, seq_len_k)
mask = np.array([
    [[1, 1, 0, 0],  # First sequence
     [1, 1, 0, 0],
     [1, 1, 0, 0],
     [1, 1, 0, 0]],
    
    [[1, 1, 1, 0],   # Second sequence
     [1, 1, 1, 0],
     [1, 1, 1, 0],
     [1, 1, 1, 0]]
])

# Or simpler broadcastable mask (batch_size, 1, seq_len_k)
mask = np.array([
    [[1, 1, 0, 0]],  # First sequence (applies to all queries)
    
    [[1, 1, 1, 0]]   # Second sequence (applies to all queries)
])

# Compute attention
output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)

print("Attention weights shape:", attention_weights.shape)
print("Output shape:", output.shape)
print("\nSample attention weights (first sequence, first query):")
print(attention_weights[0, 0])

Attention weights shape: (2, 4, 4)
Output shape: (2, 4, 6)

Sample attention weights (first sequence, first query):
[0.         0.         0.89447968 0.10552032]
