Import/install libraries

In [None]:
!pip install einops
import torch
from torch.nn import TripletMarginWithDistanceLoss
from torch.nn.functional import triplet_margin_loss, binary_cross_entropy_with_logits
from typing import final, Optional

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1


We denote **anchor examples** with $\mathbf{x}_a$, **positive examples** with $\mathbf{x}_{pos}$, **negative examples** with $\mathbf{x}_{neg}$ and **partially positive examples** with $\mathbf{x}_{part}$.

Let $\gamma \in [0, 1]$, $\alpha_{pos\text{-}neg}, \alpha_{part\text{-}neg}, \alpha_{pos\text{-}part}, \lambda \in \mathbb{R}^+$, $E_{\eta}$ be a text encoder; we define $\gamma$*-quadruplet loss*:

\begin{align*}
\mathcal{L}_{\gamma} \left(\eta\right) := \mathbb{E}_{\mathbf{x}_a, \mathbf{x}_{pos}, \mathbf{x}_{part} \mathbf{x}_{neg}}\big[ \, & \max\left(||\hat{\mathbf{x}}_{a} - \hat{\mathbf{x}}_{pos}|| - ||\hat{\mathbf{x}}_{a} -\hat{\mathbf{x}}_{neg}|| + \alpha_{pos\text{-}neg}, 0\right) + \\ & \gamma \max\left(||\hat{\mathbf{x}}_{a} - \hat{\mathbf{x}}_{part}|| - ||\hat{\mathbf{x}}_{a} -\hat{\mathbf{x}}_{neg}|| + \alpha_{part\text{-}neg}, 0\right) + \\ & (1 - \gamma) \max\left(||\hat{\mathbf{x}}_{a} - \hat{\mathbf{x}}_{pos}|| - ||\hat{\mathbf{x}}_{a} -\hat{\mathbf{x}}_{part}|| + \alpha_{pos\text{-}part}, 0\right)\big]
\end{align*}
where:

\begin{align*}
    & \hat{\mathbf{x}}_{a} := E_{\eta}\left(\mathbf{x}_{a}\right) \\
    & \hat{\mathbf{x}}_{pos} := E_{\eta}\left(\mathbf{x}_{pos}\right) \\
    & \hat{\mathbf{x}}_{part} := E_{\eta}\left(\mathbf{x}_{part}\right) \\
   & \hat{\mathbf{x}}_{neg} := E_{\eta}\left(\mathbf{x}_{neg}\right)
\end{align*}

The following function computes the $\gamma$*-quadruplet loss*.

In [None]:
DEFAULT_GAMMA: final = 0.8
REDUCTIONS: final = frozenset(["mean", "sum", "none"])


def gamma_quadruplet_loss(x_anchor: torch.Tensor,
                          x_pos: torch.Tensor,
                          x_part: torch.Tensor,
                          x_neg: torch.Tensor,
                          gamma: float = DEFAULT_GAMMA,
                          margin_pos_neg: float = 1.0,
                          margin_pos_part: float = 1.0,
                          margin_part_neg: float = 1.0,
                          p: float = 2.0,
                          swap: bool = False,
                          reduction: str  = "mean") -> torch.Tensor:
  if gamma < 0 or gamma > 1:
    raise ValueError(f"gamma must be between 0 and 1, {gamma} given")
  if margin_pos_neg <= 0:
    raise ValueError(f"margin_pos_neg must be positive, {margin_pos_neg} given")
  if margin_pos_part <= 0:
    raise ValueError(f"margin_pos_part must be positive, {margin_pos_part} given")
  if margin_part_neg <= 0:
    raise ValueError(f"margin_part_neg must be positive, {margin_part_neg} given")
  if reduction not in REDUCTIONS:
    raise ValueError(f"reduction must be one of: {REDUCTIONS}, "
                     f"{reduction} given")
  if p <= 0:
    raise ValueError(f"p must be positive, {p} given")

  # Compute the triplet losses with no reduction, shape (B,)
  a = triplet_margin_loss(
      anchor=x_anchor,
      positive=x_pos,
      negative=x_neg,
      margin=margin_pos_neg,
      p=p,
      swap=swap,
      reduction='none'
  )
  b = triplet_margin_loss(
      anchor=x_anchor,
      positive=x_part,
      negative=x_neg,
      margin=margin_part_neg,
      p=p,
      swap=swap,
      reduction='none'
  )
  c = triplet_margin_loss(
      anchor=x_anchor,
      positive=x_pos,
      negative=x_part,
      margin=margin_pos_part,
      p=p,
      swap=swap,
      reduction='none'
  )

  # Return the reduced loss if required
  if reduction == 'none':
    return a + gamma*b + (1 -gamma)*c
  elif reduction == 'sum':
    return a.sum() + (gamma*b).sum() + ((1 -gamma)*c).sum()
  else:
    return a.mean() + (gamma*b).mean() + ((1 -gamma)*c).mean()

