# **Multi-Head Latent Attention (MLA)**

inspired by the [huggingface implementation](https://huggingface.co/bird-of-paradise/deepseek-mla) of MLA.

MLA module as in DeepSeekV2 paper.

Their key innovations:
1. Low-Rank Key-Value Joint Compression
2. Decoupled Rotary Position Embedding

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

import math

In [None]:
def precompute_freqs_cis(dim, end, theta=10000.0):
    """
    Generates the complex-valued rotation factors used in RoPE for 
        each position and each pair of embedding dimensions.

    Args: 
        dim (int): embedding dimension (must be even)
        end (int): maximum sequence length to generate frequencies for
        theta (float): base frequency used to compute inverse wavelengths
    """

    i = torch.arange(dim // 2).float()                     # (dim // 2, )
    freqs = 1.0 / (theta ** (2 * i / dim))                 # (dim // 2, )
    t = torch.arange(end, device=freqs.device)             # (end, )
    freqs = torch.outer(t, freqs).float()                  # (end, dim // 2)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # (end, dim // 2)
    return freqs_cis

A note on how `torch.polar()` works:
- `torch.polar(abs, angle)` returns a complex number of the form $\text{abs}\cdot \exp(i \cdot \text{angle})$
- `abs` is the magnitude of the complex number and `angle` is the angle this number makes with the positive real axis on a complex plan.
- A refresher: $e^{i \cdot \theta} = \cos{\theta} + i \cdot \sin{\theta}$
- For example: `abs = 1` and `angle=0` <br>
        then `torch.polar(1, 0)` = 
        $$
            1 \cdot e^{i \cdot 0} = 1 \cdot(\cos{0} + i \cdot \sin{0}) \\
            = 1 \cdot (1 + 0 \cdot i) \\
            = 1 + 0 \cdot i
        $$
- So, `torch.polar()` returns a complex number (which by default is of `dtype=complex64`).
- However, note that this complex number still occupies only one positon in memory. <br>
  So, the tensor returned by `torch.polar(abs, angle)` is of same shape as `angle` or `abs`.

In [None]:
def reshape_for_broadcast(freqs_cis, x):
    """
    Reshapes the precomputed RoPE complex frequencies so they broadcast 
        correctly across the batch and head dimensions of the input tensor.

    Args:
        freqs_cis (torch.Tensor): complex RoPE frequencies of shape (seq_len, head_dim)
        x (torch.Tensor): input tensor to be rotated, used to infer the broadcastable shape
    """

    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

In [None]:
def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple:
    """
    Applies rotary positional embeddings (RoPE) to query and key tensors by rotating 
        each pair of embedding dimensions using position-dependent complex phases.

    Args:
        xq (torch.Tensor): query tensor of shape (batch, seq_len, ..., head_dim)
        xk (torch.Tensor): key tensor of shape (batch, seq_len, ..., head_dim)
        freqs_cis (torch.Tensor): precomputed complex RoPE frequencies of shape (max_seq_len, head_dim).
    """

    # validate input dimensions
    assert xq.shape[-1] == xk.shape[-1], "Query and Key must have same embedding dimension" 
    assert xq.shape[-1] % 2 == 0, "Embedding dimension must be even"

    # get sequence length
    q_len = xq.shape[1]
    k_len = xk.shape[1]

    # use appropriate part of freqs_cis for each sequence
    q_freqs = freqs_cis[:q_len]
    k_freqs = freqs_cis[:k_len]

    # apply rotary embeddings separately
    # split last dimension to [xq.shape[:-1]/2, 2]
    xq_ = torch.view_as_complex(xq.float().rehsape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().rehsape(*xk.shape[:-1], -1, 2))

    # reshape freqs for each
    q_freqs = reshape_for_broadcast(q_freqs, xq_)
    k_freqs = reshape_for_broadcast(k_freqs, xk_)

    # works for both [batch_size, seq_len, n_heads * head_dim] and [batch_size, seq_len, n_heads, head_dim]
    xq_out = torch.view_as_real(xq_ * q_freqs).flatten(xq.ndim-1)
    xk_out = torch.view_as_real(xk_ * k_freqs).flatten(xk.ndim-1)

    return xq_out.type_as(xq), xk_out.type_as(xk)

In [None]:
class MultiHeadLatentAttention(nn.Module):
    """
    MLA (from the DeepSeek V2 paper)
    
    Args:
        d_model:  total dimensions of the model.
        num_head: number of attention heads.
        d_embed:  embedding dimension
        d_c:      K/V compression dimension
        d_c1:     Q compression dimension
        d_rotate: dimension for rotary position embedding
        dropout:  dropout rate for attention scores
        bias:     whether to include bias in linear projections

        d_head:   inferred from d_model // num_head

    Inputs:
        sequence: input sequence for self-attention and the query for cross-attention
        key_value_state: input for the key, values for cross-attention
    """

    def __init__(
        self, 
        d_model,            # infer d_head from d_model
        num_head,
        d_embed,
        d_c,
        d_c1,
        d_rotate,
        dropout=0.1,
        bias=True,
        max_batch_size=32,  # for KV cache sizing
        max_seq_len=2048    # for KV cache sizing
        ):
        super().__init__()

        assert d_model % num_head == 0, "d_model must be divisible by num_head"
        assert d_c < d_embed, "Compression dim should be smaller than embedding dim"
        assert d_c1 < d_embed, "Query compression dim should be smaller than embedding dim"

        self.d_model = d_model
        self.num_head = num_head
        self.d_head = d_model // num_head
        self.d_embed = d_embed

        self.d_c = d_c
        self.d_c1 = d_c1
        self.d_rotate = d_rotate
        self.dropout_rate = dropout

        # linear down-projection (compression) transformations
        self.DKV_proj = nn.Linear(d_embed, d_c, bias=bias)
        self.DQ_proj = nn.Linear(d_embed, d_c1, bias=bias)

        # linear up-projection transformations
        self.UQ_proj = nn.Linear(d_c1, d_model, bias=bias)
        self.UK_proj = nn.Linear(d_c, d_model, bias=bias)
        self.UV_proj = nn.Linear(d_c, d_model, bias=bias)

        # linear RoPE-projection
        self.RQ_proj = nn.Linear(d_c1, num_head * d_rotate, bias=bias)
        self.RK_proj = nn.Linear(d_embed, d_rotate, bias=bias)

        

        