<a href="https://colab.research.google.com/github/basaanithanaveenkumar/self-implementation-DINO/blob/main/notebooks/DINO_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The EMA update rule for a parameter vector is:

$$
\theta_{\text{teacher}} \gets m \times \theta_{\text{teacher}} + (1 - m) \times \theta_{\text{student}}
$$

Where:
- $\theta_{\text{teacher}}$: Teacher model parameters
- $\theta_{\text{student}}$: Student model parameters  
- $m$: Momentum coefficient (typically close to 1, e.g., 0.99, 0.996)

# Why EMA is Used in DINO

## 1. Stable Targets
The teacher network provides consistent, slowly evolving targets for the student to learn from.

## 2. Prevents Collapse
EMA helps avoid the trivial solution where both networks output constant representations.

## 3. Improved Generalization
The teacher acts as an ensemble of previous student models, capturing robust features./

In [4]:

class DINO_MLP_HD(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=2048, bottleneck_dim=1256,
                 n_layers=4, use_layer_norm=True,):
        super().__init__()
        # Build the MLP layers
        layers = []

        # Input layer
        layers.append(nn.Linear(in_dim, hidden_dim))
        if use_bn:
            layers.append(nn.BatchNorm1d(hidden_dim))
        layers.append(nn.GELU())

        # Hidden layers
        for _ in range(n_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim))
            layers.append(nn.GELU())

        # Bottleneck layer
        layers.append(nn.Linear(hidden_dim, bottleneck_dim))

        # Create the MLP
        self.mlp = nn.Sequential(*layers)

        self.last_layer=nn.linear(bottleneck_dim,out_dim)

    def forward(self, x):
        x = self.mlp(x)
        x=self.last_layer(x)
        return x

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
# vibe coding


class DINO_MLP_HD(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=2048, bottleneck_dim=1256,
                 n_layers=4, use_layer_norm=True):
        super().__init__()

        # Build the MLP layers
        layers = []

        # Input layer
        layers.append(nn.Linear(in_dim, hidden_dim))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim))
        layers.append(nn.GELU())

        # Hidden layers
        for _ in range(n_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim))
            layers.append(nn.GELU())

        # Bottleneck layer
        layers.append(nn.Linear(hidden_dim, bottleneck_dim))

        # Create the MLP
        self.mlp = nn.Sequential(*layers)

        # Last layer (corrected from nn.linear to nn.Linear)
        self.last_layer = nn.Linear(bottleneck_dim, out_dim)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # Use Xavier initialization for linear layers
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.mlp(x)
        x = self.last_layer(x)
        return x

    def print_summary(self, input_size=(1, 256)):
        """
        Print model summary similar to torchsummary
        """
        print("=" * 80)
        print("DINO MLP Head Summary")
        print("=" * 80)
        print(f"Input size: {input_size}")
        print(f"Output size: {input_size[0]}, {self.last_layer.out_features}")
        print("-" * 80)

        total_params = 0
        trainable_params = 0

        print(f"{'Layer (type)':<25} {'Output Shape':<20} {'Param #':<15} {'Trainable':<10}")
        print("=" * 80)

        # Create a dummy input
        x = torch.randn(input_size)

        # Track layers
        layers_info = []

        # Process through MLP
        for i, layer in enumerate(self.mlp):
            if isinstance(layer, nn.Linear):
                layer_type = f"Linear_{i//3+1}"
                param_count = sum(p.numel() for p in layer.parameters())
                total_params += param_count
                trainable_params += param_count

                # Get output shape
                x = layer(x)
                output_shape = list(x.shape)

                layers_info.append({
                    'name': layer_type,
                    'output_shape': output_shape.copy(),
                    'params': param_count,
                    'trainable': True
                })
            elif isinstance(layer, nn.LayerNorm):
                layer_type = f"LayerNorm_{i//3+1}"
                param_count = sum(p.numel() for p in layer.parameters())
                total_params += param_count
                trainable_params += param_count

                # Get output shape
                x = layer(x)
                output_shape = list(x.shape)

                layers_info.append({
                    'name': layer_type,
                    'output_shape': output_shape.copy(),
                    'params': param_count,
                    'trainable': True
                })
            elif isinstance(layer, nn.GELU):
                layer_type = f"GELU_{i//3+1}"
                param_count = 0

                # Get output shape
                x = layer(x)
                output_shape = list(x.shape)

                layers_info.append({
                    'name': layer_type,
                    'output_shape': output_shape.copy(),
                    'params': param_count,
                    'trainable': False
                })

        # Process last layer
        layer_type = "Last_Linear"
        param_count = sum(p.numel() for p in self.last_layer.parameters())
        total_params += param_count
        trainable_params += param_count

        x = self.last_layer(x)
        output_shape = list(x.shape)

        layers_info.append({
            'name': layer_type,
            'output_shape': output_shape.copy(),
            'params': param_count,
            'trainable': True
        })

        # Print all layers
        for info in layers_info:
            print(f"{info['name']:<25} {str(info['output_shape']):<20} {info['params']:<15} {info['trainable']:<10}")

        print("=" * 80)
        print(f"Total params: {total_params:,}")
        print(f"Trainable params: {trainable_params:,}")
        print(f"Non-trainable params: {total_params - trainable_params:,}")
        print("=" * 80)

        return x.shape  # Return final output shape


# Test the model
if __name__ == "__main__":
    # Create model instance
    in_dim = 256
    out_dim = 1024
    model = DINO_MLP_HD(in_dim, out_dim, n_layers=4, use_layer_norm=True)

    # Print model summary
    print("Model Architecture:")
    print(model)
    print("\n" + "="*50 + "\n")

    # Print detailed summary
    final_shape = model.print_summary(input_size=(1, in_dim))

    # Test forward pass
    print("\nTesting forward pass...")
    batch_size = 4
    test_input = torch.randn(batch_size, in_dim)

    with torch.no_grad():
        output = model(test_input)

    print(f"Input shape: {test_input.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Output range: [{output.min().item():.4f}, {output.max().item():.4f}]")
    print(f"Output mean: {output.mean().item():.4f}")
    print(f"Output std: {output.std().item():.4f}")

