Most code taken from:
https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/matryoshka/matryoshka_nli.py

TODO: PR to make it produce the plots
[here](https://www.sbert.net/examples/training/matryoshka/README.html#results)

In [None]:
from contextlib import contextmanager
import csv
from datetime import datetime
from functools import wraps
import gzip
import logging
import math
import os
import random
import sys
from typing import Any, Callable

from datasets import load_dataset
import numpy as np
from sentence_transformers import (
    models,
    losses,
    datasets,
    LoggingHandler,
    SentenceTransformer,
    util,
    InputExample,
)
from sentence_transformers.evaluation import (
    EmbeddingSimilarityEvaluator,
    SimilarityFunction,
)
from sentence_transformers.util import cos_sim
import torch
from tqdm.auto import tqdm

In [None]:
logging.basicConfig(
    format="%(asctime)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=print,
    handlers=[LoggingHandler()],
)

# Settings

In [None]:
model_name = sys.argv[1] if len(sys.argv) > 1 else "distilroberta-base"
train_batch_size = 128
max_seq_length = 75
num_epochs = 1

MATRYOSHKA_DIMENSIONS = [768, 512, 256, 128, 64]

# Load pretrained model

Not yet similarity trained

In [None]:
model_save_path = (
    "output/matryoshka_nli_"
    + model_name.replace("/", "-")
    + "-"
    + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)

word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
pooling_model = models.Pooling(
    word_embedding_model.get_word_embedding_dimension(), pooling_mode="mean"
)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# Load training and validation data

In [None]:
# Check if dataset exists. If not, download and extract  it
nli_dataset_path = "data/AllNLI.tsv.gz"

if not os.path.exists(nli_dataset_path):
    util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path)

# Read the AllNLI.tsv.gz file and create the training dataset
print("Read AllNLI train dataset")


train_data: dict[str, dict[str, set]] = {}
# {
#     sentence: {
#         "contradiction": contr sentence,
#         "entailment": entail sentence,
#         "neutral": neutral sentence
#     }
# }


def add_to_samples(sent1, sent2, label):
    if sent1 not in train_data:
        train_data[sent1] = {
            "contradiction": set(),
            "entailment": set(),
            "neutral": set(),
        }
    train_data[sent1][label].add(sent2)


with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn:
    reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
    for row in reader:
        if row["split"] == "train":
            sent1 = row["sentence1"].strip()
            sent2 = row["sentence2"].strip()

            add_to_samples(sent1, sent2, row["label"])
            add_to_samples(sent2, sent1, row["label"])  # Also add the opposite


train_samples = []
for sent1, others in train_data.items():
    if len(others["entailment"]) > 0 and len(others["contradiction"]) > 0:
        train_samples.append(
            InputExample(
                texts=[
                    sent1,
                    random.choice(list(others["entailment"])),
                    random.choice(list(others["contradiction"])),
                ]
            )
        )
        train_samples.append(
            InputExample(
                texts=[
                    random.choice(list(others["entailment"])),
                    sent1,
                    random.choice(list(others["contradiction"])),
                ]
            )
        )

print("Train samples: {}".format(len(train_samples)))

In [None]:
# Special data loader that avoid duplicates within a batch
train_dataloader = datasets.NoDuplicatesDataLoader(
    train_samples, batch_size=train_batch_size
)

In [None]:
# Configure the training
warmup_steps = math.ceil(
    len(train_dataloader) * num_epochs * 0.1
)  # 10% of train data for warm-up
print("Warmup-steps: {}".format(warmup_steps))


# Train the model
train_loss = losses.MultipleNegativesRankingLoss(model)
train_loss = losses.MatryoshkaLoss(model, train_loss, MATRYOSHKA_DIMENSIONS)


stsb_dev = load_dataset("mteb/stsbenchmark-sts", split="validation")
dev_evaluator = EmbeddingSimilarityEvaluator(
    stsb_dev["sentence1"],
    stsb_dev["sentence2"],
    [score / 5 for score in stsb_dev["score"]],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-dev",
)


model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=dev_evaluator,
    epochs=num_epochs,
    evaluation_steps=int(len(train_dataloader) * 0.1),
    warmup_steps=warmup_steps,
    output_path=model_save_path,
    use_amp=False,  # Set to True, if your GPU supports FP16 operations
)

# Evaluate on STS benchmark dataset

In [None]:
model = SentenceTransformer(model_save_path)

In [None]:
stsb_test = load_dataset("mteb/stsbenchmark-sts", split="test")

test_evaluator = EmbeddingSimilarityEvaluator(
    stsb_test["sentence1"],
    stsb_test["sentence2"],
    [score / 5 for score in stsb_test["score"]],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-test",
)

In [None]:
# patch encode to truncate lol
@contextmanager
def _monkeypatch_instance_method(obj: Any, method_name: str, new_method: Callable):
    original_method = getattr(obj, method_name)
    # Need to use __get__ when patching methods
    # https://stackoverflow.com/a/28127947/18758987
    try:
        setattr(obj, method_name, new_method.__get__(obj, obj.__class__))
        yield
    finally:
        setattr(obj, method_name, original_method.__get__(obj, obj.__class__))

@contextmanager
def _monkeypatch_instance_method(obj: Any, method_name: str, new_method: Callable):
    original_method = getattr(obj, method_name)
    # Need to use __get__ when patching methods
    # https://stackoverflow.com/a/28127947/18758987
    try:
        setattr(obj, method_name, new_method.__get__(obj, obj.__class__))
        yield
    finally:
        setattr(obj, method_name, original_method.__get__(obj, obj.__class__))


@contextmanager
def truncate_embeddings(model: SentenceTransformer, dim: int):
    """
    In this context, the model outputs embeddings truncated at dimension `dim`.

    Parameters
    ----------
    model : SentenceTransformer
        model where `model.encode` outputs a (D,) or (N, D) array or tensor of
        embeddings
    dim : int
        dimension to truncate at. So a (N, D) array becomes (N, `dim`)
    """

    original_encode = model.encode

    @wraps(original_encode)
    def encode(self, *args, **kwargs) -> np.ndarray | torch.Tensor:
        embeddings = original_encode(*args, **kwargs)
        return embeddings[..., :dim]

    with _monkeypatch_instance_method(model, "encode", encode):
        yield

In [None]:
for dim in tqdm(MATRYOSHKA_DIMENSIONS, desc="Dimensions"):
    output_path = f"{model_save_path}-dim{dim}"
    with truncate_embeddings(model, dim):
        test_evaluator(model, output_path=output_path)