# Multi-head Latent Attention

Multi-head Latent Attention (MLA) is a variant of MHA that drastically reduce the memory footprint and compute cost of th KV cache in LLMs in inference.


Traditional MHA caches large K and V matrices for each token, which grows the cache size quadratically as the sequence length grows and becomes a major bottleneck for long contexts.

MLA addresses this with low-rank compression, projecting the input hidden state into a much smaller latent space. It introduces a down-projection layer that compresses the large K and V matrices into a single, much smaller latent representation (a "latent KV" matrix) and stores only this small latent representation in the KV cache. At attention computation, this latent matrix is then "up-projected" by specific per-head linear layers to reconstruct the K and V vectors in their respective head dimensions.

It was introduced in DeepSeek-v2 paper [DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model](https://arxiv.org/abs/2405.04434), where in the ablation tests they found that MLA even perform better than the traditional MHA. MLA is also used in DeepSeek-v3 and DeepSeek R1.

MLA pairs especially well with KV-Cache at inference time by greatly reducing KV cache memory footprint. The inference memory efficiency makes MLA suitable for scenarios where inference speed is critical or memory is constraint, e.g.:
- long-context LLM (without hitting memory limits)
- edge and mobile devices
- efficient inference servers (serve faster and more users on a single GPU)

MLA often incorporates a "decoupled" RoPE. As the standard RoPE directly modifies K and V, in MLA, applying RoPE direcly on compressed K and V can be problematic or inefficient. 

## Code

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadLatentAttention(nn.Module):
    """
    Multi-Head Latent Attention (MLA) - A memory-efficient attention mechanism.
    
    Core idea: Compress K,V matrices into a smaller latent space to reduce KV cache size.
    Instead of caching full K,V matrices, we cache only the compressed representation.
    """
    
    def __init__(self, embed_dim, num_heads, latent_dim, dropout=0.0):
        super().__init__()
        assert embed_dim % num_heads == 0
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.latent_dim = latent_dim
        
        # Standard query projection
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        
        # MLA core: compress K,V into latent space, then decompress per head
        self.kv_compress = nn.Linear(embed_dim, latent_dim, bias=False)
        self.k_decompress = nn.Linear(latent_dim, embed_dim, bias=False)
        self.v_decompress = nn.Linear(latent_dim, embed_dim, bias=False)
        
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor [batch, seq_len, embed_dim]
            mask: Optional attention mask [batch, seq_len, seq_len] or broadcastable
        """
        batch_size, seq_len, _ = x.shape
        
        # 1. Compress input into latent space (this is what gets cached in inference)
        latent = self.kv_compress(x)  # [batch, seq_len, latent_dim]
        
        # 2. Generate Q, K, V
        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_decompress(latent).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_decompress(latent).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 3. Scaled dot-product attention
        scale = 1.0 / math.sqrt(self.head_dim)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
        
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # 4. Apply attention to values
        out = torch.matmul(attn_weights, V)  # [batch, num_heads, seq_len, head_dim]
        
        # 5. Concatenate heads and project
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        return self.out_proj(out)



### Test and Analyze

In [3]:
# Example usage
if __name__ == "__main__":
    # Model parameters
    embed_dim = 512
    num_heads = 8
    latent_dim = 128  # Much smaller than embed_dim for memory savings
    seq_len = 1024
    batch_size = 2
    
    # Create model and input
    mla = MultiHeadLatentAttention(embed_dim, num_heads, latent_dim)
    x = torch.randn(batch_size, seq_len, embed_dim)
    
    # Forward pass
    output = mla(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    
    # Memory comparison (typical inference uses float16/bfloat16, not float32)
    bytes_per_element = 2  # float16/bfloat16 (common in inference)
    print(f"\nMemory savings analysis (assuming float16/bfloat16 precision):")
    print(f"Standard MHA KV cache per layer: {2 * seq_len * embed_dim * bytes_per_element / 1024**2:.2f} MB")
    print(f"MLA latent cache per layer: {seq_len * latent_dim * bytes_per_element / 1024**2:.2f} MB") 
    print(f"Memory reduction: {(2 * embed_dim) / latent_dim:.1f}x smaller")
    
    # For comparison - float32 would be 2x larger
    print(f"\n(For reference - float32 would be: {2 * seq_len * embed_dim * 4 / 1024**2:.2f} MB for MHA)")

Input shape: torch.Size([2, 1024, 512])
Output shape: torch.Size([2, 1024, 512])

Memory savings analysis (assuming float16/bfloat16 precision):
Standard MHA KV cache per layer: 2.00 MB
MLA latent cache per layer: 0.25 MB
Memory reduction: 8.0x smaller

(For reference - float32 would be: 4.00 MB for MHA)
