In [1]:
import torch
import numpy as np
import cv2 as cv
import torch.nn as nn
import torch.nn.functional as F

from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou

In [2]:
n_rois = 100
n_channels = 128
output_size = 15

def features_to_emb(features: torch.Tensor) -> torch.Tensor:
    """Computes embedding vectors from tracker template (exemplar) features.
    For each feature tensor in a batch, it applies global average pooling along
    the channel dimension. Afterwards, it L2-normalizes the vectors to project
    them onto a hypersphere.

    Args:
        features (torch.Tensor): template features of shape [B, C, S, S]

    Returns:
        torch.Tensor: embedding vectors of shape [B, C]
    """
    assert features.ndim == 4
    assert features.shape[-1] == features.shape[-2]
    
    size = features.shape[-1]
    avg = F.avg_pool2d(features, kernel_size=size)   # [B, C, 1, 1]
    avg  = avg.squeeze()  # [B, C]
    norm = torch.linalg.norm(avg, dim=1)  # [B,]
    emb = avg / norm[..., None]  # [B, C]
    
    return emb

features = torch.rand((n_rois, n_channels, output_size, output_size))
emb = features_to_emb(features)
features.shape, emb.shape

(torch.Size([100, 128, 15, 15]), torch.Size([100, 128]))

In [3]:
class BalancedMarginContrastiveLoss(nn.Module):
    _ZERO = torch.tensor(0)

    def __init__(self, alpha: float = 1, beta: float = 2) -> None:
        self.alpha: float = alpha
        self.beta: float = beta
    
    def forward(self, embs, ids):
        assert len(embs) == len(ids)
        assert (embs.ndim == 2) and (ids.ndim == 1)

        idxs = torch.arange(0, len(embs))
        idx_pairs = torch.combinations(idxs, 2)
        emb_pairs = embs[idx_pairs]

        pair_dist = torch.norm(emb_pairs[:, 0, :] - emb_pairs[:, 1, :], dim=1)

        ids_first = ids[idx_pairs[:, 0]]
        ids_second = ids[idx_pairs[:, 1]]
        neg_pairs_mask = (ids_first != ids_second)

        labels = torch.ones_like(pair_dist)
        labels[neg_pairs_mask] = -1

        n_neg = torch.sum(neg_pairs_mask).item()
        n_pos = len(neg_pairs_mask) - n_neg

        pos_weight = 1.0 / n_pos
        neg_weight = 1.0 / n_neg

        weights = torch.full_like(pair_dist, pos_weight)
        weights[neg_pairs_mask] = neg_weight
        weights /= weights.sum()

        loss = torch.sum(
            weights * 
            torch.maximum(
                self.alpha + labels * (pair_dist - self.beta), self._ZERO
            )
        )

        return loss

In [4]:
ids = torch.tensor([1, 2, 1, 3])
embs = torch.tensor([[1, 1, 1], [10, 10, 10], [1, 1, 1], [4, 4, 4]]).float()
ids, embs

(tensor([1, 2, 1, 3]),
 tensor([[ 1.,  1.,  1.],
         [10., 10., 10.],
         [ 1.,  1.,  1.],
         [ 4.,  4.,  4.]]))

In [5]:
assert len(ids) == len(embs)

idxs = torch.arange(0, len(embs))
idx_pairs = torch.combinations(idxs, 2)
idx_pairs, idx_pairs.shape

(tensor([[0, 1],
         [0, 2],
         [0, 3],
         [1, 2],
         [1, 3],
         [2, 3]]),
 torch.Size([6, 2]))

In [6]:
emb_pairs = embs[idx_pairs]
pair_dist = torch.norm(emb_pairs[:, 0, :] - emb_pairs[:, 1, :], dim=1)
pair_dist, pair_dist.shape

(tensor([15.5885,  0.0000,  5.1962, 15.5885, 10.3923,  5.1962]),
 torch.Size([6]))

In [7]:
ids_first = ids[idx_pairs[:, 0]]
ids_second = ids[idx_pairs[:, 1]]
neg_pairs_mask = ids_first != ids_second
labels = torch.ones_like(pair_dist)
labels[neg_pairs_mask] = -1
torch.sum(neg_pairs_mask)

tensor(5)

In [8]:
alpha = 1
beta = 2

margin_loss = torch.mean(torch.maximum(alpha + labels * (pair_dist - beta), torch.tensor(0)))
margin_loss

tensor(0.)