# 2025 USA-NA-AIO Round 2, Problem 2 â€” ANSWERS

## Problem 2 (100 points)

Multi-head attention (MHA) is a big breakthrough in AI. Based on its original form, there are many variants that improved it.

In this problem, you are asked to study multi-head attention and its variants.

We use the following notation in this problem.

- $B$: batch size. $b$: index of a sample.
- $L_1$: length of an attending sequence. $l_1$: index of a position in this sequence.
- $L_2$: length of a being attended sequence. $l_2$: index of a position in this sequence.
- $D_1$: dimension of a hidden state/token in an attending sequence.
- $D_2$: dimension of a hidden state/token in a being attended sequence.
- $H$: number of heads. $h$: index of a head.
- $D_v$: dimension of a value vector.
- $D_{qk}$: dimension of a query/key vector.

Before starting this problem, make sure to run the following code first without any change:


In [None]:
# Run code in this cell

"""
DO NOT MAKE ANY CHANGE IN THIS CELL.
"""
import torch
import torch.nn as nn
import numpy as np

---

## $\color{red}{\text{WARNING !!!}}$

Beyond importing libraries/modules/classes/functions in the preceding cell, you are **NOT** allowed to import anything else for the following purposes:

- As a part of your final solution. For instance, if a problem asks you to build a model without using sklearn but you use it, then you will not earn points.

- Temporarily import something to assist you to get a solution. For instance, if a problem asks you to manually compute eigenvalues but you temporarily use `np.linalg.eig` to get an answer and then delete your code, then you violate the rule.

**Rule of thumb:** Each part has its particular purpose to intentionally test you something. Do not attempt to find a shortcut to circumvent the rule.


---

## Part 1 (5 points, non-coding task)

Do the following tasks (Reasoning is not required).

1. For each hidden state at position $l_1$ in an attending sequence, $x_{l_1} \in \mathbb{R}^{D_1}$, we project it into a query vector for head $h$ according to

$$q_{l_1, h} = W^Q_h x_{l_1}$$

What is the shape of $W^Q_h$?

2. For each hidden state at position $l_2$ in a being attended sequence $y_{l_2} \in \mathbb{R}^{D_2}$, we project it into a key vector for head $h$ according to

$$k_{l_2, h} = W^K_h y_{l_2}$$

What is the shape of $W^K_h$?

3. For each hidden state at position $l_2$ in a being attended sequence $y_{l_2} \in \mathbb{R}^{D_2}$, we project it into a value vector for head $h$ according to

$$v_{l_2, h} = W^V_h y_{l_2}$$

What is the shape of $W^V_h$?


**Answer:**

1. The shape of $W^Q_h$ is $(D_{qk}, D_1)$.

2. The shape of $W^K_h$ is $(D_{qk}, D_2)$.

3. The shape of $W^V_h$ is $(D_v, D_2)$.

""" END OF THIS PART """

---

## Part 2 (5 points, non-coding task)

For $M \in \{Q, K, V\}$, We concatenate $M$-projection matrices $\{W^M_h : h \in \{0, 1, \cdots, H-1\}\}$ along axis 0 as

$$W^M = \begin{bmatrix} W^M_0 \\ W^M_1 \\ \vdots \\ W^M_{H-1} \end{bmatrix}$$

At each position $l_1$ in an attending sequence, we concatenate queries $\{q_{l_1, h} : h \in \{0, 1, \cdots, H-1\}\}$ along axis 0 to get

$$q_{l_1} = \begin{bmatrix} q_{l_1, 0} \\ q_{l_1, 1} \\ \vdots \\ q_{l_1, H-1} \end{bmatrix}$$

At each position $l_2$ in a being attended sequence, we concatenate keys/values $m \in \{k, v\}$ $\{m_{l_2, h} : h \in \{0, 1, \cdots, H-1\}\}$ along axis 0 to get

$$m_{l_2} = \begin{bmatrix} m_{l_2, 0} \\ m_{l_2, 1} \\ \vdots \\ m_{l_2, H-1} \end{bmatrix}$$

Do the following tasks (Reasoning is not required).

1. What is the shape of $W^M$ for $M \in \{Q, K, V\}$?
2. What is the shape of $q_{l_1}$?
3. What is the relationship between $q_{l_1}$ and $W^Q$?
4. For $m \in \{k, v\}$, what is the shape of $m_{l_2}$?
5. What is the relationship between $m_{l_2}$ and $W^M$?


**Answer:**

1. The shape of $W^Q$ is $(H \cdot D_{qk}, D_1)$.

   The shape of $W^K$ is $(H \cdot D_{qk}, D_2)$.

   The shape of $W^V$ is $(H \cdot D_v, D_2)$.

2. The shape of $q_{l_1}$ is $(H \cdot D_{qk},)$.

