# (this work is in progress)

I plan on just training a sentence transformer bi-encoder and seeing what happens.
Baseline to beat is random slices. SOTA is Matryoshka.

## Background

Most deep learning architectures end with a linear layer which maps a vector of activations with dimension $d$ to a vector which is used for the loss function. For example, in multi-class classification with $L$ classes, the final linear layer, $W$, is $d \times L$.

The goal of [Matryoshka Representation Learning](https://arxiv.org/abs/2205.13147)<sup>1</sup> (MRL) is to train this linear layer (along with the rest of the network) such that multiple slices of $W$ are viable for performing the downstream task at inference time. For example, if $d = 64$, we might want the following slices to work well: $W_{1:16}$ (which is $16 \times L$), $W_{1:32}, W$. This feature is useful in applications where we want to trade-off computational efficiency and statistical accuracy without having to train multiple models, and without introducing engineering complexity at inference time.

For example, consider a retrieval task where we store and query a database of vector embeddings. (Vector embeddings come from a model which learned to optimize a metric learning loss function.) Reducing the size of the embedding dimension from $64$ to $16$ saves lots of storage and query latency, at (ideally) only a small cost in query accuracy. To make this trade-off, at inference time, the query service simply slices the full embedding, i.e., change `v` to `v[:16]`.

Given a labeled input-output pair, $(\mathbf{x}, y)$, efficient MRL (MRL-E) effectively accomplishes this goal by optimizing the following loss function—

$$
\mathcal{L}_{\text{MRL-E}} = \sum_{m \in \mathcal{M}} \mathcal{L} (W_{1:m}^T f_\theta(\mathbf{x})_{1:m}, y)
$$

—where:
* $\mathcal{L}$ is the loss function of the original task, e.g., cross entropy for a classification task.
* $\mathcal{M}$ is a log-spaced grid of dimension sizes along $[d]$, which makes the sum cheap to compute. For example, for $d = 64$, $\mathcal{M} = \{16, 32, 64\}$. (Figure 5 in the MRL paper demonstrates that using dimensions between the ones in $\mathcal{M}$ at inference time still works well.)
* $f_\theta(\mathbf{x})$ is the last (or second-to-last, depending on how you wanna think about it) vector of activations from the neural network $f_\theta$.

In words, the loss is the sum of the original loss using slices of $W$ and corresponding slices of $f_\theta(\mathbf{x})$.


## A baseline

The loss makes sense, I guess. And it clearly works well in practice. But after staring at it, it looks like there's, intentionally, redundant application of weights in $W$ from earlier dimensions. That's simply because $W_{1:16}$ is included in $W_{1:32}$ and $W_{1:64}$, and $W_{1:32}$ is included in $W_{1:64}$.

I'm wondering if there's a simpler way to frame $\mathcal{L}_{\text{MRL-E}}$. The simplification doesn't have to be exact. It just has to roughly accomplish the same thing.


## Analysis

One simplification I'm thinking about is a diagonal $d \times d$ matrix of weights, like in weighted least squares. These are hyperparameters. My hope is that they're easy to set because they can be pulled out of the gradient of $\mathcal{L}_{\text{MRL-E}}$.

To work through the gradient computation, let's simplify the problem. Change it to a linear regression, and set $\mathcal{M} = \{m_1, 2m_1\}$ (i.e., $d = 2m_1$):

$$
\mathcal{L}_{\text{MRL-E}} = (\mathbf{w}_{1:m_1}^T \mathbf{x}_{1:m_1} - y)^2 + (\mathbf{w}^T \mathbf{x} - y)^2.
$$

The gradient wrt $\mathbf{w}$ is:

$$
\nabla_{\mathbf{w}}\mathcal{L}_{\text{MRL-E}} = (2(\mathbf{w}_{1:m_1}^T \mathbf{x}_{1:m_1} - y) \mathbf{x}_{1:m_1}, \mathbf{0}_{m_1}) + 2(\mathbf{w}^T \mathbf{x} - y) \mathbf{x}.
$$

I think it's useful to understand how earlier vs later weights in $\mathbf{w}$ change. For $k = 1, 2, \dots, m_1$—

$$
\begin{align*}
(\nabla_{\mathbf{w}}\mathcal{L}_{\text{MRL-E}})_k &= 2(\mathbf{w}_{1:m_1}^T \mathbf{x}_{1:m_1} - y) x_k + 2(\mathbf{w}_{1:m_1}^T \mathbf{x}_{1:m_1} + \sum_{j=m_1 + 1}^{d} w_j x_j - y) x_k \\
&= 2x_k (2\mathbf{w}_{1:m_1}^T \mathbf{x}_{1:m_1} + \sum_{j=m_1 + 1}^{d} w_j x_j - 2y).
\end{align*}
$$

—and for $k = m_1 + 1, m_1 + 2, \dots, 2m_1$—

$$
\begin{align*}
(\nabla_{\mathbf{w}}\mathcal{L}_{\text{MRL-E}})_k &= 2(\mathbf{w}^T \mathbf{x} - y) x_k \\
&= 2x_k (\mathbf{w}_{1:m_1}^T \mathbf{x}_{1:m_1} + \sum_{j=m_1 + 1}^{d} w_j x_j - y).
\end{align*}
$$

I don't think I can get the same gradients from re-weighing alone. Let's see how these gradients differ from ones where we apply a diagonal weight matrix, $\Lambda = \text{diag}(\mathbf{\lambda})$, where $\mathbf{\lambda} = (2 \cdot \mathbf{1}_{m_1}, \mathbf{1}_{m_1})$:

$$
\mathcal{L}_{\text{re-weigh}} = (\mathbf{w}^T \Lambda \mathbf{x} - y)^2.
$$

The gradient wrt $\mathbf{w}$ is:

$$
\nabla_{\mathbf{w}}\mathcal{L}_{\text{re-weigh}} = 2(\mathbf{w}^T \Lambda \mathbf{x} - y) \Lambda \mathbf{x}.
$$

In other words, for $k = 1, 2, \dots, m_1$—

$$
\begin{align*}
(\nabla_{\mathbf{w}}\mathcal{L}_{\text{re-weigh}})_k &= 2(\mathbf{w}^T \Lambda \mathbf{x} - y) \lambda_k x_k \\
&= 2 \Bigg( \sum_{j=1}^{d} w_j \lambda_j x_j - y \Bigg) \lambda_k x_k \\
&= 2 \Bigg( 2 \sum_{j=1}^{m_1} w_j x_j + \sum_{j=m_1 + 1}^{d} w_j x_j - y \Bigg) 2 x_k && \text{plug in $\lambda$s} \\
&= 4 x_k \Bigg( 2 \mathbf{w}_{1:m_1}^T \mathbf{x}_{1:m_1} + \sum_{j=m_1 + 1}^{d} w_j x_j - y \Bigg).
\end{align*}
$$

—and for $k = m_1 + 1, m_1 + 2, \dots, 2m_1$—

$$
\begin{align*}
(\nabla_{\mathbf{w}}\mathcal{L}_{\text{re-weigh}})_k &= 2 x_k \Bigg( 2 \mathbf{w}_{1:m_1}^T \mathbf{x}_{1:m_1} + \sum_{j=m_1 + 1}^{d} w_j x_j - y \Bigg).
\end{align*}
$$

Looking at the gradients, both losses double the weight of $\mathbf{w}_{1:m_1}^T \mathbf{x}_{1:m_1}$ going from the the last $m_1$ derivatives to the first. I don't really see fundamental differences beyond that. Maybe I'm missing something.

(In non-linear regression, there are parameters (instead of just $\mathbf{x}$) which get updated. I'm not sure how important it is to incorporate these into the analysis. I'm hoping that looking at the gradient of the loss wrt $\mathbf{w}$ is enough to provide insight.)

I find $\mathcal{L}_{\text{re-weigh}}$ more intuitive b/c (I think?) it directly increases the importance of earlier dimensions. It also doesn't have to be discretely structured, which might avoid undesirable behavior like [this](https://twitter.com/dhruv___anand/status/1752641057278550199). It can continuously decay: $\lambda_j = b - \log_2 m$.

Geometrically, re-weighing results in more separation along dimensions which have higher weight. (I drew a little picture of $(1, 1)$ and $(2, 1)$ and observed that.)


## Questions

1. What is the substantive, statistical/optimization difference between $\mathcal{L}_{\text{MRL-E}}$ and $\mathcal{L}_{\text{re-weigh}}$? Their gradients are clearly mathematically different. An explanation in words and/or better math would help me understand why this difference matters.

2. Is there a reason why $\mathcal{L}_{\text{re-weigh}}$ wouldn't work well? It's simple enough that there should either be (1) a semi-obvious theoretical reason for why it doesn't accomplish the same goal as MRL, or (2) a paper empirically demonstrating that it doesn't work well.


## References

1. Kusupati, A., Bhatt, G., Rege, A., Wallingford, M., Sinha, A., Ramanujan, V., ... & Farhadi, A. (2022). [Matryoshka representation learning](https://arxiv.org/abs/2205.13147). Advances in Neural Information Processing Systems, 35, 30233-30249.

# Quick test gradients

make sure math is right

In [1]:
import torch

In [2]:
# Input dimension
d = 32
assert d % 2 == 0, "d must be even, ideally a power of 2"

In [3]:
m = int(d / 2)

x = torch.randn(d)  # from the big network
w = torch.randn(d, requires_grad=True)  # Matryoshka linear layer
y = torch.randn(1)  # label


def mse(w: torch.Tensor, x: torch.Tensor, y: float):
    # forward and loss
    return ((w @ x) - y) ** 2


mrle_loss = mse(w[:m], x[:m], y) + mse(w, x, y)
mrle_loss.backward()


# hand-calculated gradient of mrle_loss wrt w
w_grad = (
    torch.concat([(2 * ((w[:m] @ x[:m]) - y) * x[:m]), torch.zeros(d - m)])
    + 2 * ((w @ x) - y) * x
).detach()

assert torch.allclose(w_grad, w.grad)

Test the component breakdown

In [4]:
w_grad_first = 2 * x[:m] * (2 * (w[:m] @ x[:m]) + (w[m:] @ x[m:]) - 2 * y)

assert torch.allclose(w_grad[:m], w_grad_first)

In [5]:
w_grad_last = 2 * x[m:] * (1 * (w[:m] @ x[:m]) + (w[m:] @ x[m:]) - 1 * y)

assert torch.allclose(w_grad[m:], w_grad_last)

Test the diagonal thing

In [6]:
x = torch.randn(d)  # from the big network
w = torch.randn(d, requires_grad=True)  # Matryoshka linear layer
y = torch.randn(1)  # label

In [7]:
_diag_vec = torch.concat([2 * torch.ones(m), torch.ones(d - m)])
diag = torch.diag(_diag_vec)

loss_diag = ((w @ diag @ x) - y) ** 2

loss_diag.backward()

In [8]:
w_grad_diag = (2 * ((w @ diag @ x) - y)) * (diag @ x)
assert torch.allclose(w_grad_diag, w.grad)

Test their components

In [9]:
w_grad_first = 4 * x[:m] * (2 * (w[:m] @ x[:m]) + (w[m:] @ x[m:]) - 1 * y)

assert torch.allclose(w_grad_diag[:m], w_grad_first)

In [10]:
w_grad_last = 2 * x[m:] * (2 * (w[:m] @ x[:m]) + (w[m:] @ x[m:]) - 1 * y)

assert torch.allclose(w_grad_diag[m:], w_grad_last)

Diagonal candidate:

In [11]:
num_m = torch.log2(torch.tensor(256)) - 3
torch.sqrt(torch.linspace(num_m, 1, steps=d))

tensor([2.2361, 2.2070, 2.1776, 2.1478, 2.1175, 2.0868, 2.0557, 2.0240, 1.9919,
        1.9593, 1.9261, 1.8923, 1.8579, 1.8228, 1.7871, 1.7506, 1.7133, 1.6752,
        1.6363, 1.5964, 1.5554, 1.5134, 1.4701, 1.4256, 1.3796, 1.3320, 1.2826,
        1.2313, 1.1778, 1.1216, 1.0626, 1.0000])