In [None]:
import csv
from datetime import datetime
import gzip
import math
import os
import random

from datasets import load_dataset
from sentence_transformers import (
    models,
    losses,
    datasets,
    SentenceTransformer,
    util,
    InputExample,
)
from sentence_transformers.evaluation import (
    EmbeddingSimilarityEvaluator,
    SimilarityFunction,
)
from sentence_transformers.util import cos_sim
import torch

# HP

In [None]:
model_name = "distilroberta-base"
train_batch_size = 128  # The larger you select this, the better the results (usually). But it requires more GPU memory
max_seq_length = 75
num_epochs = 1

# Load non-similarity LM

In [None]:
# Here we define our SentenceTransformer model
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
pooling_model = models.Pooling(
    word_embedding_model.get_word_embedding_dimension(), pooling_mode="mean"
)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# Load data

In [None]:
# Check if dataset exists. If not, download and extract  it
nli_dataset_path = "data/AllNLI.tsv.gz"

if not os.path.exists(nli_dataset_path):
    util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path)

# Read the AllNLI.tsv.gz file and create the training dataset
print("Read AllNLI train dataset")


train_data: dict[str, dict[str, set]] = {}


def add_to_samples(sent1, sent2, label):
    if sent1 not in train_data:
        train_data[sent1] = {
            "contradiction": set(),
            "entailment": set(),
            "neutral": set(),
        }
    train_data[sent1][label].add(sent2)


with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn:
    reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
    for row in reader:
        if row["split"] == "train":
            sent1 = row["sentence1"].strip()
            sent2 = row["sentence2"].strip()

            add_to_samples(sent1, sent2, row["label"])
            add_to_samples(sent2, sent1, row["label"])  # Also add the opposite


train_samples = []
for sent1, others in train_data.items():
    if len(others["entailment"]) > 0 and len(others["contradiction"]) > 0:
        train_samples.append(
            InputExample(
                texts=[
                    sent1,
                    random.choice(list(others["entailment"])),
                    random.choice(list(others["contradiction"])),
                ]
            )
        )
        train_samples.append(
            InputExample(
                texts=[
                    random.choice(list(others["entailment"])),
                    sent1,
                    random.choice(list(others["contradiction"])),
                ]
            )
        )

print("Train samples: {}".format(len(train_samples)))

In [None]:
# Special data loader that avoid duplicates within a batch
train_dataloader = datasets.NoDuplicatesDataLoader(
    train_samples[:200_000], batch_size=train_batch_size
)

# Dev data

In [None]:
stsb_dev = load_dataset("mteb/stsbenchmark-sts", split="validation")
dev_evaluator = EmbeddingSimilarityEvaluator(
    stsb_dev["sentence1"],
    stsb_dev["sentence2"],
    [score / 5 for score in stsb_dev["score"]],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-dev",
)

# Diagonaloss

In [None]:
dim: int = model.get_sentence_embedding_dimension()
max_weight_squared = torch.log2(torch.tensor(dim)) - 3
diagonal_vec = torch.sqrt(
    torch.linspace(
        max_weight_squared, 1, steps=dim, device=model.device, requires_grad=False
    )
)

In [None]:
def _weighted_cosine_similarity(
    w: torch.Tensor, a: torch.Tensor, b: torch.Tensor
) -> torch.Tensor:
    w_sqrt = torch.sqrt(w)
    a = a * w_sqrt
    b = b * w_sqrt
    return cos_sim(a, b)


def weighted_cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    # similarity_fct doesn't like functools.partial b/c it doesn't have a __name__ attr
    return _weighted_cosine_similarity(diagonal_vec, a, b)

In [None]:
train_loss = losses.MultipleNegativesRankingLoss(
    model=model, similarity_fct=weighted_cosine_similarity
)

# Train

In [None]:
# Configure the training
warmup_steps = math.ceil(
    len(train_dataloader) * num_epochs * 0.1
)  # 10% of train data for warm-up
print("Warmup-steps: {}".format(warmup_steps))

model_save_path = (
    "output/matryoshka_nli_"
    + model_name.replace("/", "-")
    + "-"
    + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)

# Train the model
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=dev_evaluator,
    epochs=num_epochs,
    evaluation_steps=int(len(train_dataloader) * 0.1),
    warmup_steps=warmup_steps,
    output_path=model_save_path,
    use_amp=False,  # Set to True, if your GPU supports FP16 operations
)

# Push to HF hub

In [None]:
!huggingface-cli login

In [None]:
model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
model.push_to_hub(f"{model_name}-nli-diagonaloss")

# Test

Can instead do this in `./eval.py`

In [None]:
model = SentenceTransformer(model_save_path)

In [None]:
stsb_test = load_dataset("mteb/stsbenchmark-sts", split="test")

In [None]:
test_evaluator = EmbeddingSimilarityEvaluator(
    stsb_test["sentence1"],
    stsb_test["sentence2"],
    [score / 5 for score in stsb_test["score"]],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-test",
)

In [None]:
test_evaluator(model, output_path=model_save_path)