3. $q_{l_1} = W^Q x_{l_1}$.

4. The shape of $k_{l_2}$ is $(H \cdot D_{qk},)$.

   The shape of $v_{l_2}$ is $(H \cdot D_v,)$.

5. $k_{l_2} = W^K y_{l_2}$.

   $v_{l_2} = W^V y_{l_2}$.

""" END OF THIS PART """


---

## Part 3 (10 points, non-coding task)

Define function $\text{Softmax}: \mathbb{R}^d \to \mathbb{R}^d$, with the $i$th output value as

$$\text{Softmax}_i(z) = \frac{\exp(z_i)}{\sum_{j=0}^{d-1} \exp(z_j)}$$

At position $l_1$ in the attending sequence, its attention score to position $l_2$ in the being attended sequence for head $h$ is denoted as $\alpha_{h, l_1}^{l_2}$.

We can write $\alpha_{h, l_1}^{l_2}$ in the following form:

$$\alpha_{h, l_1}^{l_2} = \text{Softmax}_{l_2}\left( \boxed{\color{red}{???}} \right)$$

What is the formula in the above red box (reasoning is not required)?


**Answer:**

$$\alpha_{h, l_1}^{l_2} = \text{Softmax}_{l_2}\left( \frac{q_{h, l_1}^\top K_h^\top}{\sqrt{D_{qk}}} \right)$$

where

$$K_h = \begin{bmatrix} k_{h, 0}^\top \\ k_{h, 1}^\top \\ \vdots \\ k_{h, L_2-1}^\top \end{bmatrix} \in \mathbb{R}^{L_2 \times D_{qk}}$$

""" END OF THIS PART """


---

## Part 4 (5 points, non-coding task)

At position $l_1$ in an attending sequence, for head $h$, the information extracted from attending to a being attended sequence is given by

$$o_{h, l_1} = \sum_{l_2=0}^{L_2-1} \alpha_{h, l_1}^{l_2} v_{l_2, h}$$

We hereafter call $o_{h, l_1}$ a pre-out-projection output vector.

Do the following tasks.

1. What is the shape of vector $o_{h, l_1}$?

2. We concatenate $\{o_{h, l_1} : h \in \{0, 1, \cdots, H-1\}\}$ along axis 0:

$$o_{l_1} = \begin{bmatrix} o_{0, l_1} \\ o_{1, l_1} \\ \vdots \\ o_{H-1, l_1} \end{bmatrix}$$

What is the shape of $o_{l_1}$?

3. We project $o_{l_1}$ to a post-out-projection output vector via an out-projection matrix:

$$x^{out}_{l_1} = W^O o_{l_1} \in \mathbb{R}^{D_1}$$

where

$$W^O = \begin{bmatrix} W^O_0 & W^O_1 & \cdots & W^O_{H-1} \end{bmatrix}$$

What is the shape of $W^O_h$ for each $h \in \{0, 1, \cdots, H-1\}$ and $W^O$?


**Answer:**

1. The shape of $o_{h, l_1}$ is $(D_v,)$.

2. The shape of $o_{l_1}$ is $(H \cdot D_v,)$.

3. For each head $h$, the shape of $W^O_h$ is $(D_1, D_v)$.

   The shape of $W^O$ is $(D_1, H \cdot D_v)$.

""" END OF THIS PART """


---

## Part 5 (10 points, coding task)

In this part, you are asked to build your own multi-head attention module that subclasses `nn.Module`.

For simplicity, we ignore any masking. That is, each position in an attending sequence attends to all positions in a being attended sequence.

In your code, you do not need to worry about whether your code is efficient in an autoprogressive token generation process when your module is used in inference in a GPT-like task.

That is, if we use your code in a GPT-like task to autoprogressively generate tokens, it is totally fine if you repeatly generate the same key and value at a given position rather than more efficiently storing their values in cache.

1. The class name is `MyMHA`.

2. **Attributes:**
   - `D_1`: Dimension of a hidden state/token in an attending sequence.
   - `D_2`: Dimension of a hidden state/token in a being attended sequence.
   - `D_v`: Dimension of a value vector.
   - `D_qk`: Dimension of a query/key vector.
   - `H`: Number of heads.
   - `W_Q`: A linear module whose weights is a query-projection matrix. The shape should be consistant with your answer in Part 2. No bias.
   - `W_K`: A linear module whose weights is key-projection matrix. The shape should be consistant with your answer in Part 2. No bias.
   - `W_V`: A linear module whose weights is value-projection matrix. The shape should be consistant with your answer in Part 2. No bias.
   - `W_O`: A linear module whose weights is an out-projection matrix. The shape should be consistant with your answer in Part 4. No bias.

3. **Method `__init__`:**
   - Inputs: `D_1`, `D_2`, `D_qk`, `D_v`, `H`
   - Outputs: None
   - What to do inside this method: Initialize attribute values

