In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torch
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from typing import Callable
from ViT import VissionTransformer

In [8]:
# These augmentations are defined exactly as proposed in the paper
def global_augment(images):
    global_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.4, 1.0)),  # Larger crops
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),  # Color jittering
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return torch.stack([global_transform(img) for img in images])

def multiple_local_augments(images, num_crops=6):
    size = 96  # Smaller crops for local
    local_transform = transforms.Compose([
        transforms.RandomResizedCrop(size, scale=(0.05, 0.4)),  # Smaller, more concentrated crops
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),  # Same level of jittering
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    # Apply the transformation multiple times to the same image
    return torch.stack([local_transform(img) for img in images])

In [9]:
from typing import Callable

class DINO(nn.Module):
    def __init__(self, student_arch: Callable, teacher_arch: Callable, device: torch.device):
        """
        Args:
            student_arch (nn.Module): ViT Network for student_arch
            teacher_arch (nn.Module): ViT Network for teacher_arch
            device: torch.device ('cuda' or 'cpu')
        """
        super(DINO, self).__init__()

        self.student = student_arch().to(device)
        self.teacher = teacher_arch().to(device)
        self.teacher.load_state_dict(self.student.state_dict())

        # Initialize center as nuffer to avoid backpropagation
        self.register_buffer('center', torch.zeros(1, student_arch().output_dim))

        for param in self.teacher.parameters():
            param.requires_grad = False

    @staticmethod
    def distillication_loss(student_logits, teacher_logits, center, tau_s, tau_t):
        """
        Creating the centered and sharpened loss function to evaluate the student's performance

        NOTE:
        """
        # Detatching teacher logits to stop gradients from flowing back into the teacher
        teacher_logits = teacher_logits.detach()

        # Center and sharpen the teacher's logits
        teacher_probs = F.softmax((teacher_logits - center) / tau_t, dim=1)

        # Sharpen the student's logits
        student_probs = F.log_softmax(student_logits / tau_s, dim=1)

        # Calculate cross-entropy loss between the student's and teacher's probs
        loss = - (teacher_probs * student_probs).sum(dim=1).mean()
        return loss

    def teacher_update(self, beta: float):
        for teacher_params, student_params in zip(self.teacher.parameters(), self.student.parameters()):
            teacher_params.data.mul_(beta).add_(student_params.data, alpha=(1 - beta))

In [10]:
student = VissionTransformer()
teacher = VissionTransformer()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dino = DINO(student, teacher, device)

NameError: name 'ViT' is not defined