## Weight Absorption in Multi-Head Latent Attention (MLA)

### 1. Low-Rank Compression of Keys and Values
In **Multi-Head Latent Attention (MLA)**, the model applies a low-rank compression to the **Keys (K)** and **Values (V)** using a latent vector \( c \).

The projections are defined as:

\[
Q = x W_q
\]

\[
K = c W_{uk}
\]

where:

- \( x \) is the input token representation  
- \( c \) is the latent vector (compressed representation)  
- \( W_q \) is the query projection matrix  
- \( W_{uk} \) is the key projection matrix  

---

### 2. Weight Absorption (Matrix Fusion)

Since both queries and keys participate in the attention dot product, their projection matrices can be **mathematically combined**.

By **pre-multiplying** the projection matrices \( W_q \) and \( W_{uk} \), the key-related weights are *absorbed* into the query projection. This results in a **single effective projection matrix**.

Conceptually:

\[
QK^\top = x W_q (c W_{uk})^\top
\]

\[
= x (W_q W_{uk}^\top) c^\top
\]

Thus, instead of computing two separate matrix multiplications during inference, the model performs **one fused matrix multiplication**.

---

### 3. Optimization Benefits

- **Reduced latency** due to fewer matrix multiplications during inference  
- **Lower computational overhead**, especially for large models and long sequences  
- **Inference-time optimization**, typically applied after training  

---

### 4. PyTorch Implementation Detail (`register_buffer`)

In PyTorch, the absorbed weight matrix (e.g., `absorbed_k`) is stored using:

```python
self.register_buffer("absorbed_k", absorbed_k)
```

The use of `register_buffer` tells PyTorch that:

- `absorbed_k` is part of the model's state
- It will be saved and loaded via `state_dict`
- It is not treated as a trainable parameter
- No gradients are computed for it during backpropagation

In [None]:
import torch
import torch.nn as nn
import torch.functional as F



In [None]:
class RopelessMLA(nn.Module):
    def _init_(self,d_model,n_heads,kv_latent_dim):
        super().__init__()
        self.d_model=d_model
        self.n_heads=n_heads
        self.dh=d_model//n_heads


        self.W_q=nn.Linear(d_model,d_model,bias=False)
        self.W_dkv=nn.Linear(d_model,kv_latent_dim,bias=False)
        self.W_uk=nn.Linear(kv_latent_dim,d_model,bias=False)
        self.W_uv =nn.Linear(kv_latent_dim,d_model,bias=False)
        self.W_o=nn.Linear(d_model,d_model,bias=False)
        self.nl=nn.LayerNorm(kv_latent_dim)
        self.register_buffer("absorbed_k",None) # holds W_q @w_uk


def forward (self,x,kv_cache=None,past_length=0):
    B,S,D=x.size()

    if self.absorbed_k is None :
        absorbed=torch.matmul(self.W_q.weight,self.W_uk.weight)
        self.absorbed_k=absorbed.view(self.head,self.dh,-1)