4. **Method `forward`:**
   - Inputs:
     - An attending sequence (tensor) with shape `(B, L_1, D_1)`
     - A being addended sequence (tensor) with shape `(B, L_2, D_2)`
   - Outputs: Post-out-projection outputs with shape `(B, L_1, D_1)`
   - What to do inside this method:
     - Compute the outputs
     - After each operation, add a comment on the tensor shape
     - Do not use any loop


In [None]:
### WRITE YOUR SOLUTION HERE ###

class MyMHA(nn.Module):
    def __init__(self, D_1, D_2, D_qk, D_v, H):
        super().__init__()
        self.D_1 = D_1
        self.D_2 = D_2
        self.D_qk = D_qk
        self.D_v = D_v
        self.H = H

        self.W_Q = nn.Linear(in_features=D_1, out_features=H*D_qk, bias=False)
        self.W_K = nn.Linear(in_features=D_2, out_features=H*D_qk, bias=False)
        self.W_V = nn.Linear(in_features=D_2, out_features=H*D_v, bias=False)
        self.W_O = nn.Linear(in_features=H*D_v, out_features=D_1, bias=False)

    def forward(self, x, y):
        B = x.shape[0] # batch size
        L_1 = x.shape[1] # the length of sequence x
        L_2 = y.shape[1] # the length of sequence y

        Q = self.W_Q(x) # shape: (B,L_1,H*D_qk)
        K = self.W_K(y) # shape: (B,L_2,H*D_qk)
        V = self.W_V(y) # shape: (B,L_2,H*D_v)

        Q = Q.reshape(B,L_1,self.H,self.D_qk) # shape: (B,L_1,H,D_qk)
        K = K.reshape(B,L_2,self.H,self.D_qk) # shape: (B,L_2,H,D_qk)
        V = V.reshape(B,L_2,self.H,self.D_v) # shape: (B,L_2,H,D_v)

        Q = Q.permute(0,2,1,3) # shape: (B,H,L_1,D_qk)
        K = K.permute(0,2,1,3) # shape: (B,H,L_2,D_qk)
        V = V.permute(0,2,1,3) # shape: (B,H,L_2,D_v)

        logits = Q @ K.transpose(-2,-1) / (self.D_qk**0.5) # shape: (B,H,L_1,L_2)
        alpha = torch.softmax(logits, dim=-1) # shape: (B,H,L_1,L_2)

        O = alpha @ V # shape: (B,H,L_1,D_v)

        O = O.permute(0,2,1,3) # shape: (B,L_1,H,D_v)
        O = O.reshape(B,L_1,self.H*self.D_v) # shape: (B,L_1,H*D_v)
        return self.W_O(O) # shape: (B,L_1,D_1)

""" END OF THIS PART """

---

Next, let us study a variant of MHA: **Group Query Attention (GQA)**.

Recall that in MHA, the number of heads in queries, keys and values are the same, $H$. Thus, query $q_{l_1, h}$ attends to key $k_{l_2, h}$ with the same head index $h$.

In GQA, we relax this constraint by allowing keys and values to have $G$ heads ($G \leq H$), where $G$ is factor of $H$. For instance, if $H = 12$, then $G \in \{1, 2, 3, 4, 6, 12\}$.

In GQA, a query $q_{l_1, h}$ with head $h$ is permitted to attend to a key $k_{l_2, g}$ and use value $v_{l_2, g}$ in computing its output with head $g$ if

$$h \equiv g \pmod{G}$$

Thus, each head in keys and values is mapped to $\frac{H}{G} \geq 1$ heads in queries.

As an example, suppose $H = 12$ and $G = 3$. Then
- Head $g = 0$ in keys and values is associated with heads $h = 0, 3, 6, 9$ in queries.
- Head $g = 1$ in keys and values is associated with heads $h = 1, 4, 7, 10$ in queries.
- Head $g = 2$ in keys and values is associated with heads $h = 2, 5, 8, 11$ in queries.

---

## Part 6 (5 points, non-coding task)

For $M \in \{K, V\}$, Denote the $M$-projection matrix as

$$W^{M, GQA} = \begin{bmatrix} W^{M, GQA}_0 \\ \vdots \\ W^{M, GQA}_{G-1} \end{bmatrix}$$

Now, we concatenate $\frac{H}{G}$ copies of the above matrix along axis 0:

$$\tilde{W}^{M, GQA} = \begin{bmatrix} W^{M, GQA} \\ W^{M, GQA} \\ \vdots \\ W^{M, GQA} \end{bmatrix}$$

What is the relationship between $\text{rank}(\tilde{W}^{M, GQA})$ and $\text{rank}(W^{M, GQA})$?

Reasoning is required.


**Answer:**

