In [1]:
from torch.utils.data import DataLoader
import torch
from sentence_transformers import SentenceTransformer, losses, InputExample
from sentence_transformers.evaluation import (
    EmbeddingSimilarityEvaluator,
    SimilarityFunction,
)
from sentence_transformers.util import cos_sim

In [2]:
model = SentenceTransformer("paraphrase-albert-small-v2")

In [3]:
dim: int = model.get_sentence_embedding_dimension()
max_weight_squared = torch.log2(torch.tensor(dim)) - 3
diagonal_vec = torch.sqrt(
    torch.linspace(
        max_weight_squared, 1, steps=dim, device=model.device, requires_grad=False
    )
) / 100

In [4]:
x1 = model.encode(["test", "hi", "ladsfj;"], convert_to_tensor=True)
x2 = model.encode(["hello", "testing"], convert_to_tensor=True)

In [5]:
# def _euclidian(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
#     dists = a @ b.T
#     a_norms = (a * a).sum(dim=-1).unsqueeze(0).T
#     b_norms = (b * b).sum(dim=-1)
#     return torch.sqrt(a_norms - (2 * dists) + b_norms)


def _weighted_euclidean(
    w: torch.Tensor, a: torch.Tensor, b: torch.Tensor
) -> torch.Tensor:
    w_sqrt = torch.sqrt(w)
    return torch.cdist(a * w_sqrt, b * w_sqrt, p=2)


def negative_weighted_euclidean(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    # similarity_fct doesn't like functools.partial b/c it doesn't have a __name__ attr
    # so can't use partial
    return -1 * _weighted_euclidean(diagonal_vec, a, b)

In [6]:
cos_sim(x1, x2)

tensor([[0.1056, 0.8345],
        [0.5414, 0.2684],
        [0.3557, 0.1200]], device='mps:0')

In [7]:
negative_weighted_euclidean(x1, x2)

tensor([[-3.3179, -1.3817],
        [-2.1805, -2.6236],
        [-2.4544, -2.7100]], device='mps:0')

In [8]:
train_examples = [
    InputExample(texts=["Anchor 1", "Positive 1"]),
    InputExample(texts=["somethin", "something else"]),
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = losses.MultipleNegativesRankingLoss(
    model=model, similarity_fct=negative_weighted_euclidean
)
dev_evaluator = EmbeddingSimilarityEvaluator(
    ["aljfad", "a;lkjdfasl;jf"],
    ["sentence3", "sentence4"],
    [0.9, 0.9],
    main_similarity=SimilarityFunction.COSINE,
    write_csv=False,
    show_progress_bar=True,
)

In [9]:
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=dev_evaluator,
    epochs=2,
)

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1 [00:00<?, ?it/s]

NotImplementedError: The operator 'aten::_cdist_backward' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.