**Description**: custom Matryoshka + MNRL training implementation which saves memory to
the extent that training with a batch size of 768 (w/o grad caching) becomes feasible.
Gradients are numerically identical when training in fp32. Currently, they're different
when using `bf16` mixed precision. Still trying to debug that, but at least the models
are statistically similar.

TODO: this change will have a clearer impact when training a model w/ a large embedding
size, e.g., 8192. But even then, the relative impact is probably tiny, b/c embedding
size is correlated w/ model size. Maybe try dummy-finetuning [this
model](https://huggingface.co/dunzhang/stella_en_400M_v5) to get some optimistic memory
statistics.

Modified from this SentenceTransformers
[script](https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/matryoshka/matryoshka_nli.py).

In [1]:
!pip install datasets sentence-transformers

Collecting datasets
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting sentence-transformers
  Downloading sentence_transformers-3.0.1-py3-none-any.whl.metadata (10 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-2.21.0-py3-none-any.whl (527 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading sentence_transformers-3.0.1-py3-none-any.whl (227 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m227.1/227.1 kB[0

In [2]:
from datetime import datetime
import logging
from typing import Any

from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    losses,
    util,
)
from sentence_transformers.evaluation import (
    EmbeddingSimilarityEvaluator,
    SequentialEvaluator,
    SimilarityFunction,
)
from sentence_transformers.training_args import BatchSamplers
import torch
from transformers.utils import is_apex_available

In [3]:
if is_apex_available():
    from apex import amp

In [4]:
USE_CUSTOM = True
FORCE_CPU = not torch.cuda.is_available()  # HF seems to always put stuff on MPS

In [5]:
# Set the log level to INFO to get more information
logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)

In [6]:
model_name = "distilroberta-base"

batch_size = 768 if not FORCE_CPU else 3
num_batches = 20 if not FORCE_CPU else 3  #  limit training
num_train_epochs = 1
matryoshka_dims = [768, 512, 256, 128, 64]

# Save path of the model
output_dir = f"output/matryoshka_nli_{model_name.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

In [7]:
# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically
# create one with "mean" pooling.
model = SentenceTransformer(model_name)
# If we want, we can limit the maximum sequence length for the model
# model.max_seq_length = 75
logging.info(model)

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/331M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [8]:
if FORCE_CPU:
    model = model.to("cpu")

In [9]:
# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli
train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
logging.info(train_dataset)

# If you wish, you can limit the number of training samples
train_dataset = train_dataset.select(range(batch_size * num_batches))

Downloading readme:   0%|          | 0.00/5.15k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/38.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/782k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/810k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/557850 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/6584 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6609 [00:00<?, ? examples/s]

In [10]:
# 3. Define our training loss
class MultipleNegativesRankingLoss(torch.nn.Module):
    def __init__(
        self,
        model: SentenceTransformer,
        scale: float = 20.0,
        similarity_fct=util.cos_sim,
    ) -> None:
        super(MultipleNegativesRankingLoss, self).__init__()
        self.model = model
        self.scale = scale
        self.similarity_fct = similarity_fct
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

    def forward(self, embeddings_a: torch.Tensor, embeddings_b: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        embeddings_a = torch.nn.functional.normalize(embeddings_a, p=2, dim=-1)
        embeddings_b = torch.nn.functional.normalize(embeddings_b, p=2, dim=-1)
        scores: torch.Tensor = (
            self.similarity_fct(embeddings_a, embeddings_b) * self.scale
        )
        # Example a[i] should match with b[i]
        range_labels = torch.arange(0, scores.size(0), device=scores.device)
        return self.cross_entropy_loss(scores, range_labels)

    def get_config_dict(self) -> dict[str, Any]:
        return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__}

    @property
    def citation(self) -> str:
        return """
@misc{henderson2017efficient,
    title={Efficient Natural Language Response Suggestion for Smart Reply},
    author={Matthew Henderson and Rami Al-Rfou and Brian Strope and Yun-hsuan Sung and Laszlo Lukacs and Ruiqi Guo and Sanjiv Kumar and Balint Miklos and Ray Kurzweil},
    year={2017},
    eprint={1705.00652},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
"""

In [11]:
class MatryoshkaTrainer(SentenceTransformerTrainer):
    def __init__(self, matryoshka_dims: list[int], *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.matryoshka_dims = matryoshka_dims

    def training_step(
        self, model: torch.nn.Module, inputs: dict[str, torch.Tensor | Any]
    ) -> torch.Tensor:
        if FORCE_CPU:
            # TODO: rm this stupid line. I only need it for the CPU test which checks
            # that the gradient is identical to the sentence_transformers training code.
            #
            # Somewhere in the HF training code, the model gets moved to MPS, ignoring
            # the use_mps_device=False flag. To get this mini-test working, force it to
            # the CPU in the training step lol. I had to do this inside
            # sentence_transformers.losses.MatryoshkaLoss.forward as well
            model.to("cpu")

        model.train()

        inputs = self._prepare_inputs(inputs)
        features, labels = self.collect_features(inputs)

        # Get full embeddings
        reps = [
            self.model(sentence_feature)["sentence_embedding"]
            for sentence_feature in features
        ]
        A_full: torch.Tensor = reps[0]
        B_full = torch.cat(reps[1:])

        # Detach it from the computation graph => don't back-propagate the gradient to
        # the model's parameters yet
        A_full_detached = A_full.detach()
        A_full_detached.requires_grad = True
        B_full_detached = B_full.detach()
        B_full_detached.requires_grad = True

        # From the super class' training_step method:
        del inputs
        torch.cuda.empty_cache()

        # Loop over dims, backwarding w/in each
        tr_loss = torch.tensor(0.0, device=model.device)
        for dim in self.matryoshka_dims:
            loss: torch.Tensor = self.loss(
                A_full_detached[..., :dim], B_full_detached[..., :dim], labels
            )

            # Some bells and whistles from super's training_step: multi-gpu and
            # mixed-precision
            #
            # W/o the bells and whistles, the rest of the lines in the loop are just
            #
            # loss.backward()
            #
            # Backward accumulates the gradient for A_full_detached and B_full_detached

            if self.args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training

            if self.use_apex:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                self.accelerator.backward(loss)
            tr_loss += loss.item()

        # Apply chain rule to back-propagate the accumulated gradient to the model's
        # parameters
        if self.use_apex:
            with amp.scale_loss(A_full, self.optimizer) as scaled_A_full:
                scaled_A_full.backward(gradient=A_full_detached.grad)
            with amp.scale_loss(B_full, self.optimizer) as scaled_B_full:
                scaled_B_full.backward(gradient=B_full_detached.grad)
        else:
            self.accelerator.backward(A_full, gradient=A_full_detached.grad)
            self.accelerator.backward(B_full, gradient=B_full_detached.grad)

        return tr_loss.detach() / self.args.gradient_accumulation_steps

In [12]:
bf16 = (not FORCE_CPU) and torch.cuda.is_bf16_supported()
if bf16:
    print("Using mixed precision in bf16")
else:
    print("Not using mixed precision")

Using mixed precision in bf16


In [13]:
# 5. Define the training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=output_dir,
    use_mps_device=False,
    use_cpu=FORCE_CPU,
    # Optional training parameters:
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_ratio=0.1,
    fp16=False,
    bf16=bf16,
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    seed=42,
)

In [14]:
# 6. Create the trainer
if USE_CUSTOM:
    print("Using **CUSTOM** Matryoshka + MNRL implementation")
    train_loss = MultipleNegativesRankingLoss(model)
    trainer = MatryoshkaTrainer(
        matryoshka_dims=matryoshka_dims,
        model=model,
        args=args,
        train_dataset=train_dataset,
        loss=train_loss,
    )
else:
    print("Using original Matryoshka + MNRL implementation")
    inner_train_loss = losses.MultipleNegativesRankingLoss(model)
    train_loss = losses.MatryoshkaLoss(model, inner_train_loss, matryoshka_dims=matryoshka_dims)
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        loss=train_loss,
    )

Using **CUSTOM** Matryoshka + MNRL implementation


In [15]:
torch.cuda.reset_peak_memory_stats()

In [16]:
trainer.train()  # bf16 from custom implementation

Step,Training Loss


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

TrainOutput(global_step=20, training_loss=16.214540100097658, metrics={'train_runtime': 78.8778, 'train_samples_per_second': 194.732, 'train_steps_per_second': 0.254, 'total_flos': 0.0, 'train_loss': 16.214540100097658, 'epoch': 1.0})

In [17]:
peak_memory_allocated = torch.cuda.max_memory_allocated()
peak_memory_reserved = torch.cuda.max_memory_reserved()

print(f"Peak memory allocated: {peak_memory_allocated / 1024**3:.2f} GB")
print(f"Peak memory reserved: {peak_memory_reserved / 1024**3:.2f} GB")

Peak memory allocated: 14.36 GB
Peak memory reserved: 14.55 GB


In [18]:
# 7. Evaluate the model performance on the STS Benchmark test dataset
test_dataset = load_dataset("sentence-transformers/stsb", split="test")
evaluators = []
for dim in matryoshka_dims:
    evaluators.append(
        EmbeddingSimilarityEvaluator(
            sentences1=test_dataset["sentence1"],
            sentences2=test_dataset["sentence2"],
            scores=test_dataset["score"],
            main_similarity=SimilarityFunction.COSINE,
            name=f"sts-test-{dim}",
            truncate_dim=dim,
        )
    )
test_evaluator = SequentialEvaluator(evaluators)
test_result = test_evaluator(model)

Downloading readme:   0%|          | 0.00/1.50k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/471k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/142k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/108k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5749 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1500 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1379 [00:00<?, ? examples/s]

In [19]:
test_result

{'sts-test-768_pearson_cosine': 0.7222296872147473,
 'sts-test-768_spearman_cosine': 0.7279093710315198,
 'sts-test-768_pearson_manhattan': 0.7305430260426594,
 'sts-test-768_spearman_manhattan': 0.7109340923162628,
 'sts-test-768_pearson_euclidean': 0.7327646399348023,
 'sts-test-768_spearman_euclidean': 0.7126823136233756,
 'sts-test-768_pearson_dot': 0.2758066409299485,
 'sts-test-768_spearman_dot': 0.2651155306337087,
 'sts-test-768_pearson_max': 0.7327646399348023,
 'sts-test-768_spearman_max': 0.7279093710315198,
 'sts-test-512_pearson_cosine': 0.730052611674607,
 'sts-test-512_spearman_cosine': 0.7150714764884654,
 'sts-test-512_pearson_manhattan': 0.7315677539545927,
 'sts-test-512_spearman_manhattan': 0.7112519115852266,
 'sts-test-512_pearson_euclidean': 0.7337904993431678,
 'sts-test-512_spearman_euclidean': 0.712717508267019,
 'sts-test-512_pearson_dot': 0.4472367331494563,
 'sts-test-512_spearman_dot': 0.44244456499950935,
 'sts-test-512_pearson_max': 0.7337904993431678,
 

In [20]:
# # 8. Save the trained & evaluated model locally
# final_output_dir = f"{output_dir}/final"
# model.save(final_output_dir)
# final_output_dir