Model Architecture:
DINO_MLP_HD(
  (mlp): Sequential(
    (0): Linear(in_features=256, out_features=2048, bias=True)
    (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=2048, out_features=2048, bias=True)
    (4): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    (5): GELU(approximate='none')
    (6): Linear(in_features=2048, out_features=2048, bias=True)
    (7): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    (8): GELU(approximate='none')
    (9): Linear(in_features=2048, out_features=1256, bias=True)
  )
  (last_layer): Linear(in_features=1256, out_features=1024, bias=True)
)


DINO MLP Head Summary
Input size: (1, 256)
Output size: 1, 1024
--------------------------------------------------------------------------------
Layer (type)              Output Shape         Param #         Trainable 
Linear_1                  [1, 2048]            526336          1         
LayerNorm_1              

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import copy
from typing import List, Optional

class DINO_MLP_HD(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=2048, bottleneck_dim=256,
                 n_layers=4, use_layer_norm=True):
        super().__init__()

        # Build the MLP layers
        layers = []

        # Input layer
        layers.append(nn.Linear(in_dim, hidden_dim))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim))
        layers.append(nn.GELU())

        # Hidden layers
        for _ in range(n_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim))
            layers.append(nn.GELU())

        # Bottleneck layer
        layers.append(nn.Linear(hidden_dim, bottleneck_dim))

        # Create the MLP
        self.mlp = nn.Sequential(*layers)

        # Last layer
        self.last_layer = nn.Linear(bottleneck_dim, out_dim)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.mlp(x)
        x = self.last_layer(x)
        return x


class VisionTransformerWrapper(nn.Module):
    """
    Wrapper for Vision Transformer backbones
    """
    def __init__(self, model_name: str, img_size: int = 224, pretrained: bool = False):
        super().__init__()
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            img_size=img_size,
            num_classes=0  # Remove classification head
        )

        # Get feature dimension
        if hasattr(self.backbone, 'num_features'):
            self.feature_dim = self.backbone.num_features
        else:
            # For ViT models, use embed_dim
            self.feature_dim = self.backbone.embed_dim

    def forward(self, x):
        return self.backbone(x)


class DINOStudent(nn.Module):
    """
    Student model for DINO framework
    """
    def __init__(
        self,
        backbone: nn.Module,
        head: nn.Module,
        momentum_teacher: float = 0.996,
        use_momentum_schedule: bool = True
    ):
        super().__init__()
        self.backbone = backbone
        self.head = head
        self.momentum_teacher = momentum_teacher
        self.use_momentum_schedule = use_momentum_schedule

        # Initialize teacher with the same architecture but no gradients
        self.teacher = None
        self.init_teacher()

        # Register buffer for momentum schedule
        self.register_buffer('iteration', torch.tensor(0, dtype=torch.long))

    def init_teacher(self):
        """Initialize teacher model with student weights"""
        # Create a copy of student backbone and head
        teacher_backbone = copy.deepcopy(self.backbone)
        teacher_head = copy.deepcopy(self.head)

        # Freeze teacher parameters
        for param in teacher_backbone.parameters():
            param.requires_grad = False
        for param in teacher_head.parameters():
            param.requires_grad = False

        self.teacher = nn.ModuleDict({
            'backbone': teacher_backbone,
            'head': teacher_head
        })

    @torch.no_grad()
    def update_teacher(self):
        """Update teacher weights with exponential moving average"""
        # Calculate momentum based on schedule if enabled
        if self.use_momentum_schedule:
            momentum = self.momentum_schedule()
        else:
            momentum = self.momentum_teacher

        # Update teacher backbone
        for param_s, param_t in zip(self.backbone.parameters(),
                                   self.teacher['backbone'].parameters()):
            param_t.data.mul_(momentum).add_((1 - momentum) * param_s.detach().data)

        # Update teacher head
        for param_s, param_t in zip(self.head.parameters(),
                                   self.teacher['head'].parameters()):
            param_t.data.mul_(momentum).add_((1 - momentum) * param_s.detach().data)

        # Increment iteration counter
        self.iteration += 1

    def momentum_schedule(self):
        """Cosine momentum schedule from 0.996 to 1.0"""
        # Original DINO schedule: cosine from base_m to 1.0
        base_m = self.momentum_teacher
        final_m = 1.0
        # In practice, DINO uses a different schedule but this is a common approximation
        return final_m - (final_m - base_m) * (torch.cos(torch.pi * self.iteration / 200000) + 1) / 2

    def forward(self, x: torch.Tensor, return_features: bool = False):
        """Forward pass through student network"""
        features = self.backbone(x)
        output = self.head(features)

        if return_features:
            return output, features
        return output

    @torch.no_grad()
    def teacher_forward(self, x: torch.Tensor):
        """Forward pass through teacher network (no gradients)"""
        features = self.teacher['backbone'](x)
        output = self.teacher['head'](features)
        return output


class DINOTeacher(nn.Module):
    """
    Teacher model for DINO framework (separate implementation for clarity)
    """
    def __init__(self, backbone: nn.Module, head: nn.Module):
        super().__init__()
        self.backbone = backbone
        self.head = head

        # Freeze all parameters
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x: torch.Tensor):
        """Forward pass through teacher network"""
        with torch.no_grad():
            features = self.backbone(x)
            output = self.head(features)
        return output


class MultiCropDINO(nn.Module):
    """
    Wrapper for DINO that handles multiple crops
    """
    def __init__(self, student: DINOStudent):
        super().__init__()
        self.student = student

    def forward(self, x: List[torch.Tensor]):
        """
        Forward pass for multiple crops

        Args:
            x: List of tensors representing different crops

        Returns:
            student_outputs: List of student outputs for each crop
            teacher_outputs: List of teacher outputs for each crop
        """
        student_outputs = []
        teacher_outputs = []

        # Process each crop through student and teacher
        for crop in x:
            # Student forward
            student_out = self.student(crop)
            student_outputs.append(student_out)

            # Teacher forward
            with torch.no_grad():
                teacher_out = self.student.teacher_forward(crop)
                teacher_outputs.append(teacher_out)

        return student_outputs, teacher_outputs


