# IN PROGRESS

# 1. Deepseek-V3

## 1.1 Multi-head Latent Attention-(MLA)
<img src="./figures/mla.webp" >

We first introduce the standard MHA mechanism as background. 
Let $d$ be the embedding dimension, $n_h$ be the number of attention heads, $d_h$ be the dimension per head, and $\mathbf{h}_{t} \in \mathbb{R}^{d}$ be the attention input of the $t$-th token at an attention layer. 
Standard MHA first produces $\mathbf{q}_{t}, \mathbf{k}_{t}, \mathbf{v}_{t} \in \mathbb{R}^{d_h n_h}$ through three matrices $W^{Q}, W^{K}, W^{V} \in \mathbb{R}^{d_h n_h \times d}$, respectively: 

$$
\mathbf{q}_{t} = W^{Q} \mathbf{h}_{t},
$$
$$
\mathbf{k}_{t} = W^{K} \mathbf{h}_{t},
$$
$$
\mathbf{v}_{t} = W^{V} \mathbf{h}_{t}
$$


## 1.2 Low-Rank Key-Value Joint Compression

The core of \dsattn{} is the low-rank joint compression for keys and values to reduce KV cache:

$$
\mathbf{c}_{t}^{KV} = W^{DKV} \mathbf{h}_{t},
$$
$$
\mathbf{k}_{t}^{C} = W^{UK} \mathbf{c}_{t}^{KV},
$$
$$
\mathbf{v}_{t}^{C} = W^{UV} \mathbf{c}_{t}^{KV}
$$


where $\mathbf{c}_{t}^{KV} \in \mathbb{R}^{d_c}$ is the compressed latent vector for keys and values; 
$d_c (\ll d_h n_h)$ denotes the KV compression dimension;
$W^{DKV} \in \mathbb{R}^{d_c \times d}$ is the down-projection matrix;
and $W^{UK},W^{UV} \in \mathbb{R}^{d_h n_h \times d_c}$ are the up-projection matrices for keys and values, respectively. 
During inference, \dsattn{} only needs to cache $\mathbf{c}_{t}^{KV}$, so its KV cache has only $d_{c}l$ elements, where $l$ denotes the number of layers. 
In addition, during inference, since $W^{UK}$ can be absorbed into $W^{Q}$, and $W^{UV}$ can be absorbed into $W^{O}$, we even do not need to compute keys and values out for attention. 
Figure~\ref{fig:dsattn} intuitively illustrates how the KV joint compression in \dsattn{} reduces the KV cache. 

Moreover, in order to reduce the activation memory during training, we also perform low-rank compression for the queries, even if it cannot reduce the KV cache:

$$
\mathbf{c}_{t}^{Q} = W^{DQ} \mathbf{h}_{t}, 
$$
$$
\mathbf{q}_{t}^{C} = W^{UQ} \mathbf{c}_{t}^{Q},
$$



where $\mathbf{c}_{t}^{Q} \in \mathbb{R}^{d_c^{\prime}}$ is the compressed latent vector for queries; 
$d_c^{\prime} (\ll d_h n_h)$ denotes the query compression dimension; 
and $W^{DQ} \in \mathbb{R}^{d_c^{\prime} \times d}, W^{UQ} \in \mathbb{R}^{d_h n_h \times d_c^{\prime}}$ are the down-projection and up-projection matrices for queries, respectively. 

In [None]:
import torch
import torch.nn as nn

class LoxoRankKVCompression(nn.Module):
    def __init__(self, d_model=512, n_heads=8, d_head=64, d_compression=32):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_head
        self.d_compression = d_compression

        # Projection matrices
        self.W_DKV = nn.Linear(d_model, d_compression, bias=False)  # Down-projection
        self.W_UK = nn.Linear(d_compression, n_heads * d_head, bias=False)  # Key up-projection
        self.W_UV = nn.Linear(d_compression, n_heads * d_head, bias=False)  # Value up-projection

    def forward(self, h_t, cache=None):
        """Process one token step, returns compressed KV and reconstructed K/V"""
        # h_t shape: (batch_size, d_model)
        
        # Step 1: Joint KV compression (Equation 9)
        c_t_KV = self.W_DKV(h_t)  # (batch_size, d_compression)
        
        # Step 2: Cache management (store compressed representation)
        if cache is not None:
            cache.append(c_t_KV.detach())
        
        # Step 3: Up-projection to original dimensions (Equations 10-11)
        k_t_C = self.W_UK(c_t_KV).view(-1, self.n_heads, self.d_head)  # (batch_size, n_heads, d_head)
        v_t_C = self.W_UV(c_t_KV).view(-1, self.n_heads, self.d_head)  # (batch_size, n_heads, d_head)
        
        return k_t_C, v_t_C, c_t_KV

# Example usage
batch_size = 1
d_model = 512
d_compression = 32

# Initialize module
compressor = LoxoRankKVCompression(d_model=d_model, d_compression=d_compression)

# Simulate hidden state for one token
h_t = torch.randn(batch_size, d_model)  # (1, 512)

# Forward pass
compressed_k, compressed_v, c_t_KV = compressor(h_t)

# During inference, we would only cache c_t_KV
kv_cache = [c_t_KV.detach()]

print("Original hidden state size:", h_t.shape)
print("Compressed KV cache size:", c_t_KV.shape)
print("Reconstructed keys shape:", compressed_k.shape)
print("Reconstructed values shape:", compressed_v.shape)

## 1.3 Decoupled Rotary Position Embedding

Standard RoPE is incompatible with low-rank KV compression as done above. Decoupled RoPE strategy uses additional multi-head queries `q_t` and a shared key `k_t` to carry RoPE. This sums up the complete MLA computation as:

$$
\begin{aligned}
    [\mathbf{q}_{t, 1}^{R};\mathbf{q}_{t, 2}^{R};...;\mathbf{q}_{t, n_{h}}^{R}] = \mathbf{q}_{t}^{R} &= \operatorname{RoPE}({W^{QR}} \mathbf{c}_{t}^{Q}), \\
    \mathbf{k}_{t}^{R} &= \operatorname{RoPE}({W^{KR}} \mathbf{h}_{t}), \\
    \mathbf{q}_{t, i} &= [\mathbf{q}_{t, i}^{C}; \mathbf{q}_{t, i}^{R}], \\
    \mathbf{k}_{t, i} &= [\mathbf{k}_{t, i}^{C}; \mathbf{k}_{t}^{R}], \\
    \mathbf{o}_{t, i} &= \sum_{j=1}^{t} \operatorname{Softmax}_j\left(\frac{\mathbf{q}_{t, i}^T \mathbf{k}_{j, i}}{\sqrt{d_{h} + d_{h}^{R}}}\right) \mathbf{v}_{j, i}^{C}, \\ 
    \mathbf{u}_{t} &= W^{O} [\mathbf{o}_{t, 1};\mathbf{o}_{t, 2};...;\mathbf{o}_{t, n_{h}}].
\end{aligned}
$$


## 1.4 KV Cache

In [None]:
# Simplified generation loop with KV cache
def generate(input_ids, max_length=50):
    kv_cache = []  # Stores compressed KV states
    for _ in range(max_length):
        # Forward pass: compute logits and update cache
        logits, kv_cache = model(input_ids, kv_cache=kv_cache)
        # Sample next token
        next_token = sample(logits)
        input_ids = torch.cat([input_ids, next_token], dim=-1)
    return input_ids

# 2. DeepSeek-R1 

## 2.1 Gate Implementation

## 2.2 Cold-start

## 2.3 Reasoning-Oriented Reinforcement Learning