**Description**: Sigmoid loss for text similarity training. Inspired by
[SigLIP](https://arxiv.org/abs/2303.15343). Benefits are outlined in the SigLIP paper;
it's easier to increase the batch size to get more negatives.

**Usage**: run on a T4 GPU.

Modified from this SentenceTransformers
[script](https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/matryoshka/matryoshka_nli.py).

In [None]:
!pip install datasets --upgrade sentence-transformers

In [None]:
!pip uninstall wandb

In [None]:
from datetime import datetime
from typing import Any, Iterable

from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    losses,
    util,
)
from sentence_transformers.evaluation import (
    EmbeddingSimilarityEvaluator,
    SimilarityFunction,
)
from sentence_transformers.training_args import BatchSamplers
import torch
import tqdm

In [None]:
USE_CUSTOM = True

In [None]:
model_name = "distilroberta-base"

batch_size = 128
num_train_epochs = 1

# Save path of the model
output_dir = f"output/sigltt_nli_{model_name.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

In [None]:
# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically
# create one with "mean" pooling.
model = SentenceTransformer(model_name)
# If we want, we can limit the maximum sequence length for the model
# model.max_seq_length = 75

In [None]:
# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli
train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")

# If you wish, you can limit the number of training samples
train_dataset = train_dataset.select(range(10_000))

# Quick demo

In [None]:
b = 3
_nn = 3
m = b * (1 + _nn)  # m = len(candidates)

scores = torch.randn((b, m))
Y = torch.concat((torch.eye(b), torch.zeros((b, m - b))), dim=1)

In [None]:
loss = torch.nn.BCEWithLogitsLoss(reduction="sum")
loss(scores, Y) / b

In [None]:
log_sigmoid = torch.nn.LogSigmoid()
-torch.sum(log_sigmoid((2 * Y - 1) * scores)) / b

# Loss implementation

In [None]:
# 3. Define our training loss
class MultipleNegativesRankingSigmoidLoss(torch.nn.Module):
    def __init__(
        self,
        model: SentenceTransformer,
        scale: float = 20.0,
        similarity_fct=util.cos_sim,
        bias: float = -10.0,
    ) -> None:
        super(MultipleNegativesRankingSigmoidLoss, self).__init__()
        self.model = model
        self.scale = torch.nn.Parameter(torch.tensor(scale, device=model.device))
        self.similarity_fct = similarity_fct
        self.bias = torch.nn.Parameter(torch.tensor(bias, device=model.device))
        self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="sum")

    def forward(
        self, sentence_features: Iterable[dict[str, torch.Tensor]], labels: torch.Tensor
    ) -> torch.Tensor:
        # Compute the embeddings and distribute them to anchor and candidates (positive and optionally negatives)
        embeddings = [
            self.model(sentence_feature)["sentence_embedding"]
            for sentence_feature in sentence_features
        ]
        anchors = embeddings[0]  # (batch_size, embedding_dim)
        candidates = torch.cat(
            embeddings[1:]
        )  # (batch_size * (1 + num_negatives), embedding_dim)

        # For every anchor, we compute the similarity to all other candidates (positives and negatives),
        # also from other anchors. This gives us a lot of in-batch negatives.
        scores: torch.Tensor = (
            self.similarity_fct(anchors, candidates) * self.scale
        ) + self.bias
        # (batch_size, batch_size * (1 + num_negatives))

        # anchor[i] should be most similar to candidates[i], as that is the paired positive,
        # so the label for anchor[i] is i
        b = len(anchors)
        m = len(candidates)
        labels = torch.concat(
            (
                torch.eye(b, device=scores.device),
                torch.zeros((b, m - b), device=scores.device),
            ),
            dim=1,
        )
        return self.bce_loss(scores, labels) / b

    def get_config_dict(self) -> dict[str, Any]:
        return {
            "scale": self.scale.item(),
            "similarity_fct": self.similarity_fct.__name__,
            "bias": self.bias.item(),
        }

    @property
    def citation(self) -> str:
        return """
@inproceedings{zhai2023sigmoid,
    title={Sigmoid loss for language image pre-training},
    author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
    booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
    pages={11975--11986},
    year={2023}
}
"""