# Example usage and testing
if __name__ == "__main__":
    # Configuration
    backbone_name = "vit_small_patch16_224"
    img_size = 224
    in_dim = 384  # ViT-small feature dimension
    out_dim = 65536  # As in original DINO paper
    hidden_dim = 2048
    bottleneck_dim = 256

    # Create backbone
    backbone = VisionTransformerWrapper(backbone_name, img_size, pretrained=False)

    # Create DINO head
    dino_head = DINO_MLP_HD(
        in_dim=in_dim,
        out_dim=out_dim,
        hidden_dim=hidden_dim,
        bottleneck_dim=bottleneck_dim,
        n_layers=4,
        use_layer_norm=True
    )

    # Create student model
    student = DINOStudent(
        backbone=backbone,
        head=dino_head,
        momentum_teacher=0.996,
        use_momentum_schedule=True
    )

    # Create multi-crop wrapper
    model = MultiCropDINO(student)

    # Print model information
    print("Student Model:")
    print(student)
    print(f"Number of parameters: {sum(p.numel() for p in student.parameters()):,}")
    print(f"Number of trainable parameters: {sum(p.numel() for p in student.parameters() if p.requires_grad):,}")

    print("\nTeacher Model:")
    print(student.teacher)
    print(f"Number of parameters: {sum(p.numel() for p in student.teacher.parameters()):,}")

    # Test with sample input (2 global crops + 4 local crops)
    batch_size = 2
    global_crops = [torch.randn(batch_size, 3, img_size, img_size) for _ in range(2)]
    local_crops = [torch.randn(batch_size, 3, img_size//2, img_size//2) for _ in range(4)]
    all_crops = global_crops + local_crops

    # Forward pass
    student_outputs, teacher_outputs = model(all_crops)

    print(f"\nNumber of crops: {len(all_crops)}")
    print(f"Student outputs: {len(student_outputs)} tensors")
    print(f"Teacher outputs: {len(teacher_outputs)} tensors")

    for i, (s_out, t_out) in enumerate(zip(student_outputs, teacher_outputs)):
        print(f"Crop {i}: student shape {s_out.shape}, teacher shape {t_out.shape}")

    # Update teacher
    student.update_teacher()
    print(f"\nUpdated teacher (iteration {student.iteration.item()})")

    # Test with a single image
    single_image = torch.randn(1, 3, img_size, img_size)
    student_output = student(single_image)
    teacher_output = student.teacher_forward(single_image)

    print(f"\nSingle image:")
    print(f"Student output shape: {student_output.shape}")
    print(f"Teacher output shape: {teacher_output.shape}")

In [6]:

class DINO_MLP_HD(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=2048, bottleneck_dim=1256,
                 n_layers=4, use_layer_norm=True,):
        super().__init__()
        # Build the MLP layers
        layers = []

        # Input layer
        layers.append(nn.Linear(in_dim, hidden_dim))
        if use_bn:
            layers.append(nn.BatchNorm1d(hidden_dim))
        layers.append(nn.GELU())

        # Hidden layers
        for _ in range(n_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim))
            layers.append(nn.GELU())

        # Bottleneck layer
        layers.append(nn.Linear(hidden_dim, bottleneck_dim))

        # Create the MLP
        self.mlp = nn.Sequential(*layers)

        self.last_layer=nn.linear(bottleneck_dim,out_dim)

    def forward(self, x):
        x = self.mlp(x)
        x=self.last_layer(x)
        return x





class VisionTransformerWrapper(nn.Module):
    """
    Wrapper for Vision Transformer backbones
    """
    def __init__(self, model_name: str, img_size: int = 224, pretrained: bool = True, is_teacher=True):
        super().__init__()
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            img_size=img_size,
            #num_classes=0  # Remove classification head
        )

        # Get feature dimension
        if hasattr(self.backbone, 'num_features'):
            self.feature_dim = self.backbone.num_features
        else:
            # For ViT models, use embed_dim
            self.feature_dim = self.backbone.embed_dim
        input_vec_dim=self.backbone.in_features
        # self.backbone.head = nn.Identity()
        layers = list(self.backbone.children())[:-1]
        self.backbone = nn.Sequential(*layers)
        for param in self.model.parameters():
          if is_teacher:
            param.requires_grad = False
          else:
            param.requires_grad = True
        self.dino_mlp_head = DINO_MLP_HD(input_dim, out_dim)
        dino_head = DINO_MLP_HD(
        in_dim=input_vec_dim,
        out_dim=1024,
        hidden_dim=2038,
        bottleneck_dim=152,
        n_layers=5,
        use_layer_norm=True
    )

    def forward(self, x):
        vis_features=self.backbone(x)
        cls_token=vis_features[:,0]
        print(cls_token.shape, "log class token shape")
        x=self.dino_mlp_head(cls_token)
        return F.normalize(x, dim=-1, p=2)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DINOLoss(nn.Module):
    """
    DINO loss function implementation.

    This loss function implements the self-distillation with no labels approach
    used in the DINO paper. It consists of:
    1. Cross-entropy loss between student and teacher outputs
    2. Centering of teacher outputs to avoid collapse
    3. Sharpening of teacher distributions with temperature

    Args:
        out_dim (int): Output dimension of the projection head
        warmup_teacher_temp (float): Initial teacher temperature
        teacher_temp (float): Final teacher temperature (after warmup)
        warmup_teacher_temp_epochs (int): Number of warmup epochs for teacher temperature
        nepochs (int): Total number of epochs
        student_temp (float): Student temperature
        center_momentum (float): Momentum for center update
    """

    def __init__(self, out_dim, warmup_teacher_temp=0.04, teacher_temp=0.04,
                 warmup_teacher_temp_epochs=30, nepochs=100, student_temp=0.1,
                 center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, out_dim))

        # Temperature scheduling
        self.teacher_temp_schedule = torch.cat((
            torch.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
            torch.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))

    def forward(self, student_output, teacher_output, epoch):
        """
        Forward pass of the DINO loss.

        Args:
            student_output: List of student outputs for different crops
            teacher_output: List of teacher outputs for different crops
            epoch: Current epoch number (for temperature scheduling)

        Returns:
            Loss value
        """
        # Get current teacher temperature
        teacher_temp = self.teacher_temp_schedule[epoch].item()

        # Gather all outputs
        student_out = self.gather_outputs(student_output)
        teacher_out = self.gather_outputs(teacher_output)

        # Apply temperature to student outputs
        student_out = student_out / self.student_temp

        # Apply temperature and center to teacher outputs
        teacher_out = F.softmax((teacher_out - self.center) / teacher_temp, dim=-1)
        teacher_out = teacher_out.detach()  # Detach to stop gradients

        # Calculate cross-entropy loss
        loss = -torch.sum(teacher_out * F.log_softmax(student_out, dim=-1), dim=-1)
        loss = loss.mean()

        # Update center
        self.update_center(teacher_output)

        return loss

    def gather_outputs(self, outputs):
        """
        Gather outputs from all crops and concatenate them.
        """
        return torch.cat([output for output in outputs], dim=0)

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output centering.
        """
        # Gather all teacher outputs
        teacher_out = self.gather_outputs(teacher_output)

        # Calculate batch mean
        batch_center = torch.mean(teacher_out, dim=0, keepdim=True)

        # Update center with momentum
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)


In [None]:
# Test parameters
model_name = "vit_tiny_patch16_224"
img_size = 224
batch_size = 2

print("Testing VisionTransformerWrapper...")
print("=" * 50)

# Create student model (not teacher)
print("Creating student model...")
student_model = VisionTransformerWrapper(
    model_name=model_name,
    img_size=img_size,
    pretrained=False,  # Use False for faster testing
    is_teacher=False
)

# Create teacher model
print("Creating teacher model...")
teacher_model = VisionTransformerWrapper(
    model_name=model_name,
    img_size=img_size,
    pretrained=False,  # Use False for faster testing
    is_teacher=True
)

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import timm

# First, let's define the DINO_MLP_HD class if it's not already defined
class DINO_MLP_HD(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=2048, bottleneck_dim=256,
                 n_layers=4, use_layer_norm=True):
        super().__init__()

        # Build the MLP layers
        layers = []

        # Input layer
        layers.append(nn.Linear(in_dim, hidden_dim))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim))
        layers.append(nn.GELU())

        # Hidden layers
        for _ in range(n_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim))
            layers.append(nn.GELU())

        # Bottleneck layer
        layers.append(nn.Linear(hidden_dim, bottleneck_dim))

        # Create the MLP
        self.mlp = nn.Sequential(*layers)

        # Last layer
        self.last_layer = nn.Linear(bottleneck_dim, out_dim)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.mlp(x)
        x = self.last_layer(x)
        return x

# Now let's define the VisionTransformerWrapper with a fix for the parameter freezing
class VisionTransformerWrapper(nn.Module):
    """
    Wrapper for Vision Transformer backbones
    """
    def __init__(self, model_name: str, img_size: int = 224, pretrained: bool = True, is_teacher=True):
        super().__init__()
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            img_size=img_size,
            #num_classes=0  # Remove classification head
        )

        # Get feature dimension
        if hasattr(self.backbone, 'num_features'):
            self.feature_dim = self.backbone.num_features
        else:
            # For ViT models, use embed_dim
            self.feature_dim = self.backbone.embed_dim

        # Get input dimension from the backbone
        input_vec_dim = self.feature_dim

        # Remove the classification head
        layers = list(self.backbone.children())[:-1]
        self.backbone = nn.Sequential(*layers)

        # Freeze parameters if it's a teacher
        for param in self.backbone.parameters():
            if is_teacher:
                param.requires_grad = False
            else:
                param.requires_grad = True

        # Create the DINO head
        self.dino_mlp_head = DINO_MLP_HD(
            in_dim=input_vec_dim,
            out_dim=1024,
            hidden_dim=2048,
            bottleneck_dim=256,
            n_layers=3,
            use_layer_norm=True
        )

    def forward(self, x):
        vis_features = self.backbone(x)
        # Assuming the first token is the class token
        cls_token = vis_features[:, 0]
        print(f"Class token shape: {cls_token.shape}")
        x = self.dino_mlp_head(cls_token)
        return F.normalize(x, dim=-1, p=2)

# Test function
def test_vision_transformer_wrapper():
    # Test parameters
    model_name = "vit_tiny_patch16_224"
    img_size = 224
    batch_size = 2

    print("Testing VisionTransformerWrapper...")
    print("=" * 50)

    # Create student model (not teacher)
    print("Creating student model...")
    student_model = VisionTransformerWrapper(
        model_name=model_name,
        img_size=img_size,
        pretrained=False,  # Use False for faster testing
        is_teacher=False
    )

    # Create teacher model
    print("Creating teacher model...")
    teacher_model = VisionTransformerWrapper(
        model_name=model_name,
        img_size=img_size,
        pretrained=False,  # Use False for faster testing
        is_teacher=True
    )

    # Create a dummy input tensor
    dummy_input = torch.randn(batch_size, 3, img_size, img_size)
    print(f"Input shape: {dummy_input.shape}")

    # Test student model
    print("\nTesting student model forward pass...")
    student_output = student_model(dummy_input)
    print(f"Student output shape: {student_output.shape}")
    print(f"Student output norm: {torch.norm(student_output, dim=1)}")

    # Check gradient requirements for student
    student_params_require_grad = sum(p.requires_grad for p in student_model.parameters())
    print(f"Student parameters requiring grad: {student_params_require_grad}")

    # Test teacher model
    print("\nTesting teacher model forward pass...")
    teacher_output = teacher_model(dummy_input)
    print(f"Teacher output shape: {teacher_output.shape}")
    print(f"Teacher output norm: {torch.norm(teacher_output, dim=1)}")

    # Check gradient requirements for teacher
    teacher_params_require_grad = sum(p.requires_grad for p in teacher_model.parameters())
    print(f"Teacher parameters requiring grad: {teacher_params_require_grad}")

    # Test with gradients
    print("\nTesting backward pass with student model...")
    dummy_input.requires_grad = True
    student_output = student_model(dummy_input)

    # Create a dummy loss
    dummy_target = torch.randn_like(student_output)
    loss = F.mse_loss(student_output, dummy_target)

    # Backward pass
    loss.backward()

    print(f"Loss: {loss.item()}")
    print("Backward pass completed successfully!")

    # Verify gradients
    has_gradients = any(p.grad is not None for p in student_model.parameters() if p.requires_grad)
    print(f"Model parameters have gradients: {has_gradients}")

    # Print model summary
    print("\nModel Summary:")
    print("=" * 50)
    total_params = sum(p.numel() for p in student_model.parameters())
    trainable_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Frozen parameters: {total_params - trainable_params:,}")

    # Test with different input sizes
    print("\nTesting with different input sizes...")
    for test_size in [224]:
        test_input = torch.randn(batch_size, 3, test_size, test_size)
        try:
            output = student_model(test_input)
            print(f"Input size {test_size}x{test_size}: Output shape {output.shape}")
        except Exception as e:
            print(f"Input size {test_size}x{test_size}: Error - {e}")

if __name__ == "__main__":
    test_vision_transformer_wrapper()

Testing VisionTransformerWrapper...
Creating student model...
Creating teacher model...
Input shape: torch.Size([2, 3, 224, 224])

Testing student model forward pass...
Class token shape: torch.Size([2, 192])
Student output shape: torch.Size([2, 1024])
Student output norm: tensor([1.0000, 1.0000], grad_fn=<LinalgVectorNormBackward0>)
Student parameters requiring grad: 160

Testing teacher model forward pass...
Class token shape: torch.Size([2, 192])
Teacher output shape: torch.Size([2, 1024])
Teacher output norm: tensor([1.0000, 1.0000], grad_fn=<LinalgVectorNormBackward0>)
Teacher parameters requiring grad: 12

Testing backward pass with student model...
Class token shape: torch.Size([2, 192])
Loss: 0.9629690051078796
Backward pass completed successfully!
Model parameters have gradients: True

Model Summary:
Total parameters: 10,873,920
Trainable parameters: 10,873,920
Frozen parameters: 0

Testing with different input sizes...
Class token shape: torch.Size([2, 192])
Input size 224x22

In [None]:
```markdown
# Setup Instructions

## 1. Clone the Repository
```bash
git clone https://github.com/basaanithanaveenkumar/object-detection-BBD.git
cd object-detection-BBD
```

## 2. Create Data Directory
```bash
mkdir -p data
```

## 3. Download Dataset
```bash
python scripts/download_dataset.py
```

## 4. Organize Directory Structure
```bash
mv data/100k/val data/100k/valid
```

## 5. Convert to COCO Format
```bash
python scripts/convert_to_coco.py
```

## Workflow Summary
This setup process:
1. Clones the object detection project repository
2. Creates the necessary directory structure
3. Downloads the BBD (Berkeley DeepDrive) dataset
4. Renames the validation directory to match expected conventions
5. Converts the BBD dataset format to standard COCO format for compatibility with object detection frameworks
```

In [4]:
# clone the object detection-BBD project
!git clone https://github.com/basaanithanaveenkumar/object-detection-BBD.git
%cd object-detection-BBD

# create a dir to store the dataset
!mkdir -p data

# Download and extract dataset
!python scripts/download_dataset.py

# organise the dir structure for coco datset
!mv data/100k/val data/100k/valid

# convert the BBD dataset into COCO Dataset
!python scripts/convert_to_coco.py

Cloning into 'object-detection-BBD'...
remote: Enumerating objects: 176, done.[K
remote: Counting objects: 100% (176/176), done.[K
remote: Compressing objects: 100% (124/124), done.[K
remote: Total 176 (delta 111), reused 92 (delta 46), pack-reused 0 (from 0)[K
Receiving objects: 100% (176/176), 5.90 MiB | 30.80 MiB/s, done.
Resolving deltas: 100% (111/111), done.
/content/object-detection-BBD

Downloading images_100k...
data/bdd100k_images_100k.zip: 100% 5.28G/5.28G [11:09<00:00, 8.47MiB/s]
Extracting data/bdd100k_images_100k.zip...
Extracted to data

Downloading labels...
data/bdd100k_labels.zip: 100% 181M/181M [00:17<00:00, 10.8MiB/s]
Extracting data/bdd100k_labels.zip...
Extracted to data

BDD100K dataset download complete!
{'car': 1, 'traffic sign': 2, 'traffic light': 3, 'person': 4, 'truck': 5, 'bus': 6, 'bike': 7, 'rider': 8, 'motor': 9, 'train': 10}
Processing annotations: 100% 10000/10000 [00:03<00:00, 3027.12it/s]
Conversion complete. COCO format JSON saved to data/100k/

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DINOLoss(nn.Module):
    def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
                 warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
                 center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.ncrops = ncrops
        self.register_buffer("center", torch.zeros(1, out_dim))
        # we apply a warm up for the teacher temperature because
        # a too high temperature makes the training instable at the beginning
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))

    def forward(self, student_output, teacher_output, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        """
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.ncrops)

        # teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    # we skip cases where student and teacher operate on the same view
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output.
        """
        batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
        dist.all_reduce(batch_center)
        batch_center = batch_center / (len(teacher_output) * dist.get_world_size())

        # ema update
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)



# # one more generated by deep seek


class DINOLoss(nn.Module):
    """
    DINO loss function implementation.

    This loss function implements the self-distillation with no labels approach
    used in the DINO paper. It consists of:
    1. Cross-entropy loss between student and teacher outputs
    2. Centering of teacher outputs to avoid collapse
    3. Sharpening of teacher distributions with temperature

    Args:
        out_dim (int): Output dimension of the projection head
        warmup_teacher_temp (float): Initial teacher temperature
        teacher_temp (float): Final teacher temperature (after warmup)
        warmup_teacher_temp_epochs (int): Number of warmup epochs for teacher temperature
        nepochs (int): Total number of epochs
        student_temp (float): Student temperature
        center_momentum (float): Momentum for center update
    """

    def __init__(self, out_dim, warmup_teacher_temp=0.04, teacher_temp=0.07,
                 warmup_teacher_temp_epochs=30, nepochs=100, student_temp=0.1,
                 center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.nepochs = nepochs
        self.warmup_teacher_temp_epochs = warmup_teacher_temp_epochs

        # Register buffer for center
        self.register_buffer("center", torch.zeros(1, out_dim))

        # Temperature scheduling
        self.warmup_teacher_temp = warmup_teacher_temp
        self.teacher_temp = teacher_temp

    def forward(self, student_output, teacher_output, epoch):
        """
        Forward pass of the DINO loss.

        Args:
            student_output: List of student outputs for different crops
            teacher_output: List of teacher outputs for different crops
            epoch: Current epoch number (for temperature scheduling)

        Returns:
            Loss value
        """
        # Get current teacher temperature (with warmup)
        teacher_temp = self.get_teacher_temp(epoch)

        # Gather all outputs
        student_out = self.gather_outputs(student_output)
        teacher_out = self.gather_outputs(teacher_output).detach()  # Detach early

        # Apply temperature to student outputs
        if isinstance(student_output, list):
            student_output = torch.cat(student_output, dim=0)
            teacher_output = torch.cat(teacher_output, dim=0)
        student_out = student_out / self.student_temp

        # Apply temperature and center to teacher outputs
        teacher_out = (teacher_out - self.center) / teacher_temp
        teacher_out = F.softmax(teacher_out, dim=-1)

        # Calculate cross-entropy loss
        loss = -torch.sum(teacher_out * F.log_softmax(student_out, dim=-1), dim=-1)
        loss = loss.mean()

        # Update center
        self.update_center(teacher_output)

        return loss

    def get_teacher_temp(self, epoch):
        """Get teacher temperature with warmup schedule"""
        if epoch < self.warmup_teacher_temp_epochs:
            # Linear warmup
            return self.warmup_teacher_temp + (self.teacher_temp - self.warmup_teacher_temp) * \
                   epoch / self.warmup_teacher_temp_epochs
        else:
            return self.teacher_temp

    def gather_outputs(self, outputs):
        """
        Gather outputs from all crops and concatenate them.
        """
        return torch.cat([output for output in outputs], dim=0)

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output centering.
        """
        # Gather all teacher outputs (raw, before temperature/softmax)
        teacher_out = self.gather_outputs(teacher_output)

        # Calculate batch mean
        batch_center = torch.mean(teacher_out, dim=0, keepdim=True)

        # Update center with momentum
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)