Let $\{w^*_i : i \in \{0, 1, \cdots, r-1\}\}$ be $r$ linearly independent row vectors that span all row vectors of $W^{M, GQA}$.

Because each row vector in $W^{M, GQA}$ has $\frac{H}{G}$ copies in $\tilde{W}^{M, GQA}$, we must have that $\{w^*_i : i \in \{0, 1, \cdots, r-1\}\}$ also spans $\tilde{W}^{M, GQA}$.

Therefore,

$$\text{rank}(\tilde{W}^{M, GQA}) = \text{rank}(W^{M, GQA})$$

""" END OF THIS PART """


---

## Part 7 (10 points, coding task)

In this part, please build your own GQA module called `MyGQA`.

The requirement is pretty much the same as Part 5.

- Do NOT create $\frac{H}{G}$ copies of key-projection and value-projection matrices. Otherwise, you will use too much unnecessary memory.
- No loop is allowed.


In [None]:
### WRITE YOUR SOLUTION HERE ###

class MyGQA(nn.Module):
    def __init__(self, D_1, D_2, D_qk, D_v, H, G):
        super().__init__()
        self.D_1 = D_1
        self.D_2 = D_2
        self.D_qk = D_qk
        self.D_v = D_v
        self.H = H
        self.G = G

        self.W_Q = nn.Linear(in_features=D_1, out_features=H*D_qk, bias=False)
        self.W_K = nn.Linear(in_features=D_2, out_features=G*D_qk, bias=False)
        self.W_V = nn.Linear(in_features=D_2, out_features=G*D_v, bias=False)
        self.W_O = nn.Linear(in_features=H*D_v, out_features=D_1, bias=False)

    def forward(self, x, y):
        B = x.shape[0] # batch size
        L_1 = x.shape[1] # the length of sequence x
        L_2 = y.shape[1] # the length of sequence y
        num_copies = self.H // self.G

        Q = self.W_Q(x) # shape: (B,L_1,H*D_qk)
        K = self.W_K(y) # shape: (B,L_2,G*D_qk)
        V = self.W_V(y) # shape: (B,L_2,G*D_v)

        Q = Q.reshape(B,L_1,num_copies,self.G,self.D_qk) # shape: (B,L_1,num_copies,G,D_qk)
        K = K.reshape(B,L_2,1,self.G,self.D_qk) # shape: (B,L_2,1,G,D_qk)
        V = V.reshape(B,L_2,1,self.G,self.D_v) # shape: (B,L_2,1,G,D_v)

        Q = Q.permute(0,2,3,1,4) # shape: (B,num_copies,G,L_1,D_qk)
        K = K.permute(0,2,3,1,4) # shape: (B,1,G,L_2,D_qk)
        V = V.permute(0,2,3,1,4) # shape: (B,1,G,L_2,D_v)

        logits = Q @ K.transpose(-2,-1) / (self.D_qk**0.5) # shape: (B,num_copies,G,L_1,L_2)
        alpha = torch.softmax(logits, dim=-1) # shape: (B,num_copies,G,L_1,L_2)

        O = alpha @ V # shape: (B,num_copies,G,L_1,D_v)

        O = O.permute(0,3,1,2,4) # shape: (B,L_1,num_copies,G,D_v)
        O = O.reshape(B,L_1,-1) # shape: (B,L_1,H*D_v)
        return self.W_O(O) # shape: (B,L_1,D_1)

""" END OF THIS PART """

---

## Part 8 (5 points, non-coding task)

MHA is a special case of GQA. Explain why.


**Answer:**

When $G = H$, GQA becomes MHA.

""" END OF THIS PART """


---

Now, let us study another variant of MHA: **Multi-head Latent Attention (MLA)**. MLA was introduced by DeepSeek. It is a core component of DeepSeek's large language model (LLM).

The key intuition of MLA is as follows. In MHA, the key and value projection matrices

$$W^{K, MHA} \in \mathbb{R}^{H \cdot D_{qk} \times D_2}, \quad W^{V, MHA} \in \mathbb{R}^{H \cdot D_v \times D_2}$$

may be high dimensional.

For instance, suppose $H \cdot D_{qk} = H \cdot D_v = D_2 = 4096$.

However, it is not necessarily the case that these matrices are with high ranks (such as 4096). Their actual ranks (or top few ranks that make their truncated singular value decomposition (SVD) to be close to the actual matrices) may be much lower than that.

To capture the low-rank feature, MLA proposed the following model:

$$W^{K, MHA} = W^{UK, MLA} W^{DKV, MLA}$$

$$W^{V, MHA} = W^{UV, MLA} W^{DKV, MLA}$$

where

- $W^{DKV, MLA} \in \mathbb{R}^{r \times D_2}$: down-projection matrix for computing keys and values.
- $W^{UK, MLA} \in \mathbb{R}^{H \cdot D_{qk} \times r}$: up-projection matrix for computing keys.
- $W^{UV, MLA} \in \mathbb{R}^{H \cdot D_v \times r}$: up-projection matrix for computing values.

