## Best case optimization

### In general
- **goal:** try to predict the theoretical limit of embedding dimension
- best case: try to optimize both query and document embeddings directly $\implies$ no limit of the natural language $\implies$ free embeddings
- if there are limits for this best case, then there are limits for any real world scenario
- the qrel matrix is given previously

### Experimental settings
- training set:
    - $n$ document embeddings with dimension $d$ - $d$ is controlled
    - $m$ query vectors with dimension $d$
    - only mark top-$k$ relevant documents for each query as relevant in the qrel matrix
    - we aim to have max number of different query results $\implies m:= \binom{n}{k}$
    - 2 queries have same top-$k$ results $\implies$ $+k(n-k)$ similar inequalities $\implies$ no new indepent constrains $\implies$ duplicated row in $2A - 1_{m\times n} \iff $ duplicated rows in $A$
    - we construct $A$ by enumerating all possible cases
- loss function - based on InfoNCE (but engineered without the negative event):
      $L_\text{total}=\displaystyle-\frac{1}{M}\sum_{i=1}^M\log\frac{\sum_{d_r\in R_i}\exp(\text{sim}(q_i, d_r)/\tau)}{\sum_{d_k\in D}\exp(\text{sim}(q_i, d_k)/\tau)}$ where
    - $R_i = \{d_r \in D \mid d_r \text{ relevan to query } q_i\}$
    - $D$ set of documents
    - $k:=2$ - for simplicity
- optimizer: SGD and Adam, but Adam was mainly used for its speed
- normalization after each update
- early stopping if no better result (in the next $1000$ iterations)
- for fixed $d$ increase the $n$ num of docs, until the optimization model can't reach $100\%$ accuracy

In [1]:
!nvidia-smi

Mon Nov  3 23:33:05 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off |   00000000:00:04.0 Off |                    0 |
| N/A   35C    P0             26W /  250W |       0MiB /  16384MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                     