In [18]:
import torch
from torch.utils.data import Dataset, DataLoader
from pycocotools.coco import COCO
import cv2
import numpy as np
import torchvision.transforms as transforms
from torchvision.transforms.functional import to_tensor
import random

class COCOMultiCropDataset(Dataset):
    """
    COCO dataset with multi-crop transformations for DINO training
    """
    def __init__(self, annFile, dataDir, global_crop_size=224, local_crop_size=96,
                 num_local_crops=4, transform=None):
        self.coco = COCO(annFile)
        self.dataDir = dataDir
        self.img_ids = self.coco.getImgIds()
        self.global_crop_size = global_crop_size
        self.local_crop_size = local_crop_size
        self.num_local_crops = num_local_crops

        # Normalization (ImageNet stats)
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

        # Default transformations if none provided
        if transform is None:
            self.transform = self.get_default_transforms()
        else:
            self.transform = transform

    def get_default_transforms(self):
        """
        Default DINO multi-crop transformations
        """
        # Global crops (2x)
        global_transform = transforms.Compose([
            transforms.RandomResizedCrop(self.global_crop_size, scale=(0.4, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([transforms.GaussianBlur(5)], p=0.5),
            transforms.ToTensor(),
            self.normalize
        ])

        # Local crops (multiple)
        local_transform = transforms.Compose([
            transforms.RandomResizedCrop(self.local_crop_size, scale=(0.05, 0.4)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Resize((self.global_crop_size, self.global_crop_size)),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([transforms.GaussianBlur(5)], p=0.5),
            transforms.ToTensor(),
            self.normalize
        ])

        return {
            'global': global_transform,
            'local': local_transform
        }

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]

        # Load image
        img_path = f"{self.dataDir}/{img_info['file_name']}"
        image = cv2.imread(img_path)

        if image is None:
            # If image loading fails, return a random image
            image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
        else:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Convert to PIL Image for transformations
        image_pil = transforms.ToPILImage()(image)

        # Apply transformations
        crops = []

        # Global crops (2x)
        for _ in range(2):
            crops.append(self.transform['global'](image_pil))

        # Local crops (num_local_crops x)
        for _ in range(self.num_local_crops):
            crops.append(self.transform['local'](image_pil))

        return crops

# Create dataset and data loader
def create_coco_dataloader(annFile, dataDir, batch_size=4, num_workers=4,
                          global_crop_size=224, local_crop_size=96, num_local_crops=4):
    """
    Create COCO data loader for DINO training
    """
    dataset = COCOMultiCropDataset(
        annFile=annFile,
        dataDir=dataDir,
        global_crop_size=global_crop_size,
        local_crop_size=local_crop_size,
        num_local_crops=num_local_crops
    )

    # Custom collate function for multi-crop data
    def collate_fn(batch):
        # batch is a list of lists of crops
        # We need to transpose it to group crops by type
        transposed = list(zip(*batch))
        return [torch.stack(crops) for crops in transposed]

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )

    return dataloader

# Usage example
if __name__ == "__main__":
    # Initialize COCO dataset
    dataDir = "/content/object-detection-BBD/data/100k/test/"
    annFile = f"{dataDir}/_annotations.coco.json"

    # Create data loader
    dataloader = create_coco_dataloader(
        annFile=annFile,
        dataDir=dataDir,
        batch_size=4,
        num_workers=4,
        global_crop_size=224,
        local_crop_size=96,
        num_local_crops=4
    )

    # Test the data loader
    for batch_idx, crops in enumerate(dataloader):
        print(f"Batch {batch_idx}:")
        for i, crop_batch in enumerate(crops):
            print(f"  Crop {i}: shape {crop_batch.shape}")

        if batch_idx >= 2:  # Just test a few batches
            break

loading annotations into memory...
Done (t=1.57s)
creating index...
index created!




Batch 0:
  Crop 0: shape torch.Size([4, 3, 224, 224])
  Crop 1: shape torch.Size([4, 3, 224, 224])
  Crop 2: shape torch.Size([4, 3, 224, 224])
  Crop 3: shape torch.Size([4, 3, 224, 224])
  Crop 4: shape torch.Size([4, 3, 224, 224])
  Crop 5: shape torch.Size([4, 3, 224, 224])
Batch 1:
  Crop 0: shape torch.Size([4, 3, 224, 224])
  Crop 1: shape torch.Size([4, 3, 224, 224])
  Crop 2: shape torch.Size([4, 3, 224, 224])
  Crop 3: shape torch.Size([4, 3, 224, 224])
  Crop 4: shape torch.Size([4, 3, 224, 224])
  Crop 5: shape torch.Size([4, 3, 224, 224])
Batch 2:
  Crop 0: shape torch.Size([4, 3, 224, 224])
  Crop 1: shape torch.Size([4, 3, 224, 224])
  Crop 2: shape torch.Size([4, 3, 224, 224])
  Crop 3: shape torch.Size([4, 3, 224, 224])
  Crop 4: shape torch.Size([4, 3, 224, 224])
  Crop 5: shape torch.Size([4, 3, 224, 224])


In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
import numpy as np
import time
import logging
from pathlib import Path

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("DINO_Trainer")

class DINOTrainer:
    def __init__(self, student_model, teacher_model, dataloader,
                 loss_fn, optimizer, device, out_dir="./dino_checkpoints",
                 warmup_epochs=10, total_epochs=100, save_freq=10,
                 use_amp=True, base_lr=0.0005):  # Added base_lr parameter
        """
        DINO trainer for self-supervised learning

        Args:
            student_model: Student model (trainable)
            teacher_model: Teacher model (EMA of student)
            dataloader: DataLoader with multi-crop images
            loss_fn: DINO loss function
            optimizer: Optimizer for student model
            device: Training device (cuda/cpu)
            out_dir: Directory to save checkpoints
            warmup_epochs: Number of warmup epochs for learning rate
            total_epochs: Total training epochs
            save_freq: Frequency of saving checkpoints
            use_amp: Whether to use automatic mixed precision
            base_lr: Base learning rate for scheduling
        """
        self.student = student_model
        self.teacher = teacher_model
        self.dataloader = dataloader
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.device = device
        self.out_dir = Path(out_dir)
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.save_freq = save_freq
        self.use_amp = use_amp
        self.base_lr = base_lr  # Store base learning rate

        # Create output directory
        self.out_dir.mkdir(parents=True, exist_ok=True)

        # Learning rate scheduler
        self.lr_schedule = self._get_lr_schedule()

        # Mixed precision scaler
        self.scaler = GradScaler(enabled=use_amp)

        # Move models to device
        self.student.to(device)
        self.teacher.to(device)

        # Set teacher to eval mode
        self.teacher.eval()

        # Training state
        self.epoch = 0
        self.global_step = 0
        self.best_loss = float('inf')

        # Logging
        logger.info(f"DINO Trainer initialized on device: {device}")
        logger.info(f"Using mixed precision: {use_amp}")
        logger.info(f"Base learning rate: {base_lr}")

    def _get_lr_schedule(self):
        """Create learning rate schedule with warmup"""
        def lr_schedule(step):
            # Warmup for the first warmup_steps
            warmup_steps = self.warmup_epochs * len(self.dataloader)
            if step < warmup_steps:
                return (step + 1) / warmup_steps
            else:
                # Cosine decay after warmup
                total_steps = self.total_epochs * len(self.dataloader)
                return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps)))

        return lr_schedule

    def update_teacher(self, momentum=0.996):
        """Update teacher model with EMA of student weights"""
        with torch.no_grad():
            for param_s, param_t in zip(self.student.parameters(), self.teacher.parameters()):
                param_t.data.mul_(momentum).add_((1 - momentum) * param_s.data)

    def train_epoch(self):
        """Train for one epoch"""
        self.student.train()

        epoch_loss = 0
        num_batches = len(self.dataloader)

        for batch_idx, crops in enumerate(self.dataloader):
            # Move crops to device
            crops = [crop.to(self.device, non_blocking=True) for crop in crops]

            # Update learning rate
            self._adjust_learning_rate(self.global_step)
            lr = self.optimizer.param_groups[0]["lr"]

            # Forward pass
            with autocast(enabled=self.use_amp):
                # Student forward pass (all crops)
                student_outputs = []
                for crop in crops:
                    student_outputs.append(self.student(crop))

                # Teacher forward pass (only global crops)
                teacher_outputs = []
                with torch.no_grad():
                    for crop in crops[:2]:  # First two are global crops
                        teacher_outputs.append(self.teacher(crop))

                # Compute loss
                print(len(student_outputs),"len of student outputs")
                print(len(teacher_outputs), "len of teacher outputs")
                loss = self.loss_fn(student_outputs, teacher_outputs, self.epoch)

            # Backward pass
            self.optimizer.zero_grad()

            if self.use_amp:
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                loss.backward()
                self.optimizer.step()

            # Update teacher with EMA
            self.update_teacher()

            # Update metrics
            epoch_loss += loss.item()
            self.global_step += 1

            # Log progress
            if batch_idx % 100 == 0:
                logger.info(
                    f"Epoch {self.epoch}/{self.total_epochs} | "
                    f"Batch {batch_idx}/{num_batches} | "
                    f"Loss: {loss.item():.4f} | "
                    f"LR: {lr:.6f}"
                )

        return epoch_loss / num_batches

    def _adjust_learning_rate(self, step):
        """Adjust learning rate based on schedule"""
        # Calculate the multiplier from the schedule
        multiplier = self.lr_schedule(step)

        # Set the learning rate for all parameter groups
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = self.base_lr * multiplier

    def train(self):
        """Main training loop"""
        logger.info("Starting DINO training...")
        start_time = time.time()

        for epoch in range(self.epoch, self.total_epochs):
            self.epoch = epoch

            # Train for one epoch
            epoch_loss = self.train_epoch()

            # Log epoch results
            logger.info(
                f"Epoch {epoch}/{self.total_epochs} | "
                f"Avg Loss: {epoch_loss:.4f} | "
                f"Time: {time.time() - start_time:.2f}s"
            )

            # Save checkpoint
            if epoch % self.save_freq == 0 or epoch == self.total_epochs - 1:
                self.save_checkpoint(epoch_loss)

            # Update best loss
            if epoch_loss < self.best_loss:
                self.best_loss = epoch_loss
                self.save_checkpoint(epoch_loss, is_best=True)

        logger.info(f"Training completed in {time.time() - start_time:.2f} seconds")

    def save_checkpoint(self, loss, is_best=False):
        """Save training checkpoint"""
        checkpoint = {
            'epoch': self.epoch,
            'global_step': self.global_step,
            'student_state_dict': self.student.state_dict(),
            'teacher_state_dict': self.teacher.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
            'loss': loss,
            'best_loss': self.best_loss,
            'base_lr': self.base_lr,  # Save base learning rate
        }

        # Save regular checkpoint
        checkpoint_path = self.out_dir / f"checkpoint_epoch_{self.epoch}.pth"
        torch.save(checkpoint, checkpoint_path)

        # Save best checkpoint
        if is_best:
            best_path = self.out_dir / "best_checkpoint.pth"
            torch.save(checkpoint, best_path)
            logger.info(f"New best checkpoint saved with loss: {loss:.4f}")

    def load_checkpoint(self, checkpoint_path):
        """Load training checkpoint"""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        self.epoch = checkpoint['epoch']
        self.global_step = checkpoint['global_step']
        self.best_loss = checkpoint['best_loss']
        self.base_lr = checkpoint.get('base_lr', 0.0005)  # Load base learning rate

        self.student.load_state_dict(checkpoint['student_state_dict'])
        self.teacher.load_state_dict(checkpoint['teacher_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        if self.use_amp and checkpoint['scaler_state_dict'] is not None:
            self.scaler.load_state_dict(checkpoint['scaler_state_dict'])

        logger.info(f"Loaded checkpoint from epoch {self.epoch} with loss {checkpoint['loss']:.4f}")

In [1]:


class DINOLoss(nn.Module):
    """
    DINO loss function implementation.

    This loss function implements the self-distillation with no labels approach
    used in the DINO paper. It consists of:
    1. Cross-entropy loss between student and teacher outputs
    2. Centering of teacher outputs to avoid collapse
    3. Sharpening of teacher distributions with temperature

    Args:
        out_dim (int): Output dimension of the projection head
        warmup_teacher_temp (float): Initial teacher temperature
        teacher_temp (float): Final teacher temperature (after warmup)
        warmup_teacher_temp_epochs (int): Number of warmup epochs for teacher temperature
        nepochs (int): Total number of epochs
        student_temp (float): Student temperature
        center_momentum (float): Momentum for center update
    """

    def __init__(self, out_dim, warmup_teacher_temp=0.04, teacher_temp=0.07,
                 warmup_teacher_temp_epochs=30, nepochs=100, student_temp=0.1,
                 center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.nepochs = nepochs
        self.warmup_teacher_temp_epochs = warmup_teacher_temp_epochs

        # Register buffer for center
        self.register_buffer("center", torch.zeros(1, out_dim))

        # Temperature scheduling
        self.warmup_teacher_temp = warmup_teacher_temp
        self.teacher_temp = teacher_temp

    def forward(self, student_output, teacher_output, epoch):
        """
        Forward pass of the DINO loss.

        Args:
            student_output: List of student outputs for different crops
            teacher_output: List of teacher outputs for different crops
            epoch: Current epoch number (for temperature scheduling)

        Returns:
            Loss value
        """
        # Get current teacher temperature (with warmup)
        teacher_temp = self.get_teacher_temp(epoch)

        # Gather all outputs
        student_out = self.gather_outputs(student_output)
        teacher_out = self.gather_outputs(teacher_output).detach()  # Detach early

        # Apply temperature to student outputs
        if isinstance(student_output, list):
            student_output = torch.cat(student_output, dim=0)
            teacher_output = torch.cat(teacher_output, dim=0)
        student_out = student_out / self.student_temp

        # Apply temperature and center to teacher outputs
        teacher_out = (teacher_out - self.center) / teacher_temp
        teacher_out = F.softmax(teacher_out, dim=-1)

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    # we skip cases where student and teacher operate on the same view
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss

    def get_teacher_temp(self, epoch):
        """Get teacher temperature with warmup schedule"""
        if epoch < self.warmup_teacher_temp_epochs:
            # Linear warmup
            return self.warmup_teacher_temp + (self.teacher_temp - self.warmup_teacher_temp) * \
                   epoch / self.warmup_teacher_temp_epochs
        else:
            return self.teacher_temp

    def gather_outputs(self, outputs):
        """
        Gather outputs from all crops and concatenate them.
        """
        return torch.cat([output for output in outputs], dim=0)

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output centering.
        """
        # Gather all teacher outputs (raw, before temperature/softmax)
        teacher_out = self.gather_outputs(teacher_output)

        # Calculate batch mean
        batch_center = torch.mean(teacher_out, dim=0, keepdim=True)

        # Update center with momentum
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

# Example usage
if __name__ == "__main__":
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize models (using your VisionTransformerWrapper)
    model_name = "vit_small_patch16_224"
    img_size = 224
    std_img_size = 224
    out_dim = 1024

    # Student model (trainable)
    student_model = VisionTransformerWrapper(
        model_name=model_name,
        img_size=std_img_size,
        pretrained=True,
        is_teacher=False
    )

    # Teacher model (frozen, updated via EMA)
    teacher_model = VisionTransformerWrapper(
        model_name=model_name,
        img_size=img_size,
        pretrained=True,
        is_teacher=True
    )
    # Example configuration (adjust values based on your setup)
    ncrops = 6  # 2 global crops + 4 local crops
    warmup_teacher_temp = 0.04
    teacher_temp = 0.07
    warmup_teacher_temp_epochs = 30
    nepochs = 100

    # Initialize DINOLoss with the required arguments
    loss_fn = DINOLoss(
        #ncrops=ncrops,
        warmup_teacher_temp=warmup_teacher_temp,
        teacher_temp=teacher_temp,
        warmup_teacher_temp_epochs=warmup_teacher_temp_epochs,
        nepochs=nepochs,
        out_dim=1024
    )

    # Create COCO data loader
    dataDir = "/content/object-detection-BBD/data/100k/test"
    annFile = f"{dataDir}/_annotations.coco.json"
    dataloader = create_coco_dataloader(
        annFile=annFile,
        dataDir=dataDir,
        batch_size=4,
        num_workers=4,
        global_crop_size=224,
        local_crop_size=96,
    )

    # Initialize optimizer
    optimizer = optim.AdamW(
        student_model.parameters(),
        lr=0.0005,
        weight_decay=0.04
    )

    # Initialize trainer
    trainer = DINOTrainer(
        student_model=student_model,
        teacher_model=teacher_model,
        dataloader=dataloader,
        loss_fn=loss_fn,
        optimizer=optimizer,
        device=device,
        out_dir="./dino_coco_checkpoints",
        warmup_epochs=1,
        total_epochs=10,
        save_freq=10,
        use_amp=True,
        # use_ddp=False  # Set to True for multi-GPU training
    )

    # Start training
    trainer.train()

NameError: name 'nn' is not defined