In [1]:
from torch.utils.data import DataLoader
import torch
from sentence_transformers import SentenceTransformer, losses, InputExample
from sentence_transformers.evaluation import (
    EmbeddingSimilarityEvaluator,
    SimilarityFunction,
)
from sentence_transformers.util import cos_sim

In [2]:
model = SentenceTransformer("paraphrase-albert-small-v2")

  return self.fget.__get__(instance, owner)()


In [3]:
dim = model.get_sentence_embedding_dimension()
max_weight_squared = torch.log2(torch.tensor(dim)) - 3
diagonal_vec = torch.sqrt(
    torch.linspace(
        max_weight_squared, 1, steps=dim, device=model.device, requires_grad=False
    )
)

In [4]:
x1 = model.encode(["test", "hi", "ladsfj;"], convert_to_tensor=True)
x2 = model.encode(["hello", "testing"], convert_to_tensor=True)

In [5]:
def _weighted_cosine_similarity(
    w: torch.Tensor, a: torch.Tensor, b: torch.Tensor
) -> torch.Tensor:
    w_sqrt = torch.sqrt(w)
    a = a * w_sqrt
    b = b * w_sqrt
    return cos_sim(a, b)


def weighted_cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    # similarity_fct doesn't like functools.partial b/c it doesn't have a __name__ attr
    return _weighted_cosine_similarity(diagonal_vec, a, b)

In [6]:
cos_sim(x1, x2)

tensor([[0.1056, 0.8345],
        [0.5414, 0.2684],
        [0.3557, 0.1200]])

In [7]:
weighted_cosine_similarity(x1, x2)

tensor([[0.0992, 0.8354],
        [0.5372, 0.2548],
        [0.3490, 0.1062]])

In [103]:
train_examples = [
    InputExample(texts=["Anchor 1", "Positive 1"]),
    InputExample(texts=["somethin", "something else"]),
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = losses.MultipleNegativesRankingLoss(
    model=model, similarity_fct=weighted_cosine_similarity
)
dev_evaluator = EmbeddingSimilarityEvaluator(
    ["aljfad", "a;lkjdfasl;jf"],
    ["sentence3", "sentence4"],
    [0.9, 0.9],
    main_similarity=SimilarityFunction.COSINE,
    write_csv=False,
    show_progress_bar=True,
)

In [104]:
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=dev_evaluator,
    epochs=2,
)

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]



Iteration:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

