In [2]:
import os
import cv2
import numpy as np
import pandas as pd
import transformers
import tensorflow as tf
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt
import tensorflow_addons as tfa

2023-12-28 02:07:14.286648: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-28 02:07:14.575133: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-28 02:07:14.575161: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-28 02:07:14.576861: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-28 02:07:14.716563: I tensorflow/core/platform/cpu_feature_g

In [15]:
class DinoLoss(tf.keras.losses.Loss):
    def __init__(
            self,
            ncrops,
            warmup_teacher_temp,
            teacher_temp,
            warmup_teacher_temp_epochs,
            nepochs,
            student_temp=0.1,
            center_momentum=0.9
        ) -> None:
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.ncrops = ncrops

        self.teacher_temp_schedule = tf.concat(
            (
                tf.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
                tf.ones((nepochs - warmup_teacher_temp_epochs)) * teacher_temp,
            ),
            axis=0,
        )


    def update_center(self, teacher_output) -> None:
        '''
        Update center used for teacher output.
        
        In article pseudocode -> C = m * C + (1-m) * cat([t1, t2]).mean(dim=0)
        '''

        batch_center = tf.math.reduce_sum(teacher_output, axis=0)
        batch_center = batch_center / tf.cast(len(teacher_output), tf.float32)
        
        self.center = tf.stop_gradient(self.center * self.center_momentum
                                       + batch_center * (1 - self.center_momentum))
        

    def softmax_center_teacher(self, teacher_output, teacher_temp):
        '''
        Center + sharpen for teacher

        In article pseudocode -> t = softmax((t - C) / tpt, dim=1)
        '''
        
        teacher_out = tf.stop_gradient(
            tf.nn.softmax((teacher_output - self.center) / teacher_temp, axis=-1)
        )
        
        return teacher_out
    

    def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3):
        '''
        Sinkhorn-Knopp centering https://arxiv.org/pdf/2006.09882.pdf (14 page)
        '''

        Q = tf.transpose(tf.math.exp(teacher_output / teacher_temp))
        Q /= tf.reduce_sum(Q)
        K, B = Q.shape

        for _ in range(n_iterations):
            Q /= tf.reduce_sum(Q, axis=1)
            Q /= K

            Q /= tf.reduce_sum(Q, axis=0)
            Q /= B

        Q *= B

        return tf.transpose(Q)


    def forward(self, student_output, teacher_output, epoch, centering='softmax_center'):
        '''
        Cross-entropy between softmax outputs of the teacher and student networks.
        '''
        
        total_loss = 0

        teacher_output = tf.cast(teacher_output, tf.float32)
        student_output = tf.cast(student_output, tf.float32)

        student_out = student_output / self.student_temp
        student_out = tf.split(student_out, num_or_size_splits=self.ncrops)

        # teacher centering and sharpening
        self.center = tf.zeros_like(teacher_output, dtype=tf.float32)
        
        if centering == 'softmax_center':
            teacher_out = self.softmax_center_teacher(teacher_out, self.teacher_temp_schedule[epoch])
        elif centering == 'sinkhorn_knopp':
            teacher_out = self.sinkhorn_knopp_teacher(teacher_out, self.teacher_temp_schedule[epoch])
        else:
            raise ValueError('Wrang centering algorithm')
        
        teacher_out = tf.split(
            tf.tile(teacher_out, tf.constant([2, 1], tf.int32)), num_or_size_splits=1
        )

        total_loss = 0
        n_loss_terms = 0

        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                q = tf.stop_gradient(q)
                
                if v == iq:
                    # skip cases where student and teacher operate on the same view
                    continue
                
                loss = tf.reduce_sum(-q * tf.nn.log_softmax(student_out[v], axis=-1), axis=-1)
                total_loss += tf.math.reduce_mean(loss)
                n_loss_terms += 1
        
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        
        return total_loss


In [14]:
# Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search
# Article -> https://arxiv.org/pdf/1806.03198.pdf

class KoLeoLoss(tf.keras.losses.Loss):
    
    def __init__(self) -> None:
        super().__init__()

    
    def pairwise_NNs_inner(self, x):
        dots = tf.linalg.matmul(x, tf.transpose(x))
        n = x.shape[0]

        dots = tf.linalg.set_diag(dots, np.array([-1] * n))

        return tf.math.argmax(dots, 1)
        

    def forward(self, student_output, eps=1e-8): 
        student_output = tf.nn.l2_normalize(student_output, axis=1, epsilon=eps)
        idx = self.pairwise_NNs_inner(student_output)

        distances = tf.norm(student_output - student_output[idx], ord='euclidean')
         
        return -tf.math.reduce_mean(tf.math.log(distances + eps))