Let $D_{\theta}$ be a **binary discriminator** distinguishing positive examples from partially positive examples, parametrized by $\mathbf{\eta}$ and $\mathbf{\theta}$, respectively. We thus define $D_{\theta}$*-regularized quadruplet loss*:

\begin{align*}
\mathcal{L}_{D} \left(\eta, \theta\right):= \mathbb{E}_{\mathbf{x}_a, \mathbf{x}_{pos}, \mathbf{x}_{part} \mathbf{x}_{neg}}\big[ \, & \max\left(||\hat{\mathbf{x}}_{a} - \hat{\mathbf{x}}_{pos}|| - ||\hat{\mathbf{x}}_{a} -\hat{\mathbf{x}}_{neg}|| + \alpha_{pos\text{-}neg}, 0 \right) + \\ & \max\left(||\hat{\mathbf{x}}_{a} - \hat{\mathbf{x}}_{part}|| -||\hat{\mathbf{x}}_{a} - \hat{\mathbf{x}}_{neg}|| + \alpha_{part\textbf{-}neg}, 0\right)+ \\ & - \lambda \log\left(D_{\theta}\left(\hat{\mathbf{x}}_a, \hat{\mathbf{x}}_{pos}\right)\right) - \lambda \log\left(1 - D_{\theta}\left(\hat{\mathbf{x}}_a, \hat{\mathbf{x}}_{part}\right)\right) \big]
\end{align*}

The following function computes the $D_θ$*-regularized quadruplet loss*.

In [None]:
def d_regularized_quadruplet_loss(
                          x_anchor: torch.Tensor,
                          x_pos: torch.Tensor,
                          x_part: torch.Tensor,
                          x_neg: torch.Tensor,
                          margin_pos_neg: float = 1.0,
                          margin_part_neg: float = 1.0,
                          lmbd: float = 0.1,
                          discr: Optional[torch.nn.Module] = None,
                          discr_logits_pos: Optional[torch.Tensor] = None,
                          discr_logits_part: Optional[torch.Tensor] = None,
                          p: float = 2.0,
                          swap: bool = False,
                          reduction: str  = "mean") -> torch.Tensor:
  if lmbd <= 0:
    raise ValueError(f"lmbd must be positive, {lmbd} given")
  if margin_pos_neg <= 0:
    raise ValueError(f"margin_pos_neg must be positive, {margin_pos_neg} given")
  if margin_part_neg <= 0:
    raise ValueError(f"margin_part_neg must be positive, {margin_part_neg} given")
  if reduction not in REDUCTIONS:
    raise ValueError(f"reduction must be one of: {REDUCTIONS}, "
                     f"{reduction} given")
  if p <= 0:
    raise ValueError(f"p must be positive, {p} given")
  if discr is None and (discr_logits_part is None or discr_logits_pos is None):
    raise ValueError(f"Either discriminator or discriminator logits must be "
                      f"given")

  # Compute the triplet losses with no reduction, shape (B,)
  a = triplet_margin_loss(
      anchor=x_anchor,
      positive=x_pos,
      negative=x_neg,
      margin=margin_pos_neg,
      p=p,
      swap=swap,
      reduction='none'
  )
  b = triplet_margin_loss(
      anchor=x_anchor,
      positive=x_part,
      negative=x_neg,
      margin=margin_part_neg,
      p=p,
      swap=swap,
      reduction='none'
  )

  # Compute logits if required, with shape (B, 1)
  if discr_logits_pos is None or discr_logits_part is None:
    discr_logits_pos = discr(x_anchor, x_pos)
    discr_logits_part = discr(x_anchor, x_part)

  # Unsqueeze logits to obtain tensors with shape (B, 1, 1)
  discr_logits_pos = discr_logits_pos.unsqueeze(1)
  discr_logits_part = discr_logits_part.unsqueeze(1)

  # Concatenate the logits and create targets with the same shape (B, 2, 1)
  discr_logits_cat = torch.cat([discr_logits_pos, discr_logits_part], dim=1)
  target_pos = torch.ones_like(discr_logits_pos)
  target_part = torch.zeros_like(discr_logits_part)
  target_cat = torch.cat([target_pos, target_part], dim=1)

  # Calculate BCE loss with no reduction, shape (B, 2, 1)
  bce = binary_cross_entropy_with_logits(discr_logits_cat, target=target_cat,
                                         reduction='none')

  # Sum loss value over the 2-th dim, obtaining tensor with shape (B, 1, 1)
  bce = bce.sum(dim=1, keepdim=True)

  print(bce.shape)

  # Return the reduced loss if required
  if reduction == 'none':
    return a + b - lmbd*bce.squeeze(dim=-1)
  elif reduction == 'sum':
    return a.sum() + b.sum() - lmbd*bce.squeeze(dim=-1).sum()
  else:
    return a.mean() + b.mean() - lmbd*bce.squeeze(dim=-1).mean()

