In [22]:
import torch
import torch.nn.functional as F
from torch import nn, FloatTensor, IntTensor

In [23]:
def create_metric(metric):
    lambdas_dict = {
        'euclidian': lambda x, y: F.pairwise_distance(x, y, p=2),
        'manhattan': lambda x, y: F.pairwise_distance(x, y, p=1),
        'cosine': lambda x, y: 1 - F.cosine_similarity(x, y)
    }
    assert metric.lower() in lambdas_dict.keys(), "This metric is not supported"
    return lambdas_dict[metric.lower()]

In [24]:
class TripletLoss(nn.Module):
    """
    Triplet loss. Given a triplet of (anchor, positive, negative), the loss minimizes the distance between anchor and positive while it
    maximizes the distance between anchor and negative.

    It computes the following loss function:
    loss = max(||anchor - positive|| - ||anchor - negative|| + margin, 0).

    :param distance_metric: Can be 'euclidian', 'manhattan' or 'cosine' (inverted to be distance).
    :param margin: The desired difference between the anchor-positive distance and the anchor-negative distance.
    :param do_average: Average by mean in batch or summation

    For further details, see: https://en.wikipedia.org/wiki/Triplet_loss
    """

    def __init__(self,
                 distance_metric='cosine',
                 margin=0.5,
                 do_average=True):

        super(TripletLoss, self).__init__()
        self.distance_metric = create_metric(distance_metric)
        self.margin = margin
        self.do_average = do_average

    def forward(self, anchor: FloatTensor, positive: FloatTensor, negative: FloatTensor):
        """
        :param anchor: embeddings of size [bs, emd_dim]
        :param positive: embeddings of size [bs, emd_dim]
        :param negative: embeddings of size [bs, emd_dim]
        """

        distance_pos = self.distance_metric(anchor, positive)
        distance_neg = self.distance_metric(anchor, negative)

        losses = F.relu(distance_pos - distance_neg + self.margin)

        return losses.mean() if self.do_average else losses.sum()

In [25]:
anchor = torch.randn([10, 128])
positive = torch.randn([10, 128])
negative = torch.randn([10, 128])

In [26]:
TripletLoss(distance_metric='cosine', margin=0.5)(anchor, positive, negative)

tensor(0.5067)