In [87]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
import logging
logger = logging.getLogger("dinov2")

try:
    from xformers.ops import cross_entropy

    def lossfunc(t, s, temp):
        s = s.float()
        t = t.float()
        if s.ndim == 2:
            return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0)
        elif s.ndim == 3:
            return -cross_entropy(s, t, temp, bw_inplace=True)

except ImportError:

    def lossfunc(t, s, temp):
        return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1)
    #dim=-1: Specifies that the softmax operation should be applied along the 
    #last dimension of the tensor. This is typically the feature dimension in the context
    #of a batch of data.


class iBOTPatchLoss(nn.Module):
    def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, 1, patch_out_dim))
        self.updated = True
        self.reduce_handle = None
        self.len_teacher_patch_tokens = None
        self.async_batch_center = None

    @torch.no_grad()
    def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp):
        self.apply_center_update()
        # teacher centering and sharpening
        #
        # WARNING:
        #   as self.center is a float32, everything gets casted to float32 afterwards
        #
        # teacher_patch_tokens = teacher_patch_tokens.float()
        # return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1)

        return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1)

        # this is experimental, keep everything in float16 and let's see what happens:
        # return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1)

    @torch.no_grad()
    def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3):
        teacher_output = teacher_output.float()
        # world_size = dist.get_world_size() if dist.is_initialized() else 1
        Q = torch.exp(teacher_output / teacher_temp).t()  # Q is K-by-B for consistency with notations from our paper
        # B = Q.shape[1] * world_size # number of samples to assign
        B = n_masked_patches_tensor
        dist.all_reduce(B)
        K = Q.shape[0]  # how many prototypes

        # make the matrix sums to 1
        sum_Q = torch.sum(Q)
        if dist.is_initialized():
            dist.all_reduce(sum_Q)
        Q /= sum_Q

        for it in range(n_iterations):
            # normalize each row: total weight per prototype must be 1/K
            sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
            if dist.is_initialized():
                dist.all_reduce(sum_of_rows)
            Q /= sum_of_rows
            Q /= K

            # normalize each column: total weight per sample must be 1/B
            Q /= torch.sum(Q, dim=0, keepdim=True)
            Q /= B

        Q *= B  # the columns must sum to 1 so that Q is an assignment
        return Q.t()

    def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        student_patch_tokens: (B, N, D) tensor
        teacher_patch_tokens: (B, N, D) tensor
        student_masks_flat: (B, N) tensor
        """
        t = teacher_patch_tokens
        s = student_patch_tokens
        loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
        loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0)
        return -loss.mean()

    def forward_masked(
        self,
        student_patch_tokens_masked,
        teacher_patch_tokens_masked,
        student_masks_flat,
        n_masked_patches=None,
        masks_weight=None,
    ):
        t = teacher_patch_tokens_masked
        s = student_patch_tokens_masked
        # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
        loss = lossfunc(t, s, self.student_temp)
        if masks_weight is None:
            masks_weight = (
                (1 / student_masks_flat.sum(-1).clamp(min=1.0))
                .unsqueeze(-1)
                .expand_as(student_masks_flat)[student_masks_flat]
            )
        if n_masked_patches is not None:
            loss = loss[:n_masked_patches]
        loss = loss * masks_weight
        return -loss.sum() / student_masks_flat.shape[0]

    @torch.no_grad()
    def update_center(self, teacher_patch_tokens):
        self.reduce_center_update(teacher_patch_tokens)

    @torch.no_grad()
    def reduce_center_update(self, teacher_patch_tokens):
        self.updated = False
        self.len_teacher_patch_tokens = len(teacher_patch_tokens)
        self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True)
        if dist.is_initialized():
            self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)

    @torch.no_grad()
    def apply_center_update(self):
        if self.updated is False:
            world_size = dist.get_world_size() if dist.is_initialized() else 1

            if self.reduce_handle is not None:
                self.reduce_handle.wait()
            _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size)

            self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)

            self.updated = True


In [88]:
# Parameters
batch_size = 1
num_patches = 4
patch_dim = 3

student_temp=0.1
center_momentum=0.9

# Dummy data
student_patch_tokens = torch.randn(batch_size, num_patches, patch_dim)
teacher_patch_tokens = torch.randn(batch_size, num_patches, patch_dim)
student_masks_flat = torch.randint(0, 2, (batch_size, num_patches)).float()  #tensor([[1., 0., 1., 0.]])

# Initialize the loss function
loss_fn = iBOTPatchLoss(patch_out_dim=patch_dim)

# Compute the loss
loss = loss_fn(student_patch_tokens, teacher_patch_tokens, student_masks_flat)
print(f'Loss: {loss.item()}')

Loss: 1.7183916568756104


In [89]:
student_masks_flat

tensor([[1., 0., 0., 1.]])

In [90]:
def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat):
    t = teacher_patch_tokens   
    s = student_patch_tokens
    loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
    loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0)
    return -loss.mean()

In [91]:
s = student_patch_tokens
s

tensor([[[-0.9904, -0.7244, -1.0313],
         [-0.5965, -1.2006, -0.5768],
         [ 0.8252,  1.3768,  1.7738],
         [ 0.8361,  0.5736, -0.0891]]])

In [92]:
F.softmax(s / 0.1, dim=-1)

tensor([[[6.2615e-02, 8.9580e-01, 4.1589e-02],
         [4.5050e-01, 1.0712e-03, 5.4842e-01],
         [7.4442e-05, 1.8525e-02, 9.8140e-01],
         [9.3232e-01, 6.7590e-02, 8.9417e-05]]])

In [93]:
F.log_softmax(s / 0.1, dim=-1)

tensor([[[-2.7708, -0.1100, -3.1799],
         [-0.7974, -6.8390, -0.6007],
         [-9.5055, -3.9887, -0.0188],
         [-0.0701, -2.6943, -9.3222]]])

In [109]:
# From chatGPT (does not work when it is not square!)
#Row Normalization: Each row is normalized to sum to 1/K.
#Column Normalization: Each column is normalized to sum to 1/B in the for loop.
import torch
import torch.nn.functional as F

def sinkhorn_knopp(teacher_output, num_iterations=3):
    """
    Applies the Sinkhorn-Knopp algorithm to normalize the teacher output.
    
    Args:
        teacher_output (torch.Tensor): The output logits from the teacher network, shape (B, K) where B is the batch size and K is the number of prototypes.
        num_iterations (int): Number of iterations for the Sinkhorn-Knopp algorithm.
    
    Returns:
        torch.Tensor: The normalized output, shape (B, K).
    """
    Q = torch.exp(teacher_output).t()  # Transpose to have K x B
    B, K = Q.shape[1], Q.shape[0]

    for _ in range(num_iterations):
        # Normalize rows
        Q /= torch.sum(Q, dim=1, keepdim=True)
        Q /= K

        # Normalize columns
        Q /= torch.sum(Q, dim=0, keepdim=True)
        Q /= B

    Q *= B  # Rescale the columns to sum to 1
    return Q.t()  # Transpose back to B x K

# Example usage
batch_size = 5
num_prototypes = 5

# Simulated teacher output logits
teacher_output = torch.randn(batch_size, num_prototypes)
# teacher_output = torch.tensor([[ 10,  2,  3,  4,  5],
#         [ 60.,  7,  8,  9, 10],
#         [11, 12, 13, 14, 15],
#         [16, 17, 18, 19, 20],
#         [21, 22, 23, 24, 25]])

# Apply Sinkhorn-Knopp normalization
normalized_output = sinkhorn_knopp(teacher_output)
print(normalized_output)


tensor([[0.3983, 0.1297, 0.1776, 0.2745, 0.0199],
        [0.1030, 0.2527, 0.2393, 0.2110, 0.1941],
        [0.0215, 0.2209, 0.3148, 0.0742, 0.3685],
        [0.2831, 0.0533, 0.0480, 0.2742, 0.3414],
        [0.1902, 0.3465, 0.2231, 0.1637, 0.0765]])


In [112]:
teacher_output

tensor([[ 1.2317, -0.1360,  0.9248,  0.7052, -1.3081],
        [-0.5308,  0.1209,  0.8131,  0.0324,  0.5588],
        [-2.0928, -0.0104,  1.0902, -1.0095,  1.2026],
        [ 1.1024, -0.8132, -0.1703,  0.9166,  1.7455],
        [-0.6171, -0.2631,  0.0432, -0.9210, -1.0722]])

In [113]:
print(torch.sum(teacher_output,0))

tensor([-0.9065, -1.1017,  2.7009, -0.2763,  1.1267])


In [114]:
print(torch.sum(teacher_output,1))

tensor([ 1.4177,  0.9944, -0.8200,  2.7811, -2.8302])


In [110]:
print(torch.sum(normalized_output,0))

tensor([0.9961, 1.0031, 1.0028, 0.9977, 1.0003])


In [111]:
print(torch.sum(normalized_output,1))

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [97]:
teacher_output

tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20],
        [21, 22, 23, 24, 25]])

In [27]:
print(torch.sum(teacher_output,1))

tensor([ 0.2065, -0.4713,  0.9528, -1.1330])


In [28]:
print(torch.sum(teacher_output,0))

tensor([-1.9272,  2.7210, -1.5340,  0.2952])


In [29]:
16*16*8

2048

In [30]:
96*96*96

884736

In [31]:
(96*96*96)/(16*16*8)

432.0

In [33]:
teacher_output

tensor([[-0.6461,  0.3358,  0.5791, -3.6048],
        [ 0.5991,  0.2356,  0.2242,  1.3083],
        [ 0.2423,  0.3662,  0.1327,  2.1231],
        [ 0.8047,  0.0623,  0.0640,  1.1734]])

In [64]:
Q = torch.randn(4, 4)
print(Q)
Q = Q/torch.sum(Q, dim=0, keepdim=True)
Q /= 4
print(Q)
print(torch.sum(Q,0))

tensor([[-0.5424, -1.0715,  0.0973, -0.6550],
        [-0.5035, -1.8907,  1.4194,  0.6357],
        [ 1.6272, -0.4093,  0.3786, -0.0419],
        [ 0.1899,  1.7773,  1.0848, -0.2223]])
tensor([[-0.1758,  0.1680,  0.0082,  0.5776],
        [-0.1632,  0.2965,  0.1191, -0.5606],
        [ 0.5275,  0.0642,  0.0318,  0.0370],
        [ 0.0616, -0.2787,  0.0910,  0.1961]])
tensor([0.2500, 0.2500, 0.2500, 0.2500])


In [65]:
print(Q)
Q = Q/torch.sum(Q, dim=1, keepdim=True)
Q /= 4
print(Q)
print(torch.sum(Q,1))

tensor([[-0.1758,  0.1680,  0.0082,  0.5776],
        [-0.1632,  0.2965,  0.1191, -0.5606],
        [ 0.5275,  0.0642,  0.0318,  0.0370],
        [ 0.0616, -0.2787,  0.0910,  0.1961]])
tensor([[-0.0761,  0.0727,  0.0035,  0.2498],
        [ 0.1324, -0.2404, -0.0966,  0.4546],
        [ 0.1997,  0.0243,  0.0120,  0.0140],
        [ 0.2201, -0.9966,  0.3254,  0.7011]])
tensor([0.2500, 0.2500, 0.2500, 0.2500])


In [73]:
import torch.nn.functional as F
s = torch.randn(4, 4)
t = torch.randn(4, 4)
print(s)
torch.sum(t * F.log_softmax(s / 0.1, dim=-1), dim=-1)

tensor([[-0.3363, -0.9429, -0.9437, -0.2731],
        [-0.9939, -0.4279, -0.3170,  0.0103],
        [-2.6770,  0.6739,  0.7381, -0.6415],
        [-0.2887, -0.1877,  0.7344, -0.2026]])


tensor([-11.7797,  -2.5179,  -2.5611,  15.2144])

In [84]:
vv = F.log_softmax(s / 0.1, dim=-1)

In [85]:
torch.sum(vv, 1)

tensor([-15.7476, -17.8935, -50.2803, -28.8236])