In [2]:
!pip install torch==2.5.1+cu121 --extra-index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu121
Collecting torch==2.5.1+cu121
  Downloading https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp311-cp311-linux_x86_64.whl (780.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m780.5/780.5 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.5.1+cu121)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m85.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.5.1+cu121)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m43.5 MB/s[0m eta [36m0:00:00[0m

In [3]:
import torch

print("CUDA available:", torch.cuda.is_available())


CUDA available: True


Loss function mentioned as footnote in the paper:
$$
L_\text{total}=\displaystyle-\frac{1}{M}\sum_{i=1}^M\log\frac{\sum_{d_r\in R_i}\exp(\text{sim}(q_i, d_r)/\tau)}{\sum_{d_k\in D}\exp(\text{sim}(q_i, d_k)/\tau)}
$$
where
- $R_i = \{d_r \in D \mid d_r \text{ relevan to query } q_i\}$
- $D$ set of documents
- $k:=2$ - for simplicity

The loss function used in [Google's implementation](https://github.com/google-deepmind/limit/blob/main/code/free_embedding_experiment.py):
$$
L_\text{total}=\displaystyle-\frac{1}{M}\sum_{i=1}^m\sum_{d_r\in R_i}\log\frac{\exp(\text{sim}(q_i, d_r)/\tau)}{\sum_{d_k\in D}\exp(\text{sim}(q_i, d_k)/\tau)}
$$

and it turned out, from the source code, that $M$ is the number of positive query-doc pairs (number of ones from the qrel matrix)...

In [4]:
import numpy as np
from itertools import combinations
import torch

class FreeEmbeddingsModel(torch.nn.Module):
    def __init__(self, num_of_docs: int, dimension: int, k: int = 2, temp: float = 0.07):
        super().__init__()

        self.__n = num_of_docs
        self.__d = dimension
        self.__k = k
        self.__temp = temp
        self.__m = None
        self.docs = None
        self.queries = None
        self.__qrel_matrix = None

        self.__qrel_matrix = self.__generate_qrel_matrix(self.__n, self.__k)
        self.__m = self.__qrel_matrix.shape[0]

        self.docs = torch.nn.Parameter(torch.randn(self.__n, self.__d))
        self.queries = torch.nn.Parameter(torch.randn(self.__m, self.__d))
        with torch.no_grad():
            self.queries.div_(self.queries.norm(dim=1, keepdim=True))
            self.docs.div_(self.docs.norm(dim=1, keepdim=True))

    @staticmethod
    def __generate_qrel_matrix(n: int, k: int) -> torch.Tensor:
        combos = list(combinations(range(n), k))
        matrix = torch.zeros((len(combos), n), dtype=torch.int)
        for i, combo in enumerate(combos):
            matrix[i, list(combo)] = 1
        return matrix

    # the loss function from the paper
    # it results the linear d-critical_n curve
    # def forward(self):
    #     self.__qrel_matrix = self.__qrel_matrix.to(self.docs.device)

    #     # normalize the vectors
    #     docs_norm = self.docs / self.docs.norm(dim=1, keepdim=True)
    #     queries_norm = self.queries / self.queries.norm(dim=1, keepdim=True)

    #     sim = queries_norm @ docs_norm.T
    #     print(f"sim.mean()={sim.mean()}, sim.std()={sim.std()}")
    #     exp_sim = torch.exp(sim / self.__temp)

    #     # filter the relevant docs using the qrel matrix as mask
    #     num = (exp_sim * self.__qrel_matrix).sum(dim=1)
    #     den = exp_sim.sum(dim=1)
    #     M = self.__qrel_matrix.sum()
    #     total_loss = -torch.log(num / (den + 1e-12)).sum() / M
    #     return total_loss

    # the loss function from their implementation
    # it results a cubic d-critical_n curve, which is above the expected one
    def forward(self):
        self.__qrel_matrix = self.__qrel_matrix.to(self.docs.device)

        queries_norm = self.queries #/ (self.queries.norm(dim=1, keepdim=True))
        docs_norm = self.docs #/ (self.docs.norm(dim=1, keepdim=True))

        logits = (queries_norm @ docs_norm.T) / self.__temp
        log_probs = torch.log_softmax(logits, dim=1)

        sum_pos_log_probs = (log_probs * self.__qrel_matrix).sum()
        M = self.__qrel_matrix.sum()
        total_loss = -sum_pos_log_probs / M
        return total_loss

    def accuracy(self) -> float:
        docs_norm = self.docs / self.docs.norm(dim=1, keepdim=True)
        queries_norm = self.queries / self.queries.norm(dim=1, keepdim=True)

        sim = queries_norm @ docs_norm.T
        similar_rows = 0
        for i in range(self.__m):
            # use masking to avoid ties
            top_k_mask = self.__qrel_matrix[i].bool()

            pos_vals = sim[i][top_k_mask]
            neg_vals = sim[i][~top_k_mask]

            if pos_vals.numel() == 0 or neg_vals.numel() == 0:
                continue

            similar_rows += int(torch.min(pos_vals) >= torch.max(neg_vals))

        return similar_rows / self.__m

In [5]:
# params from the original experiment
DEFAULT_EXPERIMENT_PARAMS: Dict[str, Any] = {
    "q": None,
    "learning_rate": 0.01,
    "num_iterations": 100000,
    "temperature": 0.1,
    "seed": 42,
    "show_progress": True,
    "device": "gpu",
    "log_interval": 50,
    "early_stopping_patience": 1000,
    "early_stopping_min_delta": 0.00001,
    "early_stopping_monitor_metric": "loss",
    "early_stopping_restore_best_weights": False,
}

NameError: name 'Dict' is not defined

In [None]:
import torch

def train(
    num_of_docs: int,
    dimension: int,
    max_patience: int = 1000,
    temp: float = 0.1,
    learning_rate: float = 0.01,
    max_iters: int = 100000,
    min_delta: float = 0.00001
) -> float:
    min_loss = torch.finfo(torch.float32).max
    max_acc = -1
    best_query_weights = None
    best_doc_weights = None

    device = torch.device("cuda")
    model = FreeEmbeddingsModel(num_of_docs=num_of_docs, dimension=dimension, temp=temp).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    prev_loss = torch.finfo(torch.float32).max
    iters = 0
    patience = 0
    while max_patience > patience and iters < max_iters:
        optimizer.zero_grad()
        loss = model()

        if min_loss - loss > min_delta:
            min_loss = loss
            best_query_weights = model.queries.detach().clone()
            best_doc_weights = model.docs.detach().clone()
            patience = 0
        else:
            patience += 1

        if iters % 1000 == 0:
            accuracy = model.accuracy()
            print(f"[docs={num_of_docs}, dim={dimension}]: epoch #{iters}, patience: {patience}/{max_patience}, accuracy={accuracy}, loss={loss}")
            if accuracy >= 1.0:
                return 1.0

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            model.queries.div_(model.queries.norm(dim=1, keepdim=True))
            model.docs.div_(model.docs.norm(dim=1, keepdim=True))

        iters += 1

    with torch.no_grad():
        model.queries.copy_(best_query_weights)
        model.docs.copy_(best_doc_weights)

    return model.accuracy()

In [None]:
pred = lambda d: -10.5322 + 4.0309*d + 0.0520*d**2 + 0.0037*d**3
pred(15)

In [None]:
import math

def find_critical_num_of_docs(dimension: int) -> int:
    is_accurate = dict()

    product = 5
    while train(num_of_docs=int(product), dimension=dimension) >= 1:
        product *= 1.5

    lower = int(product / 1.5)
    upper = int(product)

    while lower <= upper:
        middle = (lower + upper) // 2

        is_mid_acc = is_accurate.get(middle)
        if is_mid_acc is None:
            is_mid_acc = train(num_of_docs=middle, dimension=dimension) >= 1
            is_accurate[middle] = is_mid_acc

        if not is_mid_acc and middle >= 1:
            is_prev_mid_acc = is_accurate.get(middle - 1)
            if is_prev_mid_acc is None:
                is_prev_mid_acc = train(num_of_docs=middle - 1, dimension=dimension) >= 1
                is_accurate[middle - 1] = is_prev_mid_acc

            if is_prev_mid_acc:
                return middle

        if is_mid_acc:
            lower = middle + 1
        else:
            upper = middle - 1

    return lower

In [None]:
critical_n = []
for d in range(9, 31):
    critical_n.append(find_critical_num_of_docs(d))

In [None]:
critical_n