**Description**: custom Matryoshka + MNRL training implementation which might save
memory. Need to evaluate by running this on GPU and monitoring memory.

Modified
https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/matryoshka/matryoshka_nli.py

In [None]:
from datetime import datetime
import logging
from typing import Any

from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    losses,
    util,
)
from sentence_transformers.evaluation import (
    EmbeddingSimilarityEvaluator,
    SequentialEvaluator,
    SimilarityFunction,
)
from sentence_transformers.training_args import BatchSamplers
import torch

In [None]:
USE_CUSTOM = True
FORCE_CPU = not torch.cuda.is_available()  # HF seems to always put stuff on MPS

In [None]:
# Set the log level to INFO to get more information
logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)

In [None]:
model_name = "distilroberta-base"

batch_size = 128 if not FORCE_CPU else 3
num_batches = 20 if not FORCE_CPU else 3  #  limit training
num_train_epochs = 1
matryoshka_dims = [768, 512, 256, 128, 64]

# Save path of the model
output_dir = f"output/matryoshka_nli_{model_name.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

In [None]:
# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically
# create one with "mean" pooling.
model = SentenceTransformer(model_name)
# If we want, we can limit the maximum sequence length for the model
# model.max_seq_length = 75
logging.info(model)

In [None]:
if FORCE_CPU:
    model = model.to("cpu")

In [None]:
# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli
train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
logging.info(train_dataset)

# If you wish, you can limit the number of training samples
train_dataset = train_dataset.select(range(batch_size * num_batches))

In [None]:
# 3. Define our training loss
class MultipleNegativesRankingLoss(torch.nn.Module):
    def __init__(
        self,
        model: SentenceTransformer,
        scale: float = 20.0,
        similarity_fct=util.cos_sim,
    ) -> None:
        super(MultipleNegativesRankingLoss, self).__init__()
        self.model = model
        self.scale = scale
        self.similarity_fct = similarity_fct
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

    def forward(self, embeddings_a: torch.Tensor, embeddings_b: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        embeddings_a = torch.nn.functional.normalize(embeddings_a, p=2, dim=-1)
        embeddings_b = torch.nn.functional.normalize(embeddings_b, p=2, dim=-1)
        scores: torch.Tensor = (
            self.similarity_fct(embeddings_a, embeddings_b) * self.scale
        )
        print(scores.shape)
        # Example a[i] should match with b[i]
        range_labels = torch.arange(0, scores.size(0), device=scores.device)
        return self.cross_entropy_loss(scores, range_labels)

    def get_config_dict(self) -> dict[str, Any]:
        return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__}

    @property
    def citation(self) -> str:
        return """
@misc{henderson2017efficient,
    title={Efficient Natural Language Response Suggestion for Smart Reply}, 
    author={Matthew Henderson and Rami Al-Rfou and Brian Strope and Yun-hsuan Sung and Laszlo Lukacs and Ruiqi Guo and Sanjiv Kumar and Balint Miklos and Ray Kurzweil},
    year={2017},
    eprint={1705.00652},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
"""

In [None]:
class MatryoshkaTrainer(SentenceTransformerTrainer):
    def __init__(self, matryoshka_dims: list[int], *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.matryoshka_dims = matryoshka_dims

    def training_step(
        self, model: torch.nn.Module, inputs: dict[str, torch.Tensor | Any]
    ) -> torch.Tensor:
        if FORCE_CPU:
            # TODO: rm this stupid stuff Somewhere, the model gets moved to MPS,
            # ignoring the use_mps_device=False flag. To get this mini-test working,
            # force it to the CPU in the training step lol. I had to do this inside
            # sentence_transformers.losses.MatryoshkaLoss.forward as well
            model.to("cpu")

        model.train()

        inputs = self._prepare_inputs(inputs)
        features, labels = self.collect_features(inputs)

        # TODO: need the bells and whistles from super class' training_step method.
        # Should also be calling super's compute_loss instead

        # Get full embedding matrix
        reps = [
            self.model(sentence_feature)["sentence_embedding"]
            for sentence_feature in features
        ]
        A_full: torch.Tensor = reps[0]
        B_full = torch.cat(reps[1:])

        # Detach it from the computation graph => don't back-propagate the gradient to
        # the model yet
        A_full_detached = A_full.detach()
        A_full_detached.requires_grad = True
        B_full_detached = B_full.detach()
        B_full_detached.requires_grad = True

        # From the super class' training_step method:
        del inputs
        torch.cuda.empty_cache()

        # Loop over dims, backwarding w/in each
        tr_loss = torch.tensor(0.0)
        for dim in self.matryoshka_dims:
            loss: torch.Tensor = self.loss(
                A_full_detached[..., :dim], B_full_detached[..., :dim], labels
            )
            # Accumulate the gradient for X_full_detached
            self.accelerator.backward(loss)
            tr_loss += loss.item()

        # Apply chain rule to back-propagate the accumulated gradient to the model
        A_full.backward(gradient=A_full_detached.grad)
        B_full.backward(gradient=B_full_detached.grad)

        # # Plain
        # tr_loss = 0.0
        # for dim in self.matryoshka_dims:
        #     tr_loss += self.loss(A_full[..., :dim], B_full[..., :dim], labels)
        # self.accelerator.backward(tr_loss)

        return tr_loss.detach() / self.args.gradient_accumulation_steps

In [None]:
# 5. Define the training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=output_dir,
    use_mps_device=False,
    use_cpu=FORCE_CPU,
    # Optional training parameters:
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_ratio=0.1,
    fp16=False,
    bf16=not FORCE_CPU,
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    seed=42,
)

In [None]:
# 6. Create the trainer
if USE_CUSTOM:
    print("Using **CUSTOM** Matryoshka + MNRL implementation")
    train_loss = MultipleNegativesRankingLoss(model)
    trainer = MatryoshkaTrainer(
        matryoshka_dims=matryoshka_dims,
        model=model,
        args=args,
        train_dataset=train_dataset,
        loss=train_loss,
    )
else:
    print("Using original Matryoshka + MNRL implementation")
    inner_train_loss = losses.MultipleNegativesRankingLoss(model)
    train_loss = losses.MatryoshkaLoss(model, inner_train_loss, matryoshka_dims=matryoshka_dims)
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        loss=train_loss,
    )

In [None]:
trainer.train()

In [None]:
# 7. Evaluate the model performance on the STS Benchmark test dataset
test_dataset = load_dataset("sentence-transformers/stsb", split="test")
evaluators = []
for dim in matryoshka_dims:
    evaluators.append(
        EmbeddingSimilarityEvaluator(
            sentences1=test_dataset["sentence1"],
            sentences2=test_dataset["sentence2"],
            scores=test_dataset["score"],
            main_similarity=SimilarityFunction.COSINE,
            name=f"sts-test-{dim}",
            truncate_dim=dim,
        )
    )
test_evaluator = SequentialEvaluator(evaluators)
test_evaluator(model)

In [None]:
# 8. Save the trained & evaluated model locally
final_output_dir = f"{output_dir}/final"
model.save(final_output_dir)