In [None]:
# 3. Define our training loss
class AccumulatedMultipleNegativesRankingSigmoidLoss(torch.nn.Module):
    def __init__(
        self,
        model: SentenceTransformer,
        scale: float = 20.0,
        bias: float = -10.0,
        similarity_fct=util.cos_sim,
        mini_batch_size: int = 32,
        show_progress_bar: bool = False,
    ) -> None:
        super(AccumulatedMultipleNegativesRankingSigmoidLoss, self).__init__()
        self.model = model
        self.scale = torch.nn.Parameter(torch.tensor(scale, device=model.device))
        self.bias = torch.nn.Parameter(torch.tensor(bias, device=model.device))
        self.similarity_fct = similarity_fct
        self.mini_batch_size = mini_batch_size
        self.show_progress_bar = show_progress_bar
        self.bce_loss = torch.nn.BCEWithLogitsLoss()

    def _calculate_loss_mini_batch(
        self, begin: int, anchors_mini_batch: torch.Tensor, candidates: torch.Tensor
    ) -> torch.Tensor:
        """
        Loss for a mini-batch of anchors against all candidates. It's an average.
        """
        mini_batch_size, _ = anchors_mini_batch.shape
        # For every anchor, we compute the similarity to all other candidates (positives
        # and negatives), also from other anchors. This gives us a lot of in-batch
        # negatives.
        scores: torch.Tensor = (
            self.similarity_fct(anchors_mini_batch, candidates) * self.scale
        ) + self.bias
        # (mini_batch_size, full batch size * (1 + num negatives))

        # NOTE: we could additionally batch over candidates, since the loss is just a
        # (double) sum. (Then we'd need to backward in here.) That'd only be useful if
        # there are millions of candidates / something wild where a mini_batch_size of 1
        # still gives you OOMs.

        # anchor[i] should be most similar to candidates[i], as that is the paired
        # positive, so the label for anchor[i] is i. B/c we're batching anchors but not
        # batching candidates, we need to offset the label by begin.
        labels = torch.zeros_like(scores)
        for i in range(mini_batch_size):
            labels[i, begin + i] = 1.0
        return self.bce_loss(scores, labels)

    def forward(
        self, sentence_features: Iterable[dict[str, torch.Tensor]], labels: torch.Tensor
    ) -> torch.Tensor:
        # Compute the embeddings and distribute them to anchor and candidates (positive
        # and optionally negatives)
        embeddings = [
            self.model(sentence_feature)["sentence_embedding"]
            for sentence_feature in sentence_features
        ]
        anchors: torch.Tensor = embeddings[0]  # (batch_size, embedding_dim)
        candidates = torch.cat(
            embeddings[1:]
        )  # (batch_size * (1 + num_negatives), embedding_dim)
        batch_size = len(anchors)

        # Detach from the computation graph => don't back-propagate the gradient to the
        # model's parameters yet. Accumulate the 2 embeddings' gradients.
        anchors_detached = anchors.detach()
        anchors_detached.requires_grad = True
        candidates_detached = candidates.detach()
        candidates_detached.requires_grad = True

        # Accumulate the gradients over mini-batches of anchors, i.e., batch over rows
        # of the similarity matrix.
        losses: list[torch.Tensor] = []
        for begin in tqdm.trange(
            0,
            batch_size,
            self.mini_batch_size,
            desc="Mini-batching",
            disable=not self.show_progress_bar,
        ):
            anchors_mini_batch = anchors_detached[
                begin : (begin + self.mini_batch_size)
            ]
            mini_batch_size, _ = anchors_mini_batch.shape
            loss_mini_batch = self._calculate_loss_mini_batch(
                begin, anchors_mini_batch, candidates_detached
            ) * (mini_batch_size / batch_size)  # re-scale to mimic mean over the batch
            loss_mini_batch.backward()  # accumulate the gradient
            losses.append(loss_mini_batch.detach())

        anchors.backward(gradient=anchors_detached.grad)
        candidates.backward(gradient=candidates_detached.grad)

        loss = sum(losses).requires_grad_()
        # That requires_grad_() lets the trainer call backward on the loss. That call
        # does nothing b/c the loss is a sum over detached tensors. We already
        # accumulated the gradient in the loop.
        return loss

    def get_config_dict(self) -> dict[str, Any]:
        return {
            "scale": self.scale.item(),
            "similarity_fct": self.similarity_fct.__name__,
            "bias": self.bias.item(),
        }

    @property
    def citation(self) -> str:
        return """
@inproceedings{zhai2023sigmoid,
    title={Sigmoid loss for language image pre-training},
    author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
    booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
    pages={11975--11986},
    year={2023}
}
"""

# Demo

In [None]:
bf16 = torch.cuda.is_bf16_supported()
if bf16:
    print("Using mixed precision in bf16")
else:
    print("Not using mixed precision")

In [None]:
# 5. Define the training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=output_dir,
    use_mps_device=False,
    # Optional training parameters:
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_ratio=0.1,
    fp16=False,
    bf16=bf16,
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    seed=42,
)

In [None]:
# 6. Create the trainer
if USE_CUSTOM:
    print("Using **CUSTOM** Sigmoid loss")
    train_loss = AccumulatedMultipleNegativesRankingSigmoidLoss(model)
else:
    print("Using OG MNRL")
    train_loss = losses.MultipleNegativesRankingLoss(model)
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=train_loss,
)

In [None]:
torch.cuda.reset_peak_memory_stats()

In [None]:
trainer.train()  # og

In [None]:
trainer.train()  # custom

In [None]:
if USE_CUSTOM:
    print(train_loss.scale)
    print(train_loss.bias)

In [None]:
# og
peak_memory_allocated = torch.cuda.max_memory_allocated()
peak_memory_reserved = torch.cuda.max_memory_reserved()

print(f"Peak memory allocated: {peak_memory_allocated / 1024**3:.2f} GB")
print(f"Peak memory reserved: {peak_memory_reserved / 1024**3:.2f} GB")

In [None]:
# custom
peak_memory_allocated = torch.cuda.max_memory_allocated()
peak_memory_reserved = torch.cuda.max_memory_reserved()

print(f"Peak memory allocated: {peak_memory_allocated / 1024**3:.2f} GB")
print(f"Peak memory reserved: {peak_memory_reserved / 1024**3:.2f} GB")

In [None]:
# 7. Evaluate the model performance on the STS Benchmark test dataset
test_dataset = load_dataset("sentence-transformers/stsb", split="test")
evaluator = EmbeddingSimilarityEvaluator(
    sentences1=test_dataset["sentence1"],
    sentences2=test_dataset["sentence2"],
    scores=test_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-test",
)
test_result = evaluator(model)

In [None]:
test_result  # og

In [None]:
test_result  # custom

In [None]:
# # 8. Save the trained & evaluated model locally
# final_output_dir = f"{output_dir}/final"
# model.save(final_output_dir)
# final_output_dir