In [1]:
import torch
import numpy as np
import cv2 as cv
import torch.nn as nn
import torch.nn.functional as F

from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou

In [14]:
B = 4
E = 3

embs = torch.randint(0, 3, (B, E), dtype=torch.float)
embs

tensor([[2., 2., 2.],
        [0., 2., 2.],
        [0., 1., 0.],
        [0., 2., 0.]])

In [48]:
def _pairwise_l2_dist(
    embs: torch.Tensor,
    *,
    squared: bool = False,
    eps: float = 1e-16
) -> torch.Tensor:
    dot_product = torch.matmul(embs, embs.T)  # [B,B]
    square_norm = torch.diag(dot_product)  # [B,]

    # Apply the l2 norm formula using the dot product:
    # ||A - B||^2 = ||A||^2 - 2<A,B> + ||B||^2
    distances = (
        torch.unsqueeze(square_norm, dim=1) -  # [B,1]
        (2 * dot_product) +  # [B,B]
        torch.unsqueeze(square_norm, dim=0)  # [1,B]
    )

    # Due to potential errors caused by numerical instability, some values may
    # have become negative. Thus, we have to make sure the min. value is zero.
    zero = torch.tensor(0.0)
    distances = torch.maximum(distances, zero)  # [B,B]

    if not squared:
        # Since the gradient of sqrt(0) is infinite, we, therefore, have to
        # add a small epsilon to the zero terms to prevent this.
        zeroes_mask = ((distances - zero) < eps).float()  # [B,B]
        distances += zeroes_mask * eps
        
        distances = torch.sqrt(distances)  # [B,B]

        # Set all the "zero" values back to zero after adding the epsilon value.
        distances *= (1.0 - zeroes_mask)
    
    return distances

def _get_anchor_positive_mask(labels: torch.Tensor) -> torch.Tensor:
    labels_eq_mask = (labels[..., None] == labels[None, ...])  # [B,B]
    idxs_neq_mask = ~torch.eye(len(labels), dtype=torch.bool)  # [B,B]
    anchor_positive_mask = (labels_eq_mask & idxs_neq_mask)  # [B,B]
    return anchor_positive_mask

def _get_anchor_negative_mask(labels: torch.Tensor) -> torch.Tensor:
    anchor_negative_mask = (labels[..., None] != labels[None, ...])  # [B,B]
    return anchor_negative_mask

def get_triplet_mask(labels: torch.Tensor) -> torch.Tensor:
    idxs_neq_mask = ~torch.eye(len(labels), dtype=torch.bool)  # [B,B]
    idx_i_neq_j_mask = torch.unsqueeze(idxs_neq_mask, dim=2)  # [B,B,1]
    idx_i_neq_k_mask = torch.unsqueeze(idxs_neq_mask, dim=1)  # [B,1,B]
    idx_j_neq_k_mask = torch.unsqueeze(idxs_neq_mask, dim=0)  # [1,B,B]
    triplet_idxs_neq_mask = (
        idx_i_neq_j_mask & idx_i_neq_k_mask & idx_j_neq_k_mask  # [B,B,B]
    )

    labels_eq_mask = (labels[..., None] == labels[None, ...])  # [B,B]
    label_i_eq_j = torch.unsqueeze(labels_eq_mask, dim=2)  # [B,B,1]
    label_i_neq_k = ~torch.unsqueeze(labels_eq_mask, dim=1)  # [B,1,B]
    triplet_labels_valid_mask = (label_i_eq_j & label_i_neq_k)  # [B,B,B]

    triplet_mask = (
        triplet_idxs_neq_mask & triplet_labels_valid_mask  # [B,B,B]
    )

    return triplet_mask

class SemiHardTripletLoss(nn.Module):
    def __init__(self, margin: float = 1.0, squared: bool = True) -> None:
        super().__init__()

        self.margin: float = margin
        self.squared: bool = squared
    
    def forward(self, embs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        pairwise_dist = _pairwise_l2_dist(embs, self.squared)

        anchor_positive_mask = _get_anchor_positive_mask(labels)  # [B,B]
        anchor_positive_dist = (
            pairwise_dist * anchor_positive_mask.float()
        )  # [B,B]
        hardest_positive_dist = torch.amax(
            anchor_positive_dist, dim=1, keepdim=True
        )  # [B,1]

        anchor_negative_mask = _get_anchor_negative_mask(
            labels
        ).float()  # [B,B]
        max_anchor_negative_dist = torch.amax(
            pairwise_dist, dim=1, keepdim=True
        )  # [B,1]
        anchor_negative_dist = (
            pairwise_dist +
            (1 - anchor_negative_mask) * max_anchor_negative_dist
        )  # [B,B]
        hardest_negative_dist = torch.amin(
            anchor_negative_dist, dim=1, keepdim=True
        )  # [B,1]

        triplet_loss = torch.clamp(
            hardest_positive_dist - hardest_negative_dist + self.margin, min=0
        )  # [B,1]
        triplet_loss = torch.mean(triplet_loss)  # [c]

        return triplet_loss
    

SyntaxError: positional argument follows keyword argument (<ipython-input-48-2f81062dbd5e>, line 88)

In [44]:
dist = torch.rand((B, B))
dist

tensor([[0.7602, 0.6728, 0.1440, 0.5144],
        [0.2222, 0.8913, 0.7648, 0.6707],
        [0.2119, 0.1066, 0.5787, 0.0769],
        [0.0636, 0.4100, 0.1285, 0.3217]])

In [47]:
torch.amax(dist, dim=1, keepdim=True)

tensor([[0.7602],
        [0.8913],
        [0.5787],
        [0.4100]])

In [8]:
n = 3
indices_neq_mask= ~torch.eye(n, dtype=torch.bool)
mask_ij = torch.unsqueeze(indices_neq_mask, dim=2)
mask_ik = torch.unsqueeze(indices_neq_mask, dim=1)
mask_jk = torch.unsqueeze(indices_neq_mask, dim=0)
mask_ij.shape, mask_ik.shape, mask_jk.shape

(torch.Size([3, 3, 1]), torch.Size([3, 1, 3]), torch.Size([1, 3, 3]))