In [55]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

# import torch.distributed as dist


logger = logging.getLogger("dinov2")


class KoLeoLoss(nn.Module):
    """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search"""

    def __init__(self):
        super().__init__()
        self.pdist = nn.PairwiseDistance(2, eps=1e-8)

    def pairwise_NNs_inner(self, x):
        """
        Pairwise nearest neighbors for L2-normalized vectors.
        Uses Torch rather than Faiss to remain on GPU.
        """
        # parwise dot products (= inverse distance)
        dots = torch.mm(x, x.t())
        n = x.shape[0]
        dots.view(-1)[:: (n + 1)].fill_(-1)  # Trick to fill diagonal with -1
        # max inner prod -> min distance
        _, I = torch.max(dots, dim=1)  # noqa: E741
        return I

    def forward(self, student_output, eps=1e-8):
        """
        Args:
            student_output (BxD): backbone output of student
        """
        with torch.cuda.amp.autocast(enabled=False):
            student_output = F.normalize(student_output, eps=eps, p=2, dim=-1)
            print(student_output)
            I = self.pairwise_NNs_inner(student_output)  # noqa: E741
            distances = self.pdist(student_output, student_output[I])  # BxD, BxD -> B
            loss = -torch.log(distances + eps).mean()
        return loss


The KoLeoLoss class implements a custom loss function based on the Kozachenko-Leonenko entropic loss regularizer, as described by Sablayrolles et al. in their 2018 paper **"Spreading vectors for similarity search"**. 

This loss function is designed to encourage the spread of vectors for similarity search tasks.

This spreading of vectors can be beneficial for tasks such as similarity search, where distinct and well-separated representations are desirable for efficient and accurate retrieval.

In [62]:
eps = 1e-8
student_output = torch.tensor([[1, 1.1, 1.2, 1.3 ,1.4 ,1.5, 1.6],
                               [1, 2.2, 1.2, 1.5 ,1.4 ,1.4, 1.6],
                               [1.5, 1.7, 1.2, 1.3 ,1.4 ,1.5, 2.6],
                               [1, 1.1, 1.2, 1.3 ,1.3 ,1.5, 1.61],
                               [1.51, 1.7, 1.2, 1.3 ,1.4 ,1.5, 2.6]])
student_output = F.normalize(student_output, eps=1e-8, p=2, dim=-1)
student_output

tensor([[0.2874, 0.3161, 0.3448, 0.3736, 0.4023, 0.4310, 0.4598],
        [0.2499, 0.5498, 0.2999, 0.3749, 0.3499, 0.3499, 0.3999],
        [0.3420, 0.3876, 0.2736, 0.2964, 0.3192, 0.3420, 0.5927],
        [0.2902, 0.3192, 0.3483, 0.3773, 0.3773, 0.4353, 0.4673],
        [0.3440, 0.3873, 0.2734, 0.2961, 0.3189, 0.3417, 0.5923]])

In [63]:
x = student_output        #5x7
dots = torch.mm(x, x.t()) #5x5
n = x.shape[0]            #5
print(dots)
dots.view(-1)[:: (n + 1)].fill_(-1) # wow it was applied in 2d without reshaping saved!
dots

tensor([[1.0000, 0.9645, 0.9742, 0.9996, 0.9741],
        [0.9645, 1.0000, 0.9601, 0.9651, 0.9599],
        [0.9742, 0.9601, 1.0000, 0.9763, 1.0000],
        [0.9996, 0.9651, 0.9763, 1.0000, 0.9762],
        [0.9741, 0.9599, 1.0000, 0.9762, 1.0000]])


tensor([[-1.0000,  0.9645,  0.9742,  0.9996,  0.9741],
        [ 0.9645, -1.0000,  0.9601,  0.9651,  0.9599],
        [ 0.9742,  0.9601, -1.0000,  0.9763,  1.0000],
        [ 0.9996,  0.9651,  0.9763, -1.0000,  0.9762],
        [ 0.9741,  0.9599,  1.0000,  0.9762, -1.0000]])

In [64]:
_, I = torch.max(dots, dim=1)
# notice the 0th batch and 3rd are almost the same in student_output, 2nd and 4th
I

tensor([3, 3, 4, 0, 2])

In [59]:
print(student_output)
print(student_output[I])

tensor([[0.2874, 0.3161, 0.3448, 0.3736, 0.4023, 0.4310, 0.4598],
        [0.2499, 0.5498, 0.2999, 0.3749, 0.3499, 0.3499, 0.3999],
        [0.3420, 0.3876, 0.2736, 0.2964, 0.3192, 0.3420, 0.5927],
        [0.2902, 0.3192, 0.3483, 0.3773, 0.3773, 0.4353, 0.4673],
        [0.3440, 0.3873, 0.2734, 0.2961, 0.3189, 0.3417, 0.5923]])
tensor([[0.2902, 0.3192, 0.3483, 0.3773, 0.3773, 0.4353, 0.4673],
        [0.2902, 0.3192, 0.3483, 0.3773, 0.3773, 0.4353, 0.4673],
        [0.3440, 0.3873, 0.2734, 0.2961, 0.3189, 0.3417, 0.5923],
        [0.2874, 0.3161, 0.3448, 0.3736, 0.4023, 0.4310, 0.4598],
        [0.3420, 0.3876, 0.2736, 0.2964, 0.3192, 0.3420, 0.5927]])


 the distance between a bactch rep with its nearest representation, make them smaller!!!!
 
 NOOOO, make them bigger. 
 
For d = 0.01: ............. Loss: -log(0.01) ≈ 4.605

For d = 1: ............. Loss: -log(1) = 0

In [60]:
pdist = nn.PairwiseDistance(2, eps=eps)
distances = pdist(student_output, student_output[I])
distances

tensor([0.0273, 0.2641, 0.0021, 0.0273, 0.0021])

In [61]:
loss = -torch.log(distances + eps).mean()
loss

tensor(4.1657)

In [65]:
pdist

PairwiseDistance()