In practice, rank $r$ is typically much smaller than $\min\{H \cdot D_{qk}, H \cdot D_v, D_2\}$.

---

In all remaining parts of this problem, to simplify your analysis and highlight the relationships of MHA, GQA and MLA, we make the following assumptions:

- $D_1 = D_2 = D$.
- $D_{qk} = D_v = d$.
- $d$ is a factor of $D$.

Under these assumptions, the number heads $H$ satisfies

$$H = \frac{D}{d}$$

---

## Part 9 (10 points, non-coding task)

In this part, you are asked to prove that GQA can be equivalently represented by MLA.

In your solution, it is sufficient for you to prove that for $M \in \{K, V\}$, for matrix

$$\tilde{W}^{M, GQA} = \begin{bmatrix} W^{M, GQA} \\ W^{M, GQA} \\ \vdots \\ W^{M, GQA} \end{bmatrix} \in \mathbb{R}^{D \times D}$$

(defined in Part 6) who is the concatenation of $\frac{H}{G}$ copies of

$$W^{M, GQA} = \begin{bmatrix} W^{M, GQA}_0 \\ \vdots \\ W^{M, GQA}_{G-1} \end{bmatrix} \in \mathbb{R}^{G \cdot d \times D}$$

matrix $\tilde{W}^{M, GQA}$ can be decomposed as

$$\tilde{W}^{M, GQA} = W^{UM, MLA} W^{DKV, MLA}$$

where

- $W^{DKV, MLA} \in \mathbb{R}^{r \times D}$: down-projection matrix for computing keys and values.
- $W^{UM, MLA} \in \mathbb{R}^{D \times r}$: up-projection matrix for computing $M$ (keys or values).
- $r = G \cdot d$.


**Answer:**

We have

$$\text{rank}(\tilde{W}^{M, GQA}) = \text{rank}(W^{M, GQA}) \leq \min\{G \cdot d, D\} = \min\{r, D\} = r$$

where the first equality follows from Part 6.

Therefore, SVD implies

$$\tilde{W}^{M, GQA} = \sum_{i=0}^{r-1} \sigma_i u_i v_i^\top = \underbrace{\begin{bmatrix} u_0 & u_1 & \cdots & u_{r-1} \end{bmatrix}}_{W^{UM, MLA}} \underbrace{\begin{bmatrix} \sigma_0 v_0^\top \\ \sigma_1 v_1^\top \\ \vdots \\ \sigma_{r-1} v_{r-1}^\top \end{bmatrix}}_{W^{DKV, MLA}}$$

""" END OF THIS PART """


---

## Part 10 (10 points, coding task)

You are asked to define a function called `GQA_2_MLA` that performs the following tasks:

**Input:**

- `W_M_GQA`: A numpy array with shape `(r, D)`, where `r` is guaranteed to be a factor of `D` (not something you need to worry about).

**Outputs:**

- `W_DKV_MLA`: A numpy array with shape `(r, D)`.
- `W_UM_MLA`: A numpy array with shape `(D, r)`.

**Things to do inside this function:**

- Compute `W_M_GQA_tilde` that concatenates `D/r` copies of `W_M_GQA` along axis 0.
- Print the shapes of `W_UM_MLA` and `W_DKV_MLA`.
- Print the mean-squared error between `W_M_GQA_tilde` and `W_UM_MLA @ W_DKV_MLA`.

**Hints:**

- You may use `np.linalg`.
- PyTorch is not allowed.
- No loop in your code.

After defining this function, test it with the input `np.random.randn(4, 24)`.


In [None]:
### WRITE YOUR SOLUTION HERE ###

def GQA_2_MLA(W_M_GQA):
    r = W_M_GQA.shape[0]
    D = W_M_GQA.shape[1]
    num_copies = D // r

    W_K_GQA_tilde = np.concatenate([W_M_GQA] * num_copies, axis=0)
    U, S, V = np.linalg.svd(W_K_GQA_tilde)
    W_UM_MLA = U[:, :r]
    W_DKV_MLA = S.reshape(-1,1)[:r, :] * V[:r, :]
    print(f"Shape of W_UK_MLA: {W_UM_MLA.shape}")
    print(f"Shape of W_DKV_MLA: {W_DKV_MLA.shape}")

    MSE = np.mean((W_K_GQA_tilde - W_UM_MLA @ W_DKV_MLA)**2)
    print(f"Mean-squared error: {MSE}")

    return W_DKV_MLA, W_UM_MLA

GQA_2_MLA(np.random.randn(4,24))

""" END OF THIS PART """

---

## Part 11 (10 points, non-coding task)

So far, we have proved that GQA can always be represented by MLA.

