In [None]:
!pip -q install -U sentence-transformers datasets accelerate

import math
import random
import numpy as np
import torch

from datasets import load_dataset
from torch.utils.data import DataLoader

from sentence_transformers import SentenceTransformer, InputExample
from sentence_transformers import losses
from sentence_transformers.util import cos_sim


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

In [None]:
@torch.no_grad()
def retrieval_metrics_mrr_recall_at_k(
    model,
    queries,
    corpus,
    qrels,
    dims_list=(64, 128, 256, None),
    k=10,
    batch_size=64,
):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    qids = list(queries.keys())
    docids = list(corpus.keys())

    q_texts = [queries[qid] for qid in qids]
    d_texts = [corpus[did] for did in docids]

    q_emb = model.encode(q_texts, batch_size=batch_size, convert_to_tensor=True, normalize_embeddings=True)
    d_emb = model.encode(d_texts, batch_size=batch_size, convert_to_tensor=True, normalize_embeddings=True)

    results = {}

    for dim in dims_list:
        if dim is None:
            qe = q_emb
            de = d_emb
            dim_name = "full"
        else:
            qe = q_emb[:, :dim]
            de = d_emb[:, :dim]
            dim_name = str(dim)
            qe = torch.nn.functional.normalize(qe, p=2, dim=1)
            de = torch.nn.functional.normalize(de, p=2, dim=1)

        sims = cos_sim(qe, de)

        mrr_total = 0.0
        recall_total = 0.0

        for i, qid in enumerate(qids):
            rel = qrels.get(qid, set())
            if not rel:
                continue

            topk = torch.topk(sims[i], k=min(k, sims.shape[1]), largest=True).indices.tolist()
            topk_docids = [docids[j] for j in topk]

            recall_total += 1.0 if any(d in rel for d in topk_docids) else 0.0

            rr = 0.0
            for rank, d in enumerate(topk_docids, start=1):
                if d in rel:
                    rr = 1.0 / rank
                    break
            mrr_total += rr

        denom = max(1, len(qids))
        results[dim_name] = {f"MRR@{k}": mrr_total / denom, f"Recall@{k}": recall_total / denom}

    return results


def pretty_print(results, title):
    print("\n" + "=" * 80)
    print(title)
    print("=" * 80)
    for dim, metrics in results.items():
        print(f"dim={dim:>4} | " + " | ".join([f"{k}={v:.4f}" for k, v in metrics.items()]))

In [None]:
DATASET_ID = "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1"
SUBSET = "triplet-hard"
SPLIT = "train"

TRAIN_SAMPLES = 4000
EVAL_QUERIES = 300

stream = load_dataset(DATASET_ID, SUBSET, split=SPLIT, streaming=True)

train_examples = []
eval_queries = {}
eval_corpus = {}
eval_qrels = {}

doc_id_counter = 0
qid_counter = 0

for row in stream:
    q = (row.get("query") or "").strip()
    pos = (row.get("positive") or "").strip()
    neg = (row.get("negative") or "").strip()

    if not q or not pos or not neg:
        continue

    train_examples.append(InputExample(texts=[q, pos, neg]))

    if len(eval_queries) < EVAL_QUERIES:
        qid = f"q{qid_counter}"
        qid_counter += 1

        pos_id = f"d{doc_id_counter}"; doc_id_counter += 1
        neg_id = f"d{doc_id_counter}"; doc_id_counter += 1

        eval_queries[qid] = q
        eval_corpus[pos_id] = pos
        eval_corpus[neg_id] = neg
        eval_qrels[qid] = {pos_id}

    if len(train_examples) >= TRAIN_SAMPLES and len(eval_queries) >= EVAL_QUERIES:
        break

print(len(train_examples), len(eval_queries), len(eval_corpus))

In [None]:
MODEL_ID = "BAAI/bge-base-en-v1.5"

device = "cuda" if torch.cuda.is_available() else "cpu"
model = SentenceTransformer(MODEL_ID, device=device)
full_dim = model.get_sentence_embedding_dimension()

baseline = retrieval_metrics_mrr_recall_at_k(
    model,
    queries=eval_queries,
    corpus=eval_corpus,
    qrels=eval_qrels,
    dims_list=(64, 128, 256, None),
    k=10,
)
pretty_print(baseline, "BEFORE")

In [None]:
batch_size = 16
epochs = 1
warmup_steps = 100

train_loader = DataLoader(train_examples, batch_size=batch_size, shuffle=True, drop_last=True)

base_loss = losses.MultipleNegativesRankingLoss(model=model)

mrl_dims = [full_dim, 512, 256, 128, 64] if full_dim >= 768 else [full_dim, 256, 128, 64]
mrl_loss = losses.MatryoshkaLoss(
    model=model,
    loss=base_loss,
    matryoshka_dims=mrl_dims
)

model.fit(
    train_objectives=[(train_loader, mrl_loss)],
    epochs=epochs,
    warmup_steps=warmup_steps,
    show_progress_bar=True,
)

after = retrieval_metrics_mrr_recall_at_k(
    model,
    queries=eval_queries,
    corpus=eval_corpus,
    qrels=eval_qrels,
    dims_list=(64, 128, 256, None),
    k=10,
)
pretty_print(after, "AFTER")

out_dir = "mrl-msmarco-demo"
model.save(out_dir)

m64 = SentenceTransformer(out_dir, truncate_dim=64)
emb = m64.encode(
    ["what is the liberal arts?", "liberal arts covers humanities and sciences"],
    normalize_embeddings=True
)
print(emb.shape)

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m515.2/515.2 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.6/47.6 MB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[?25h

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Train triplets: 4000
Eval queries:  300 | Eval docs: 600
Example row:
  query   : what are the liberal arts?
  positive: liberal arts. 1. the academic course of instruction at a college intended to provide general knowledge and comprising th
  negative: Rather than preparing students for a specific career, liberal arts programs focus on cultural literacy and hone communic


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]


Base model: BAAI/bge-base-en-v1.5
Device: cpu
Embedding dim: 768

BEFORE fine-tuning (baseline truncation performance)
dim=  64 | MRR@10=0.0650 | Recall@10=0.2600
dim= 128 | MRR@10=0.0788 | Recall@10=0.2900
dim= 256 | MRR@10=0.0802 | Recall@10=0.3033
dim=full | MRR@10=0.0815 | Recall@10=0.3000

Training with Matryoshka dims: [768, 512, 256, 128, 64]


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice: