# Multi-head Latent Attention

Multi-head Latent Attention (MLA) is a variant of MHA that drastically reduce the memory footprint and compute cost of th KV cache in LLMs in inference.


Traditional MHA caches large K and V matrices for each token, which grows the cache size quadratically as the sequence length grows and becomes a major bottleneck for long contexts.

MLA addresses this with low-rank compression, projecting the input hidden state into a much smaller latent space. It introduces a down-projection layer that compresses the large K and V matrices into a single, much smaller latent representation (a "latent KV" matrix) and stores only this small latent representation in the KV cache. At attention computation, this latent matrix is then "up-projected" by specific per-head linear layers to reconstruct the K and V vectors in their respective head dimensions.

It was introduced in DeepSeek-v2 paper [DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model](https://arxiv.org/abs/2405.04434), where in the ablation tests they found that MLA even perform better than the traditional MHA. MLA is also used in DeepSeek-v3 and DeepSeek R1.

MLA pairs especially well with KV-Cache at inference time by greatly reducing KV cache memory footprint. The inference memory efficiency makes MLA suitable for scenarios where inference speed is critical or memory is constraint, e.g.:
- long-context LLM (without hitting memory limits)
- edge and mobile devices
- efficient inference servers (serve faster and more users on a single GPU)

MLA often incorporates a "decoupled" RoPE. As the standard RoPE directly modifies K and V, in MLA, applying RoPE direcly on compressed K and V can be problematic or inefficient. 

The full MLA formula as proposed in DeepSeek-v2 paper:

\begin{align}
{c}_t^{Q} &= W^{DQ} {h}_t, \tag{1} \\
\left[ \mathbf{q}_{t,1}^{C}; \mathbf{q}_{t,2}^{C}; \ldots; \mathbf{q}_{t,n_h}^{C} \right] &= \mathbf{q}_t^{C} = W^{UQ} \mathbf{c}_t^{Q}, \tag{2} \\
\left[ \mathbf{q}_{t,1}^{R}; \mathbf{q}_{t,2}^{R}; \ldots; \mathbf{q}_{t,n_h}^{R} \right] &= \mathbf{q}_t^{R} = \mathrm{RoPE} \left( W^{QR} \mathbf{c}_t^{Q} \right), \tag{3} \\
\mathbf{q}_{t,i} &= \left[ \mathbf{q}_{t,i}^{C}; \mathbf{q}_{t,i}^{R} \right], \tag{4} \\
\mathbf{c}_t^{KV} &= W^{DKV} \mathbf{h}_t, \tag{5} \\
\left[ \mathbf{k}_{t,1}^{C}; \mathbf{k}_{t,2}^{C}; \ldots; \mathbf{k}_{t,n_h}^{C} \right] &= \mathbf{k}_t^{C} = W^{UK} \mathbf{c}_t^{KV}, \tag{6} \\
\mathbf{k}_t^{R} &= \mathrm{RoPE} \left( W^{KR} \mathbf{h}_t \right), \tag{7} \\
\mathbf{k}_{t,i} &= \left[ \mathbf{k}_{t,i}^{C}; \mathbf{k}_{t,i}^{R} \right], \tag{8} \\
\left[ \mathbf{v}_{t,1}^{C}; \mathbf{v}_{t,2}^{C}; \ldots; \mathbf{v}_{t,n_h}^{C} \right] &= \mathbf{v}_t^{C} = W^{UV} \mathbf{c}_t^{KV}, \tag{9} \\
o_{t,i} &= \sum_{j=1}^{t} \mathrm{Softmax}_j \left( \frac{ \mathbf{q}_{t,i}^\top \mathbf{k}_{j,i} }{ \sqrt{d_h + d_h^{R}} } \right) \mathbf{v}_{j,i}^{C}, \tag{10} \\
\mathbf{u}_t &= W^{O} \left[ o_{t,1}; o_{t,2}; \ldots; o_{t,n_h} \right] \tag{11}
\end{align}

Where:
- $h_t$: input token embedding at position $t$
- $n_h$: number of attention heads
- $W^{DQ}, W^{DKV}$: down-projection matrices for query and key-value content vectors
- $W^{UQ}, W^{UK}, W^{UV}$: up-projection for query, key and value from content vectors
- $W^{QR}, W^{KR}$: linear projections generating relative queries and keys (before RoPE)
- $W^O$: output linear projection matrix
- $c_t^Q$: content query vector (down-projected from input $h_t$)
- $c_t^{KV}$: content key-value vector (also down-projected from input)
- $q_t^C, q_{t,i}^C$: content queries of all heads / head i
- $q_t^R, q_{r,i}^R$: relative positional queries of all heads / head i
- $q_{t,i}$: concatenated content and relative query vectors of head i
- $k_t^C, k_{t,i}^C$: content keys for all heads / head i
- $k_t^R$: relative positional keys
- $k_{t,i}$: concatenated content and relative key vectors of head i
- $v_t^C, v_{t,i}^C$: content values for all heads / head i
- $d_h, d_h^R$: dimensions of content and relative positional subspaces per head
- $o_{t,i}$: attention output for head i at position t
- $u_t$: final output

Line-by-line explanation:
1. Compress input embedd $h_t$ into query latent space $c_t^Q$
2. Decompress latent query $c_t^Q$ back to full dimension $q_t^C$ and split across $n_h$ heads for multi-head attention
3. Generate relative positional queries $q_t^R$ from compressed query $c_t^Q$, apply $RoPE$ to handle positional information separately from content (decoupled RoPE)
4. Concatenate each head's queries $q_{t,i}^C$ and their relative positions $q_{t,i}^R$
5. Compress inputs $h_t$ into a shared key and value latent content space $c_t^{KV}$ - *this will be cached*
6. Decompress latent shared content $c^{KV}$ to keys for all heads $k_t^C$
7. Generate relative positional keys $k_t^R$  directly from input $h_t$ to apply $RoPE$ - *this will be cached*
8. Concatenate each head's keys $k_{t,i}^C$ and their relative positions $k_{t,i}^R$
9. Decompress latent shared content $c^{KV}$ to values for all heads $v_t^C$, values only need content as they don't have positional content
10. Compute scaled dot-product attention 
11. Compute final attention output

The optimization as directly quoted from the paper:

$c_t^{KV}$ and $k_t^R$ will be cached for generation. During inference, the naive formula needs to recover $k_t^C$ and $v_t^C$ from $c_t^{KV}$ for attention. Fortunately, due to the associative law of matrix multiplication, we can absorb $W^{UK}$ into $W^{UQ}$, and $W^{UV}$ into $W^O$. Therefore, we do not need to compute keys and values out for each query. Through this optimization, we avoid he computational overhead for recomputing $k_t^C$ and $v_t^C$ during inference.

## Code

Simplified implementation following the formula but without RoPE:

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MaskedScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, mask=None):
        d_k = queries.size(-1)
        attn_scores = torch.matmul(queries, keys.transpose(-2,-1)) / math.sqrt(d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        attn_values = torch.matmul(attn_weights, values)
        return attn_values

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadLatentAttention(nn.Module):
    """
    Multi-Head Latent Attention (MLA) following DeepSeek-v2 formula.
    
    Core idea: Separate compression of Q and KV, with a shared KV compression.
    """
    def __init__(self, embed_dim, num_heads, q_latent_dim, kv_latent_dim, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0, f'Embedding dimension must be divisible by number of heads'

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.q_latent_dim = q_latent_dim
        self.kv_latent_dim = kv_latent_dim

        # Initialize down-projection and up-projection matrices
        self.W_DQ = nn.Linear(embed_dim, q_latent_dim)
        self.W_UQ = nn.Linear(q_latent_dim, embed_dim)
        self.W_DKV = nn.Linear(embed_dim, kv_latent_dim)
        self.W_UK = nn.Linear(kv_latent_dim, embed_dim)
        self.W_UV = nn.Linear(kv_latent_dim, embed_dim)
        # Initialize final output 
        self.W_output = nn.Linear(embed_dim, embed_dim) 

        self.attention = MaskedScaledDotProductAttention(dropout)

    def forward(self, x, mask = None):
        batch_size, seq_len, embed_dim = x.shape

        # Step 1. Compress and decompress Q
        c_Q = self.W_DQ(x) # (batch_size, seq_len, q_latent_dim)
        q_content = self.W_UQ(c_Q) # (batch_size, seq_len, embed_dim)

        # Step 2. Compress KV into c_KV
        c_KV = self.W_DKV(x) # (batch_size, seq_len, kv_latent_dim)

        # Step 3. Decompress K and V
        k_content = self.W_UK(c_KV) # (batch_size, seq_len, embed_dim)
        v_content = self.W_UV(c_KV) # (batch_size, seq_len, embed_dim)

        # Step 4. Reshape for multi-head attention -> (batch_size, num_heads, seq_len, head_dim)
        queries = q_content.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        keys = k_content.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        values = v_content.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)

        # Step 5. Apply attention
        attn_output = self.attention(queries, keys, values, mask) # (batch_size, num_heads, seq_len, head_dim)

        # Step 6. Concatenate attention heads and reshape
        attn_output = attn_output.transpose(1,2).reshape(batch_size, seq_len, embed_dim)

        # Step 7. Apply final output projection
        output = self.W_output(attn_output)

        return output, c_KV # return c_KV to cache it



### Test and Analyze

In [13]:
# Example usage
if __name__ == "__main__":
    # Model parameters
    embed_dim = 512
    num_heads = 8
    q_latent_dim = 128  # Much smaller than embed_dim for memory savings
    kv_latent_dim = 128
    seq_len = 1024
    batch_size = 2
    
    # Create model and input
    mla = MultiHeadLatentAttention(embed_dim, num_heads, q_latent_dim, kv_latent_dim)
    x = torch.randn(batch_size, seq_len, embed_dim)
    
    # Forward pass
    output, c_kv = mla(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Latent KV cache shape: {c_kv.shape}")
    
    # Memory comparison (typical inference uses float16/bfloat16, not float32)
    bytes_per_element = 2  # float16/bfloat16 (common in inference)
    print(f"\nMemory savings analysis (assuming float16/bfloat16 precision):")
    print(f"Standard MHA KV cache per layer: {2 * seq_len * embed_dim * bytes_per_element / 1024**2:.2f} MB")
    print(f"MLA latent cache per layer: {seq_len * kv_latent_dim * bytes_per_element / 1024**2:.2f} MB") 
    print(f"Memory reduction: {(2 * embed_dim) / kv_latent_dim:.1f}x smaller")
    
    # For comparison - float32 would be 2x larger
    print(f"\n(For reference - float32 would be: {2 * seq_len * embed_dim * 4 / 1024**2:.2f} MB for MHA)")

Input shape: torch.Size([2, 1024, 512])
Output shape: torch.Size([2, 1024, 512])
Latent KV cache shape: torch.Size([2, 1024, 128])

Memory savings analysis (assuming float16/bfloat16 precision):
Standard MHA KV cache per layer: 2.00 MB
MLA latent cache per layer: 0.25 MB
Memory reduction: 8.0x smaller

(For reference - float32 would be: 4.00 MB for MHA)