In this part, you are asked to prove that GQA is not equivalent to MLA. What you need to do is to find one example that MLA cannot be represented as GQA.

To be specific, please do the following things:

1. Construct $W^{DKV, MLA} \in \mathbb{R}^{1 \times 2}$.

2. Construct $W^{UM, MLA} \in \mathbb{R}^{2 \times 1}$.

3. Do matrix multiplication $W^{UM, MLA} W^{DKV, MLA}$.

4. Show that this product matrix is not the concatenation of two copies of 1-by-2 matrices along axis 0.


**Answer:**

Define

$$W^{DKV, MLA} = \begin{bmatrix} 1 & 2 \end{bmatrix}$$

and

$$W^{UM, MLA} = \begin{bmatrix} 3 \\ 4 \end{bmatrix}$$

Hence,

$$W^{UM, MLA} W^{DKV, MLA} = \begin{bmatrix} 3 \\ 4 \end{bmatrix} \begin{bmatrix} 1 & 2 \end{bmatrix} = \begin{bmatrix} 3 & 6 \\ 4 & 8 \end{bmatrix}$$

Two rows of this product matrix are not identical.

Therefore, this is an example that MLA cannot always be represented by GQA.

""" END OF THIS PART """


---

MLA does not only enjoy its advantage of being more general than MHA and GQA, it is also computationally more efficient.

**An intuitive approach of computing MLA.**

1. Compute the key-projection matrix $W^{UK, MLA} W^{DKV, MLA} \in \mathbb{R}^{D \times D}$ and the value-projection matrix $W^{UV, MLA} W^{DKV, MLA} \in \mathbb{R}^{D \times D}$.
2. Follow the standard steps in MHA.

This approach is hereafter called a **vanilla approach**. This approach fails to enjoy the low-rank feature of $W^{DKV, MLA}$, $W^{UK, MLA}$, and $W^{UV, MLA}$.

---

## Part 12 (10 points, non-coding task)

In this part, you are asked to study an alternative approach to compute MLA.

1. Find a head-independent reduced key-projection matrix $\hat{W}^{K, MLA} \in \mathbb{R}^{r \times D}$ and a reduced query-projection matrix $\hat{W}^{Q, MLA} \in \mathbb{R}^{H \cdot r \times D}$, such that

   - The reduced key at position $l_2$ for head $h$ in a being attended sequence is head-independent and is given by:

     $$\hat{k}_{l_2} = \hat{W}^{K, MLA} y_{l_2} \in \mathbb{R}^r$$

   - The reduced query at position $l_1$ for head $h$ in an attending sequence is given by:

     $$\hat{q}_{l_1, h} = \hat{W}^{Q, MLA}_h x_{l_1} \in \mathbb{R}^r$$

     where

     $$\hat{W}^{Q, MLA} = \begin{bmatrix} \hat{W}^{Q, MLA}_0 \\ \hat{W}^{Q, MLA}_1 \\ \vdots \\ \hat{W}^{Q, MLA}_{H-1} \end{bmatrix}$$

   - The attention score (query-key similarity) is invariant in both the original and the reduced forms. That is

     $$\frac{q_{l_1, h}^\top k_{l_2, h}}{\sqrt{D/H}} = \frac{\hat{q}_{l_1, h}^\top \hat{k}_{l_2}}{\sqrt{r}} \tag{1}$$

2. Find a head-independent reduced value-projection matrix $\hat{W}^{V, MLA} \in \mathbb{R}^{r \times D}$ and a reduced out-projection matrix $\hat{W}^{O, MLA} \in \mathbb{R}^{D \times H \cdot r}$, such that

   - The reduced value with head $h$ on position $l_2$ in a being attended sequence is head-independent and is given by:

     $$\hat{v}_{l_2} = \hat{W}^{V, MLA} y_{l_2} \in \mathbb{R}^r$$

   - Post-out-projection is invariant in both the original and the reduced forms.

     Let

     $$\hat{W}^{O, MLA} = \begin{bmatrix} \hat{W}^{O, MLA}_0 & \hat{W}^{O, MLA}_1 & \cdots & \hat{W}^{O, MLA}_{H-1} \end{bmatrix}$$

     Then we must have

     $$\sum_{h=0}^{H-1} W^O_h \sum_{l_2=0}^{L_2-1} \alpha_{h, l_1}^{l_2} v_{l_2, h} = \sum_{h=0}^{H-1} \hat{W}^{O, MLA}_h \sum_{l_2=0}^{L_2-1} \alpha_{h, l_1}^{l_2} \hat{v}_{l_2} \tag{2}$$

Your answer of $\hat{W}^{K, MLA}$, $\hat{W}^{V, MLA}$, $\hat{W}^{Q, MLA}$, and $\hat{W}^{O, MLA}$ should be written in terms of $W^{DKV}$, $W^{UK}$, $W^{UV}$, $W^Q$, and $W^O$.


