In [63]:
import torch
import torch.nn.functional as F
from torch import nn, FloatTensor, IntTensor

In [64]:
def create_metric(metric):
    lambdas_dict = {
        'euclidian': lambda x, y: F.pairwise_distance(x, y, p=2),
        'manhattan': lambda x, y: F.pairwise_distance(x, y, p=1),
        'cosine': lambda x, y: 1 - F.cosine_similarity(x, y)
    }
    assert metric.lower() in lambdas_dict.keys(), "This metric is not supported"
    return lambdas_dict[metric.lower()]

In [331]:
class ContrastiveLoss(nn.Module):

    """
    Contrastive loss. Expects as input two embeddings and a label of either 0 or 1. If the label == 1, then the distance between the
    two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.

    :param distance_metric: Can be 'euclidian', 'manhattan' or 'cosine' (inverted to be distance).
    :param positive_margin: The distance over which positive pairs will contribute to the loss.
    :param negative_margin: The distance under which negative pairs will contribute to the loss.
    :param do_average: Average by mean in batch or summation

    | Further information:
    | http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    | https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#contrastiveloss
    """

    def __init__(self,
                 distance_metric='cosine',
                 positive_margin=0,
                 negative_margin=1,
                 do_average=True):

        super(ContrastiveLoss, self).__init__()
        self.distance_metric = create_metric(distance_metric)
        self.positive_margin = positive_margin
        self.negative_margin = negative_margin
        self.do_average = do_average

    def forward(self, x1: FloatTensor, x2: FloatTensor, label: IntTensor):
        """
        :param x1: embeddings of size [bs, emd_dim]
        :param x2: embeddings of size [bs, emd_dim]
        :param label: labels in range [0, 1] of size [bs,]
        """

        distances = self.distance_metric(x1, x2)
        losses = 0.5 * (
                label.float() * F.relu(self.negative_margin - distances).pow(2) +
                (1 - label).float() * F.relu(distances - self.positive_margin).pow(2)
        )  # RELU can be replaced with torch.clamp; 0.5 scaling is optional

        return losses.mean() if self.do_average else losses.sum()

In [332]:
x1 = torch.randn([10, 128])
x2 = torch.randn([10, 128])
label = torch.randint(low=0, high=2, size=[10])
label

tensor([0, 1, 1, 1, 1, 1, 0, 1, 1, 1])

In [329]:
distances = 1 - F.cosine_similarity(x1, x2)
distances

tensor([1.1794, 0.9802, 1.1684, 0.9810, 1.0024, 1.0226, 1.0147, 0.8088, 0.8799,
        1.0057])

In [330]:
ContrastiveLoss(distance_metric='cosine', positive_margin=0, negative_margin=1)(x1, x2, label)

tensor(0.2574)