In [38]:
import os
import torch
import sys
sys.path.append("/home/arda/dinov2")
from torch.utils.data import Dataset
from PIL import Image
from dinov2.data.augmentations import DataAugmentationDINO

class CustomImageDataset(Dataset):
    def __init__(self, image_dir):
        """
        Args:
            image_dir (string): Directory with all the images
        """
        self.image_dir = image_dir
        self.transform = DataAugmentationDINO(
            global_crops_scale=(0.4, 1.0),
            local_crops_scale=(0.05, 0.4),
            local_crops_number=8,
        )
        
        # Get all image files
        self.image_files = []
        valid_extensions = {'.jpg', '.jpeg', '.png'}
        for filename in os.listdir(image_dir):
            ext = os.path.splitext(filename)[1].lower()
            if ext in valid_extensions:
                self.image_files.append(filename)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        
        # Apply transformations
        if self.transform:
            transformed = self.transform(image)
            
        return transformed

def collate_data_and_cast(samples_list, dtype):
    n_global_crops = 2
    n_local_crops = 8

    collated_global_crops = torch.stack([s["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
    collated_local_crops = torch.stack([s["local_crops"][i] for i in range(n_local_crops) for s in samples_list])

    return {
        "collated_global_crops": collated_global_crops.to(dtype),
        "collated_local_crops": collated_local_crops.to(dtype),
    }

# Example usage:
image_dir = "/home/arda/.cache/kagglehub/datasets/ardaerendoru/gtagta/versions/1/GTA5/GTA5/images"  # Path to your image directory
dataset = CustomImageDataset(image_dir)

# Create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    drop_last=True,
    collate_fn=lambda x: collate_data_and_cast(x, torch.float32)
)

import torch
import torch.nn as nn
import torchvision.models as models

class CustomResNet(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        # Load pretrained ResNet50 model
        resnet = models.resnet50(pretrained=pretrained)
        
        # Split model into layers
        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        self.avgpool = resnet.avgpool
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        layer4_output = self.layer4(x)
        
        pooled = self.avgpool(layer4_output)
        embeddings = torch.flatten(pooled, 1)
        
        return {
            'layer4_output': layer4_output,
            'embeddings': embeddings
        }

    
teacher = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
student = CustomResNet()
# Freeze teacher parameters
for param in teacher.parameters():
    param.requires_grad = False
def get_teacher_output(teacher, global_crops, n_global_crops):
    with torch.no_grad():
        # Process global crops through teacher
        x = global_crops
        teacher_output = teacher(x)
        teacher_cls_tokens = teacher_output['x_norm_clstoken']
        
        # Split into chunks for each global crop
        teacher_cls_tokens = teacher_cls_tokens.chunk(n_global_crops)
        
        # Concatenate in reverse order to match crops A->B with B->A
        teacher_cls_tokens = torch.cat((teacher_cls_tokens[1], teacher_cls_tokens[0]))
        teacher_patch_tokens = teacher_output['x_norm_patchtokens']
        
        return teacher_cls_tokens, teacher_patch_tokens

def get_student_output(student, global_crops, local_crops, n_local_crops, n_global_crops):
    # Process global crops through student
    inputs_for_student_head_list = []
    student_global_embeddings = student.backbone(global_crops)['embeddings']
    student_local_embeddings = student(local_crops)['embeddings']
    inputs_for_student_head_list.append(student_local_embeddings.unsqueeze(0))
    inputs_for_student_head_list.append(student_global_embeddings.unsqueeze(0))
    ibot_student_patch_tokens = student(global_crops)['layer4_output']
    student_outputs = []
    for input_tensor in inputs_for_student_head_list:
        student_outputs.append(student.dino_head(input_tensor.squeeze(0)))
    student_local_cls_tokens_after_head = student_outputs[0]
    student_global_cls_tokens_after_head = student_outputs[1]
    # Concatenate in reverse order to match teacher
    student_cls_tokens_global = torch.cat((student_cls_tokens_global[1], student_cls_tokens_global[0]))
    
    return student_cls_tokens

# Cross entropy loss for comparing teacher and student outputs
criterion = nn.CrossEntropyLoss()

# Optimizer for student model
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-4)


for batch in dataloader:




torch.Size([4, 3, 224, 224])


I20241118 16:41:37 564450 dinov2 config.py:59] git:
  sha: e1277af2ba9496fbadf7aec6eba56e8d882d1e35, status: has uncommitted changes, branch: main

I20241118 16:41:37 564450 dinov2 config.py:60] config_file: /home/arda/dinov2/dinov2/configs/ssl_default_config.yaml
eval: 
eval_only: False
no_resume: False
opts: ['train.output_dir=/storage/disk0/arda/dinov2/dinov2/train']
output_dir: /storage/disk0/arda/dinov2/dinov2/train
I20241118 16:41:37 564450 dinov2 config.py:26] sqrt scaling learning rate; base: 0.004, new: 0.001
I20241118 16:41:37 564450 dinov2 config.py:33] MODEL:
  WEIGHTS: ''
compute_precision:
  grad_scaler: true
  teacher:
    backbone:
      sharding_strategy: SHARD_GRAD_OP
      mixed_precision:
        param_dtype: fp32
        reduce_dtype: fp32
        buffer_dtype: fp32
    dino_head:
      sharding_strategy: SHARD_GRAD_OP
      mixed_precision:
        param_dtype: fp32
        reduce_dtype: fp32
        buffer_dtype: fp32
    ibot_head:
      sharding_strategy: SHARD