**Answer:**

**First, we study Equation (1).**

For the LHS in (1), we have

$$\frac{q_{l_1, h}^\top k_{l_2, h}}{\sqrt{D/H}} = \frac{1}{\sqrt{D/H}} (W^Q_h x_{l_1})^\top (W^{UK}_h W^{DKV} y_{l_2}) = \frac{1}{\sqrt{D/H}} x_{l_1}^\top W^{Q, \top}_h W^{UK}_h W^{DKV} y_{l_2}   \tag{1.1}$$

For the RHS in (1), we have

$$\frac{\hat{q}_{l_1, h}^\top \hat{k}_{l_2}}{\sqrt{r}} = \frac{1}{\sqrt{r}} (\hat{W}^{Q, MLA}_h x_{l_1})^\top (\hat{W}^{K, MLA} y_{l_2}) = \frac{1}{\sqrt{r}} x_{l_1}^\top \hat{W}^{Q, MLA, \top}_h \hat{W}^{K, MLA} y_{l_2} \tag{1.2}$$

By equating (1.1) and (1.2), we can set

$$\boxed{\hat{W}^{K, MLA} = W^{DKV}}$$

and

$$\hat{W}^{Q, MLA}_h = \frac{\sqrt{r}}{\sqrt{D/H}} W^{UK, \top}_h W^Q_h$$

Therefore,

$$\boxed{\hat{W}^{Q, MLA} = \frac{\sqrt{r}}{\sqrt{D/H}} \begin{bmatrix} W^{UK, \top}_0 W^Q_0 \\ W^{UK, \top}_1 W^Q_1 \\ \vdots \\ W^{UK, \top}_{H-1} W^Q_{H-1} \end{bmatrix}}$$

---

**Second, we study Equation (2).**

For the LHS in (2), we have

$$\sum_{h=0}^{H-1} W^O_h \sum_{l_2=0}^{L_2-1} \alpha_{h, l_1}^{l_2} v_{l_2, h} = \sum_{h=0}^{H-1} \sum_{l_2=0}^{L_2-1} \alpha_{h, l_1}^{l_2} W^O_h W^{UV}_h W^{DKV} y_{l_2} \tag{2.1}$$

For the RHS in (2), we have

$$\sum_{h=0}^{H-1} \hat{W}^{O, MLA}_h \sum_{l_2=0}^{L_2-1} \alpha_{h, l_1}^{l_2} \hat{v}_{l_2} = \sum_{h=0}^{H-1} \sum_{l_2=0}^{L_2-1} \alpha_{h, l_1}^{l_2} \hat{W}^{O, MLA}_h \hat{W}^{V, MLA} y_{l_2} \tag{2.2}$$

By equating (2.1) and (2.2), we can set

$$\boxed{\hat{W}^{V, MLA} = W^{DKV}}$$

and

$$\hat{W}^{O, MLA}_h = W^O_h W^{UV}_h$$

Therefore

$$\boxed{\hat{W}^{O, MLA} = \begin{bmatrix} W^O_0 W^{UV}_0 & W^O_1 W^{UV}_1 & \cdots & W^O_{H-1} W^{UV}_{H-1} \end{bmatrix}}$$

""" END OF THIS PART """


---

## Part 13 (5 points, coding task)

Do the following tasks:

1. Define a function called `reduced_matrices`.
   - **Input arguments:** `W_DKV`, `W_UK`, `W_UV`, `W_Q`, `W_O`, `H`
   - **Outputs:** `W_K_MLA_hat`, `W_V_MLA_hat`, `W_Q_MLA_hat`, `W_O_MLA_hat`
   - **Requirement of your code:**
     - The code of computing each output must be in one line
     - Loop is not allowed

2. Set your device as gpu:
   ```python
   device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
   ```

3. Construct the following synthetic data:
   ```python
   D = 1024
   H = 32
   D_qkv = D // H
   r = 50
   
   W_DKV = torch.randn(r, D)
   W_UK = torch.randn(D, r)
   W_UV = torch.randn(D, r)
   W_Q = torch.randn(D, D)
   W_O = torch.randn(D, D)
   
   B = 32
   L_1 = 100
   L_2 = 300
   
   x = torch.randn(B, L_1, D).to(device)
   y = torch.randn(B, L_2, D).to(device)
   ```

4. Study a vanilla attention model
   - Initialize the model
     ```python
     model_MHA_vanilla = MyMHA(D, D, D_qkv, D_qkv, H)
     ```
   - Update model parameters
     - `model_MHA_vanilla.W_K.weight`, `model_MHA_vanilla.W_V.weight`, `model_MHA_vanilla.W_Q.weight`, `model_MHA_vanilla.W_O.weight`
   - Compute the output
     ```python
     output_vanilla = model_MHA_vanilla(x, y)
     ```