The following blocks just show the functioning of the above routines.

In [None]:
pos = torch.ones(5)
neg = torch.zeros(5)
pos = pos.unsqueeze(-1)
neg = neg.unsqueeze(-1)
print(pos)
print(neg)

cat = torch.cat([pos, neg], dim=-1)
print(cat.shape)
print(cat)

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.]])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.]])
torch.Size([5, 2])
tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.]])


In [None]:
class DummyDiscriminator(torch.nn.Module):
  def __init__(self, in_channels: int):
    super().__init__()
    self._lin = torch.nn.Linear(in_channels*2, 1)

  def forward(self, x_anchor: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    x = torch.cat([x_anchor, x], dim=-1)
    del x_anchor
    print(x.shape)
    return self._lin(x)

batch_size = 5
l = DummyDiscriminator(10)
x_anchor = torch.randn(batch_size, 10)
x_pos = torch.randn(batch_size, 10)
x_part = torch.randn(batch_size, 10)
x_neg = torch.randn(batch_size, 10)
print(x_pos.shape)
print(x_part.shape)

pred_logits_pos = l(x_anchor, x_pos).unsqueeze(1)
pred_logits_part = l(x_part, x_pos).unsqueeze(1)
print(pred_logits_pos.shape)
print(pred_logits_part.shape)


pred_logits_cat = torch.cat([pred_logits_pos, pred_logits_part], dim=1)
target_pos = torch.ones_like(pred_logits_pos)
target_part = torch.zeros_like(pred_logits_part)
target_cat = torch.cat([target_pos, target_part], dim=1)
print(pred_logits_cat)
print(pred_logits_cat.shape)
print(target_cat)
print(target_cat.shape)

bce_raw = binary_cross_entropy_with_logits(pred_logits_cat, target=target_cat, reduction='none')
print(bce_raw)
print(bce_raw.shape)
bce_sum = bce_raw.sum(dim=1)
print(bce_sum)
print(bce_sum.shape)
print(bce_sum.mean())
bce = binary_cross_entropy_with_logits(pred_logits_cat, target=target_cat, reduction='mean')
print(bce)

torch.Size([5, 10])
torch.Size([5, 10])
torch.Size([5, 20])
torch.Size([5, 20])
torch.Size([5, 1, 1])
torch.Size([5, 1, 1])
tensor([[[-0.2870],
         [ 0.2598]],

        [[ 0.1180],
         [-0.1036]],

        [[ 0.4508],
         [ 0.0024]],

        [[-0.6488],
         [-0.9238]],

        [[-0.7359],
         [-0.2133]]], grad_fn=<CatBackward0>)
torch.Size([5, 2, 1])
tensor([[[1.],
         [0.]],

        [[1.],
         [0.]],

        [[1.],
         [0.]],

        [[1.],
         [0.]],

        [[1.],
         [0.]]])
torch.Size([5, 2, 1])
tensor([[[0.8469],
         [0.8314]],

        [[0.6359],
         [0.6427]],

        [[0.4930],
         [0.6943]],

        [[1.0693],
         [0.3343]],

        [[1.1273],
         [0.5922]]], grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
torch.Size([5, 2, 1])
tensor([[1.6784],
        [1.2786],
        [1.1873],
        [1.4036],
        [1.7195]], grad_fn=<SumBackward1>)
torch.Size([5, 1])
tensor(1.4535, grad_fn=<MeanBackw

In [None]:
gql = gamma_quadruplet_loss(x_anchor, x_pos, x_part, x_neg, gamma=0.8, margin_pos_neg=1.0, margin_pos_part=0.9, margin_part_neg=0.8, p=2, reduction='none')
print(gql)
print(gql.shape)
print(gql.mean())

gql = gamma_quadruplet_loss(x_anchor, x_pos, x_part, x_neg, gamma=0.8, margin_pos_neg=1.0, margin_pos_part=0.9, margin_part_neg=0.8, p=2, reduction='sum')
print(gql)
print(gql/batch_size)
print(gql.shape)

gql = gamma_quadruplet_loss(x_anchor, x_pos, x_part, x_neg, gamma=0.8, margin_pos_neg=1.0, margin_pos_part=0.9, margin_part_neg=0.8, p=2, reduction='mean')
print(gql)
print(gql.shape)

tensor([4.6508, 0.3866, 3.0095, 3.2603, 3.3247])
torch.Size([5])
tensor(2.9264)
tensor(14.6319)
tensor(2.9264)
torch.Size([])
tensor(2.9264)
torch.Size([])


In [None]:
gdl = d_regularized_quadruplet_loss(x_anchor, x_pos, x_part, x_neg, lmbd=0.8, margin_pos_neg=1.0, margin_part_neg=0.9, p=2, reduction='none', discr=l)
print(gdl)
print(gdl.shape)
print(gdl.mean())

gdl = d_regularized_quadruplet_loss(x_anchor, x_pos, x_part, x_neg, lmbd=0.8, margin_pos_neg=1.0, margin_part_neg=0.9, p=2, reduction='sum', discr=l)
print(gdl)
print(gdl/batch_size)
print(gdl.shape)

gdl = d_regularized_quadruplet_loss(x_anchor, x_pos, x_part, x_neg, lmbd=0.8, margin_pos_neg=1.0, margin_part_neg=0.9, p=2, reduction='mean', discr=l)
print(gdl)
print(gdl.shape)

torch.Size([5, 20])
torch.Size([5, 20])
torch.Size([5, 1, 1])
tensor([[ 3.8286, -0.8132,  2.2484,  2.2924,  2.4306],
        [ 3.6114, -1.0305,  2.0311,  2.0751,  2.2134],
        [ 3.7003, -0.9415,  2.1201,  2.1641,  2.3023],
        [ 3.4315, -1.2104,  1.8512,  1.8952,  2.0335],
        [ 3.5944, -1.0475,  2.0141,  2.0581,  2.1963]], grad_fn=<SubBackward0>)
torch.Size([5, 5])
tensor(1.8020, grad_fn=<MeanBackward0>)
torch.Size([5, 20])
torch.Size([5, 20])
torch.Size([5, 1, 1])
tensor(9.0098, grad_fn=<SubBackward0>)
tensor(1.8020, grad_fn=<DivBackward0>)
torch.Size([])
torch.Size([5, 20])
torch.Size([5, 20])
torch.Size([5, 1, 1])
tensor(1.8020, grad_fn=<SubBackward0>)
torch.Size([])


In [None]:
discr_logits_pos = l(x_anchor, x_pos)
discr_logits_part = l(x_anchor, x_part)
gdl = d_regularized_quadruplet_loss(x_anchor, x_pos, x_part, x_neg, lmbd=0.8, margin_pos_neg=1.0, margin_part_neg=0.9, p=2, reduction='none', discr_logits_pos=discr_logits_pos, discr_logits_part=discr_logits_part)
print(gdl)
print(gdl.shape)
print(gdl.mean())

gdl = d_regularized_quadruplet_loss(x_anchor, x_pos, x_part, x_neg, lmbd=0.8, margin_pos_neg=1.0, margin_part_neg=0.9, p=2, reduction='sum', discr_logits_pos=discr_logits_pos, discr_logits_part=discr_logits_part)
print(gdl)
print(gdl/batch_size)
print(gdl.shape)

gdl = d_regularized_quadruplet_loss(x_anchor, x_pos, x_part, x_neg, lmbd=0.8, margin_pos_neg=1.0, margin_part_neg=0.9, p=2, reduction='mean', discr_logits_pos=discr_logits_pos, discr_logits_part=discr_logits_part)
print(gdl)
print(gdl.shape)

torch.Size([5, 20])
torch.Size([5, 20])
torch.Size([5, 1, 1])
tensor([[ 3.8286, -0.8132,  2.2484,  2.2924,  2.4306],
        [ 3.6114, -1.0305,  2.0311,  2.0751,  2.2134],
        [ 3.7003, -0.9415,  2.1201,  2.1641,  2.3023],
        [ 3.4315, -1.2104,  1.8512,  1.8952,  2.0335],
        [ 3.5944, -1.0475,  2.0141,  2.0581,  2.1963]], grad_fn=<SubBackward0>)
torch.Size([5, 5])
tensor(1.8020, grad_fn=<MeanBackward0>)
torch.Size([5, 1, 1])
tensor(9.0098, grad_fn=<SubBackward0>)
tensor(1.8020, grad_fn=<DivBackward0>)
torch.Size([])
torch.Size([5, 1, 1])
tensor(1.8020, grad_fn=<SubBackward0>)
torch.Size([])


In [None]:
t = torch.randn(5, 1, 2, 1)
t = t.squeeze(dim=-1)
t.shape

torch.Size([5, 1, 2])

In [None]:
from abc import ABC, abstractmethod


class QuadrupletLoss(torch.nn.Module, ABC):
  def __init__(self,
               margin_pos_neg: float = 1.0,
               margin_pos_part: float = 1.0,
               p: float = 2.0,
               swap: bool = False,
               reduction: str  = "mean"):
    super().__init__()
    if margin_pos_neg <= 0:
      raise ValueError(f"margin_pos_neg must be positive, {margin_pos_neg} given")
    if margin_pos_part <= 0:
      raise ValueError(f"margin_pos_part must be positive, {margin_pos_part} given")
    if reduction not in REDUCTIONS:
      raise ValueError(f"reduction must be one of: {REDUCTIONS}, "
                      f"{reduction} given")
    if p <= 0:
      raise ValueError(f"p must be positive, {p} given")


    self.__margin_pos_neg: float = margin_pos_neg
    self.__margin_pos_part: float = margin_pos_part
    self.__p: float = p
    self.__swap: bool = swap
    self.__reduction: str = reduction

  @property
  def margin_pos_neg(self) -> float:
    return self.__margin_pos_neg

  @margin_pos_neg.setter
  def margin_pos_neg(self, margin_pos_neg: float):
    if margin_pos_neg <= 0:
      raise ValueError(f"margin_pos_neg must be positive, {margin_pos_neg} given")
    self.__margin_pos_neg = margin_pos_neg

  @property
  def margin_pos_part(self) -> float:
    return self.__margin_pos_part

  @margin_pos_part.setter
  def margin_pos_part(self, margin_pos_part: float):
    if margin_pos_part <= 0:
      raise ValueError(f"margin_pos_part must be positive, {margin_pos_part} given")
    self.__margin_pos_part = margin_pos_part

  @property
  def p(self) -> float:
    return self.__p

  @p.setter
  def p(self, p: float):
    if p <= 0:
      raise ValueError(f"p must be positive, {p} given")
    self.__p

  @property
  def swap(self) -> bool:
    return self.__swap

  @swap.setter
  def swap(self, swap: bool):
    self.__swap = swap

  @property
  def reduction(self) -> str:
    return self.__reduction

  @reduction.setter
  def reduction(self, reduction: str):
    if reduction not in REDUCTIONS:
      raise ValueError(f"reduction must be one of: {REDUCTIONS}, "
                      f"{reduction} given")
    self.__reduction = reduction

  @abstractmethod
  def forward(self,
              x_anchor: torch.Tensor,
              x_pos: torch.Tensor,
              x_part: torch.Tensor,
              x_neg: torch.Tensor,
              reduction: Optional[str] = None,
              **kwargs) -> torch.Tensor:
    raise NotImplementedError()


class GammaQuadrupletLoss(QuadrupletLoss):
  def __init__(self,
               gamma: float = DEFAULT_GAMMA,
               margin_pos_neg: float = 1.0,
               margin_pos_part: float = 1.0,
               margin_part_neg: float = 1.0,
               p: float = 2.0,
               swap: bool = False,
               reduction: str  = "mean"):
    super().__init__(margin_pos_part=margin_pos_part,
                     margin_pos_neg=margin_pos_neg,
                     p=p,
                     swap=swap,
                     reduction=reduction)
    if gamma < 0 or gamma > 1:
      raise ValueError(f"gamma must be between 0 and 1, {gamma} given")
    if margin_part_neg <= 0:
      raise ValueError(f"margin_part_neg must be positive, {margin_part_neg} given")

    self.__gamma: float = gamma
    self.__margin_part_neg: float = margin_part_neg

  @property
  def gamma(self) -> float:
    return self.__gamma

  @gamma.setter
  def gamma(self, gamma: float):
    if gamma < 0 or gamma > 1:
      raise ValueError(f"gamma must be between 0 and 1, {gamma} given")
    self.__gamma = gamma

  @property
  def margin_part_neg(self) -> float:
    return self.__margin_part_neg

  @margin_part_neg.setter
  def margin_part_neg(self, margin_part_neg: float):
    if margin_part_neg <= 0:
      raise ValueError(f"margin_part_neg must be positive, {margin_part_neg} given")
    self.__margin_part_neg = margin_part_neg


  def forward(self,
              x_anchor: torch.Tensor,
              x_pos: torch.Tensor,
              x_part: torch.Tensor,
              x_neg: torch.Tensor,
              reduction: Optional[str] = None,
              **kwargs) -> torch.Tensor:

    reduction = self.reduction if reduction is None else reduction

    return gamma_quadruplet_loss(x_anchor=x_anchor,
                                 x_pos=x_pos,
                                 x_part=x_part,
                                 x_neg=x_neg,
                                 gamma=self.gamma,
                                 margin_pos_neg=self.margin_pos_neg,
                                 margin_pos_part=self.margin_pos_part,
                                 margin_part_neg=self.margin_part_neg,
                                 p=self.p,
                                 swap=self.swap,
                                 reduction=reduction)

gql_object = GammaQuadrupletLoss(gamma=0.8,
                                 margin_pos_neg=1.0,
                                 margin_pos_part=0.9,
                                 margin_part_neg=0.8,
                                 p=2,
                                 reduction='none')

gql = gql_object(x_anchor, x_pos, x_part, x_neg)
print(gql)
print(gql.shape)
print(gql.mean())

gql = gql_object(x_anchor, x_pos, x_part, x_neg, reduction='sum')
print(gql)
print(gql.shape)
print(gql/batch_size)

gql = gql_object(x_anchor, x_pos, x_part, x_neg, reduction='mean')
print(gql.shape)
print(gql)

tensor([4.6508, 0.3866, 3.0095, 3.2603, 3.3247])
torch.Size([5])
tensor(2.9264)
tensor(14.6319)
torch.Size([])
tensor(2.9264)
torch.Size([])
tensor(2.9264)
