In [None]:
import torch
import torch.nn as nn
import math

class MultiHeadLatentAttention(nn.Module):
    """
    Implements a Multi-Head Latent Attention (MLA) mechanism, inspired by
    concepts aiming for reduced KV cache size through latent compression
    of Keys and Values.

    Attributes:
        d_model (int): The dimensionality of the input and output features.
        n_heads (int): The number of attention heads.
        d_kv_comp (int): The latent dimension for compressed Keys and Values.
                         This is the dimension that would be stored in the KV cache.
        d_k (int): The dimensionality of each attention head for Query (d_model // n_heads).
        W_q (nn.Linear): Linear layer to project the input query to the query space.
        W_kv_down (nn.Linear): Linear layer to compress input key/value to the latent space.
        W_k_up (nn.Linear): Linear layer to project latent key back to full key space.
        W_v_up (nn.Linear): Linear layer to project latent value back to full value space.
        W_o (nn.Linear): Linear layer to project the concatenated output of all attention heads.
        dropout (nn.Dropout): Dropout layer applied to the attention weights.
    """
    def __init__(self, d_model, n_heads, d_kv_comp, dropout=0.1):
        """
        Initializes the MultiHeadLatentAttention module.

        Args:
            d_model (int): The dimensionality of the input and output features.
            n_heads (int): The number of attention heads.
            d_kv_comp (int): The latent dimension for compressed Keys and Values.
            dropout (float, optional): Dropout probability for the attention weights. Default is 0.1.

        Raises:
            AssertionError: If d_model is not divisible by n_heads.
        """
        super().__init__()
        assert d_model % n_heads == 0

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_kv_comp = d_kv_comp
        self.d_k = d_model // n_heads # Dimension per head for Query

        # Query projection remains standard
        self.W_q = nn.Linear(d_model, d_model)

        # Key and Value compression (down-projection)
        # These project from d_model (input feature dim) to d_kv_comp (latent dim)
        self.W_kv_down = nn.Linear(d_model, d_kv_comp)

        # Key and Value expansion (up-projection)
        # These project from d_kv_comp (latent dim) back to d_model for attention computation
        self.W_k_up = nn.Linear(d_kv_comp, d_model)
        self.W_v_up = nn.Linear(d_kv_comp, d_model)

        # Output projection remains standard
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        """
        Performs the forward pass of the Multi-Head Latent Attention mechanism.

        Args:
            query (torch.Tensor): The input query tensor of shape (batch_size, seq_len, d_model).
            key (torch.Tensor): The input key tensor of shape (batch_size, seq_len, d_model).
            value (torch.Tensor): The input value tensor of shape (batch_size, seq_len, d_model).
            mask (torch.Tensor, optional): The mask tensor to apply to the attention scores. Default is None.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
                - output (torch.Tensor): The output tensor of shape (batch_size, seq_len, d_model).
                - attention_weights (torch.Tensor): The attention weights tensor of shape (batch_size, n_heads, seq_len, seq_len).
        """
        batch_size = query.size(0)
        seq_len_q = query.size(1) # Sequence length for Query
        seq_len_kv = key.size(1) # Sequence length for Key/Value

        # 1. Project Query
        # Q: (batch_size, seq_len_q, d_model)
        Q = self.W_q(query)

        # 2. Compress Key and Value into latent space
        # K_latent, V_latent: (batch_size, seq_len_kv, d_kv_comp)
        # This is the representation that would be cached during inference.
        K_latent = self.W_kv_down(key)
        V_latent = self.W_kv_down(value)

        # 3. Up-project latent K and V for attention computation
        # K_up, V_up: (batch_size, seq_len_kv, d_model)
        # These are the full-dimensional K and V that will be split into heads
        K_up = self.W_k_up(K_latent)
        V_up = self.W_v_up(V_latent)

        # 4. Reshape for Multi-Head Attention
        # Q: (batch_size, n_heads, seq_len_q, d_k)
        # K_up, V_up: (batch_size, n_heads, seq_len_kv, d_k)
        Q = Q.view(batch_size, seq_len_q, self.n_heads, self.d_k).transpose(1, 2)
        K_up = K_up.view(batch_size, seq_len_kv, self.n_heads, self.d_k).transpose(1, 2)
        V_up = V_up.view(batch_size, seq_len_kv, self.n_heads, self.d_k).transpose(1, 2)

        # 5. Calculate Attention Scores
        # scores: (batch_size, n_heads, seq_len_q, seq_len_kv)
        scores = torch.matmul(Q, K_up.transpose(-2, -1)) / math.sqrt(self.d_k)

        # 6. Apply mask (e.g., causal mask for autoregressive models)
        if mask is not None:
            # Mask should be broadcastable to scores shape
            # (batch_size, 1, seq_len_q, seq_len_kv) or (1, 1, seq_len_q, seq_len_kv)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # 7. Apply Softmax and Dropout
        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # 8. Compute weighted sum of Values
        # output: (batch_size, n_heads, seq_len_q, d_k)
        output = torch.matmul(attention_weights, V_up)

        # 9. Concatenate heads and apply final linear layer
        # output: (batch_size, seq_len_q, d_model)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)
        output = self.W_o(output)

        return output, attention_weights

# Example Usage to demonstrate similarity in style and functionality:
if __name__ == "__main__":
    d_model = 512
    n_heads = 8
    d_kv_comp = 128  # Choose a compression dimension, e.g., 1/4 of d_model
    seq_len = 100
    batch_size = 4

    # Dummy input tensors (e.g., from a self-attention context where Q, K, V are same)
    query = torch.randn(batch_size, seq_len, d_model)
    key = torch.randn(batch_size, seq_len, d_model)
    value = torch.randn(batch_size, seq_len, d_model)

    # Example causal mask for self-attention
    causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
    causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)

    print("--- Testing MultiHeadAttention (Original Style) ---")
    mha_block = MultiHeadAttention(d_model, n_heads)
    mha_output, mha_attn_weights = mha_block(query, key, value, mask=causal_mask)
    print(f"MHA Output shape: {mha_output.shape}")
    print(f"MHA Attention weights shape: {mha_attn_weights.shape}")

    print("\n--- Testing MultiHeadLatentAttention (MLA Style) ---")
    mla_block = MultiHeadLatentAttention(d_model, n_heads, d_kv_comp)
    mla_output, mla_attn_weights = mla_block(query, key, value, mask=causal_mask)
    print(f"MLA Output shape: {mla_output.shape}")
    print(f"MLA Attention weights shape: {mla_attn_weights.shape}")

    print(f"\nMLA KV compression ratio (d_model/d_kv_comp): {d_model / d_kv_comp:.2f}x")
    print(f"MLA KV cache size (relative to original): {d_kv_comp / d_model:.2f}")

    # Demonstrate cross-attention scenario where key/value sequence length differs from query
    print("\n--- Testing MLA with cross-attention (different KV sequence length) ---")
    encoder_output_len = 50
    encoder_key = torch.randn(batch_size, encoder_output_len, d_model)
    encoder_value = torch.randn(batch_size, encoder_output_len, d_model)
    
    # Mask for cross-attention, typically (batch_size, 1, query_len, key_len)
    cross_attn_mask = torch.ones(batch_size, 1, seq_len, encoder_output_len).bool()

    mla_cross_output, mla_cross_attn_weights = mla_block(query, encoder_key, encoder_value, mask=cross_attn_mask)
    print(f"MLA Cross-Attention Output shape: {mla_cross_output.shape}")
    print(f"MLA Cross-Attention weights shape: {mla_cross_attn_weights.shape}")