5. Study a reduced attention model
   - Initialize the model
     ```python
     model_MHA_reduced = MyMHA(D, D, r, r, H)
     ```
   - Update model parameters
     - `model_MHA_reduced.W_K.weight`, `model_MHA_reduced.W_V.weight`, `model_MHA_reduced.W_Q.weight`, `model_MHA_reduced.W_O.weight`
   - Compute the output
     ```python
     output_reduced = model_MHA_reduced(x, y)
     ```

6. Check the correctness of the reduced model by computing and printing a relative error:
   ```python
   relative_error = mse_output**.5 / torch.mean(output_vanilla**2)**.5
   ```


In [None]:
### WRITE YOUR SOLUTION HERE ###

# Function
def reduced_matrices(W_DKV, W_UK, W_UV, W_Q, W_O, H):
    r = W_DKV.shape[0]
    D = W_DKV.shape[1]

    W_K_MLA_hat = W_DKV
    W_V_MLA_hat = W_DKV
    W_Q_MLA_hat = (W_UK.reshape(H, -1, r).transpose(-2, -1) @ W_Q.reshape(H, -1, D)).reshape(-1, D) * (r/(D/H))**.5
    W_O_MLA_hat = (W_O.reshape(D, H, -1).transpose(0, 1) @ W_UV.reshape(H, -1, r)).transpose(0, 1).reshape(D, -1)

    return W_K_MLA_hat, W_V_MLA_hat, W_Q_MLA_hat, W_O_MLA_hat

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data
D = 1024
H = 32
D_qkv = D // H
r = 50

W_DKV = torch.randn(r, D)
W_UK = torch.randn(D, r)
W_UV = torch.randn(D, r)
W_Q = torch.randn(D, D)
W_O = torch.randn(D, D)

B = 32
L_1 = 100
L_2 = 300

x = torch.randn(B, L_1, D).to(device)
y = torch.randn(B, L_2, D).to(device)

# Vanilla model
model_MHA_vanilla = MyMHA(D, D, D_qkv, D_qkv, H)
model_MHA_vanilla.W_K.weight = nn.Parameter(W_UK @ W_DKV)
model_MHA_vanilla.W_V.weight = nn.Parameter(W_UV @ W_DKV)
model_MHA_vanilla.W_Q.weight = nn.Parameter(W_Q)
model_MHA_vanilla.W_O.weight = nn.Parameter(W_O)

model_MHA_vanilla.to(device)
output_vanilla = model_MHA_vanilla(x, y)

# Reduced model
model_MHA_reduced = MyMHA(D, D, r, r, H)
W_K_MLA_hat, W_V_MLA_hat, W_Q_MLA_hat, W_O_MLA_hat = reduced_matrices(W_DKV, W_UK, W_UV, W_Q, W_O, H)
model_MHA_reduced.W_K.weight = nn.Parameter(torch.concatenate([W_K_MLA_hat] * H, dim=0))
model_MHA_reduced.W_V.weight = nn.Parameter(torch.concatenate([W_V_MLA_hat] * H, dim=0))
model_MHA_reduced.W_Q.weight = nn.Parameter(W_Q_MLA_hat)
model_MHA_reduced.W_O.weight = nn.Parameter(W_O_MLA_hat)

model_MHA_reduced.to(device)
output_reduced = model_MHA_reduced(x, y)

# Check the correctness of the reduced model
mse_output = torch.mean((output_vanilla - output_reduced)**2)
relative_error = mse_output**.5 / torch.mean(output_vanilla**2)**.5

print(f"Relative error: {relative_error.item()}")

""" END OF THIS PART """

---

## Part 14 (5 points, non-coding task)

In generative AI, such as GPT, we autoprogressively generate tokens. For a given position $l$, the keys and values on this position $k_l$ and $v_l$ are repeatly used in generating tokens for positions $l' > l$.

Therefore, the values of $k_l$ and $v_l$ are typically stored in cache (no need to revise your code in earlier parts if your code does not support this). We call such storage as **kv-cache**.

Do the following tasks to compute kv-cache in different models while doing autoregressive inference: (reasoning is required)

1. In MHA, the kv-cache at each position is $2D$. Explain why.

2. In MLA, what is the kv-cache at each position?


**Answer:**

1. In MHA, $k_l, v_l \in \mathbb{R}^D$. Therefore, the kv-cache at each position is $2D$.

2. In MLA, because $W^{DKV} \in \mathbb{R}^{r \times D}$, we have $\hat{k}_l, \hat{v}_l \in \mathbb{R}^r$.

   In addition, because $\hat{k}_l = \hat{v}_l$ (both are computed as $W^{DKV} y_l$), the kv-cache at each position is $r$.

""" END OF THIS PART """
