In [1]:
import torch.nn as nn
import torch

In [6]:
class SimpleMarginContrastiveLoss(nn.Module):
    def __init__(self, margin=0.05):
        super().__init__()
        self.margin = margin

    def forward(self, anchor, positive, negatives):
    
        # Compute distances between anchor and positive 
        # L-1 norm between anchor and positive sample
        
        dist_ap = torch.abs(anchor - positive).mean(dim=(1, 2, 3))  # Shape: (B,)
        print('dist_ap:', dist_ap)

        # Concatenate negatives from the list
        negatives_tensor = torch.stack(negatives, dim=1)  # Shape: (B, N, C, H, W)
        B, N, C, H, W = negatives_tensor.shape
        print('negatives_tensor:', B,N,C,H,W)
        
        # take a lotta negative samples then take avg distance from anchor
        # Reshape negatives for vectorized distance computation
        negatives_flat = negatives_tensor.reshape(B * N, C, H, W)  # Shape: (B*N, C, H, W)
        print('negatives_flat:', negatives_flat.shape)
        anchor_expanded = anchor.unsqueeze(1).repeat(1, N, 1, 1, 1).view(B * N, C, H, W)  # Shape: (B*N, C, H, W)
        print('anchor_expanded:', anchor_expanded.shape)

        # Compute distances between anchor and negatives
        dist_an = torch.abs(anchor_expanded - negatives_flat).mean(dim=(1, 2, 3))  # Shape: (B*N,)
        dist_an = dist_an.view(B, N)  # Reshape back to (B, N)
        print('dist_an: ', dist_an.shape)

        # Aggregate negative distances (e.g., average)
        loss_neg = dist_an.sum(dim=1)  # Shape: (B,)

        # Compute contrastive loss
        loss = dist_ap / (self.margin + loss_neg)  # Add small epsilon to prevent division by zero

        return loss.mean()

In [7]:
contrastive_loss = SimpleMarginContrastiveLoss(margin=0.05)

anchor = torch.randn(1, 3, 256, 256)
positive = torch.randn(1, 3, 256, 256)
negatives = [torch.randn(1, 3, 256, 256),torch.randn(1, 3, 256, 256),torch.randn(1, 3, 256, 256),torch.randn(1, 3, 256, 256),]

In [8]:
contrastive_loss(anchor, positive, negatives)

dist_ap: tensor([1.1277])
negatives_tensor: 1 4 3 256 256
negatives_flat: torch.Size([4, 3, 256, 256])
anchor_expanded: torch.Size([4, 3, 256, 256])
dist_an:  torch.Size([1, 4])


tensor(0.2470)