In [1]:
from models import DINOv2ViT, CustomResNet
import torch
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

x = torch.randn(1, 3, 224, 224).to(device)
teacher = DINOv2ViT().to(device)


# out = teacher(x)
# print(out["patch_embeddings"].shape)
# print(out["embedding"].shape)
# print(out["feature_map"].shape)

student = CustomResNet().to(device)
out = student(x)
print(out["dinov2_feature_map"].shape)
print(out["embedding"].shape)
print(out["contrastive_embeddings"].shape)

Using cache found in /home/arda/.cache/torch/hub/facebookresearch_dinov2_main


layer_3 shape: torch.Size([1, 1024, 14, 14])
torch.Size([1, 1536, 16, 16])
torch.Size([1, 1536])
torch.Size([1, 2048])


In [2]:
# from datasets.GTA5 import GTA5Dataset
import sys
sys.path.append('./datasets')  # Add the datasets directory to the Python path

from collate_fn import collate_data_and_cast  # Adjusted import statement
# from datasets.collate_fn import collate_data_and_cast
from dinov2.data.augmentations import DataAugmentationDINO
from torch.utils.data import DataLoader
from imagenet import ImageNetDataset

import yaml

# Load configurations
with open("config/config.yaml", "r") as f:
    cfg = yaml.safe_load(f)

# Data Transformation
data_transform = DataAugmentationDINO(
    global_crops_scale=tuple(cfg['data_transform']['global_crops_scale']),
    local_crops_scale=tuple(cfg['data_transform']['local_crops_scale']),
    local_crops_number=cfg['data_transform']['n_local_crops'],
    global_crops_size=tuple(cfg['data_transform']['global_crops_size']),
    local_crops_size=tuple(cfg['data_transform']['local_crops_size']),
)


# Create train and test datasets
train_dataset = ImageNetDataset(type='train', transform=data_transform, num_samples = 5000)
test_dataset = ImageNetDataset(type='test', transform=data_transform, num_samples = 500)
# Create train and test dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=cfg['data_loader']['batch_size'], 
    num_workers=cfg['data_loader']['num_workers'],
    shuffle=cfg['data_loader']['shuffle'],
    collate_fn=collate_data_and_cast
)

test_loader = DataLoader(
    test_dataset,
    batch_size=cfg['data_loader']['batch_size'],
    num_workers=cfg['data_loader']['num_workers'], 
    shuffle=False,
    collate_fn=collate_data_and_cast
)

# Optimizer
optimizer = getattr(torch.optim, cfg['optimizer']['type'])([
    {"params": student.parameters()},
], lr=2.5e-4)

# Freeze teacher model
for param in teacher.parameters():
    param.requires_grad = False


Loading dataset shards:   0%|          | 0/257 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/25 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/257 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/25 [00:00<?, ?it/s]

In [3]:
from tqdm import tqdm
import os
import torch.nn.functional as F  # Added import for functional operations
from torch.cuda.amp import GradScaler, autocast  # Import for mixed precision

best_test_similarity = 0
save_frequency = 5
checkpoint_dir = "./checkpoints"
scaler = GradScaler()

def compute_feature_similarity(feat1, feat2):
    # Reshape feature maps to 2D: (batch*height*width, channels)
    f1 = feat1.reshape(-1, feat1.shape[-1])
    f2 = feat2.reshape(-1, feat2.shape[-1])
    
    # Compute cosine similarity
    similarity = torch.nn.functional.cosine_similarity(f1, f2, dim=1)
    return similarity.mean()

def compound_loss(mse_loss, cosine_sim_loss, alpha=1.0, beta=1.0):
    """
    Combine MSE loss and Cosine Similarity loss.
    
    Args:
        mse_loss (torch.Tensor): Mean Squared Error loss.
        cosine_sim_loss (torch.Tensor): Cosine Similarity loss.
        alpha (float): Weight for MSE loss.
        beta (float): Weight for Cosine Similarity loss.
        
    Returns:
        torch.Tensor: Combined loss.
    """
    return alpha * mse_loss + beta * cosine_sim_loss

checkpoint_path = os.path.join(checkpoint_dir, "latest_checkpoint.pth")

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    student.load_state_dict(checkpoint['student_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_test_similarity = checkpoint['best_test_similarity']
    print(f"Resuming from epoch {start_epoch}")

for epoch in range(1000):
    epoch_loss = []
    similarities = []
    embedding_similarities = []
    student.train()
    teacher.eval()
    for i, data in enumerate(tqdm(train_loader)):
        global_crops = data["collated_global_crops"].to(device)
        local_crops = data["collated_local_crops"].to(device)

        # Mixed precision training
        with autocast():
            # Get feature maps from teacher
            with torch.no_grad():
                teacher_output = teacher(global_crops)
                teacher_feature_maps = teacher_output["feature_map"]
                teacher_embedding = teacher_output["embedding"]

            # Get feature maps from student
            student_output = student(global_crops)
            student_feature_maps = student_output["dinov2_feature_map"]
            student_embedding = student_output["embedding"]

            # Calculate MSE loss between feature maps
            mse_loss = torch.nn.functional.mse_loss(
                student_feature_maps,
                teacher_feature_maps
            )
            mse_embedding_loss = torch.nn.functional.mse_loss(
                student_embedding,
                teacher_embedding
            )
            
            # Calculate Cosine Similarity loss
            student_feature_normalized = F.normalize(student_feature_maps, p=2, dim=1)
            teacher_feature_normalized = F.normalize(teacher_feature_maps, p=2, dim=1)
            cosine_similarity = torch.nn.functional.cosine_similarity(
                student_feature_normalized, 
                teacher_feature_normalized, 
                dim=1
            )
            cosine_similarity_loss = 1 - cosine_similarity.mean()  # Convert similarity to loss

            student_embedding_normalized = F.normalize(student_embedding, p=2, dim=1)
            teacher_embedding_normalized = F.normalize(teacher_embedding, p=2, dim=1)
            cosine_similarity_embedding = torch.nn.functional.cosine_similarity(
                student_embedding_normalized, 
                teacher_embedding_normalized, 
                dim=1
            )
            cosine_similarity_embedding_loss = 1 - cosine_similarity_embedding.mean()  # Convert similarity to loss

            # Combine the losses
            total_loss = compound_loss(mse_loss, cosine_similarity_loss, alpha=1.0, beta=1.0)
            total_embedding_loss = compound_loss(mse_embedding_loss, cosine_similarity_embedding_loss, alpha=1.0, beta=1.0)
            total_loss += total_embedding_loss
        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        # Calculate similarity for logging
        similarity = compute_feature_similarity(student_feature_maps, teacher_feature_maps)
        similarities.append(similarity.item())
        # Calculate embedding similarity for logging
        embedding_similarity = compute_feature_similarity(student_embedding, teacher_embedding)
        embedding_similarities.append(embedding_similarity.item())
        
        epoch_loss.append(total_loss.item())

    # Evaluation on test set
    student.eval()
    test_losses = []
    test_similarities = []
    test_embedding_similarities = []
    with torch.no_grad():
        for i, data in enumerate(tqdm(test_loader)):
            global_crops = data["collated_global_crops"].to(device)
            
            teacher_output = teacher(global_crops)
            student_output = student(global_crops)
            
            # Feature map losses
            test_mse = torch.nn.functional.mse_loss(
                student_output["dinov2_feature_map"],
                teacher_output["feature_map"]
            )
            test_similarity = compute_feature_similarity(
                student_output["dinov2_feature_map"],
                teacher_output["feature_map"]
            )
            
            # Embedding losses
            test_embedding_mse = torch.nn.functional.mse_loss(
                student_output["embedding"],
                teacher_output["embedding"]
            )
            test_embedding_similarity = compute_feature_similarity(
                student_output["embedding"],
                teacher_output["embedding"]
            )
            
            test_losses.append(test_mse.item() + test_embedding_mse.item())
            test_similarities.append(test_similarity.item())
            test_embedding_similarities.append(test_embedding_similarity.item())

    # Calculate average metrics
    avg_train_loss = sum(epoch_loss)/len(epoch_loss)
    avg_train_similarity = sum(similarities)/len(similarities)
    avg_train_embedding_similarity = sum(embedding_similarities)/len(embedding_similarities)
    avg_test_loss = sum(test_losses)/len(test_losses)
    avg_test_similarity = sum(test_similarities)/len(test_similarities)
    avg_test_embedding_similarity = sum(test_embedding_similarities)/len(test_embedding_similarities)

    # Print metrics
    print(f"Epoch {epoch}")
    print(f"Train Loss: {avg_train_loss:.4f}")
    print(f"Train Feature Similarity: {avg_train_similarity:.4f}")
    print(f"Train Embedding Similarity: {avg_train_embedding_similarity:.4f}")
    print(f"Test Loss: {avg_test_loss:.4f}")
    print(f"Test Feature Similarity: {avg_test_similarity:.4f}")
    print(f"Test Embedding Similarity: {avg_test_embedding_similarity:.4f}")

     # Save checkpoint
    if (epoch + 1) % save_frequency == 0:
        checkpoint = {
            'epoch': epoch,
            'student_state_dict': student.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'test_loss': avg_test_loss,
            'train_feature_similarity': avg_train_similarity,
            'train_embedding_similarity': avg_train_embedding_similarity,
            'test_similarity': avg_test_similarity,
            'best_test_similarity': best_test_similarity
        }
        torch.save(checkpoint, os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth"))
        torch.save(checkpoint, os.path.join(checkpoint_dir, "latest_checkpoint.pth"))
        
    # Save best model
    if avg_test_similarity > best_test_similarity:
        best_test_similarity = avg_test_similarity
        torch.save({
            'epoch': epoch,
            'student_state_dict': student.state_dict(),
            'test_similarity': avg_test_similarity,
            'test_embedding_similarity': avg_test_similarity
        }, os.path.join(checkpoint_dir, "best_model.pth"))

  0%|          | 0/16 [00:00<?, ?it/s]

  storage_data_ptr = tensors[0].storage().data_ptr()
  if x.storage().data_ptr() != storage_data_ptr:
100%|██████████| 16/16 [01:01<00:00,  3.82s/it]
100%|██████████| 2/2 [00:12<00:00,  6.46s/it]


Epoch 0
Train Loss: 5.8403
Train Feature Similarity: 0.0879
Train Embedding Similarity: 0.0714
Test Loss: 4.1575
Test Similarity: 0.0881


100%|██████████| 16/16 [01:03<00:00,  3.98s/it]
100%|██████████| 2/2 [00:12<00:00,  6.50s/it]


Epoch 1
Train Loss: 5.5363
Train Feature Similarity: 0.1243
Train Embedding Similarity: 0.0848
Test Loss: 3.9246
Test Similarity: 0.1102


100%|██████████| 16/16 [01:05<00:00,  4.12s/it]
 50%|█████     | 1/2 [00:10<00:10, 10.78s/it]

In [5]:
labels = torch.arange(32).repeat(2)  # Creates [0,1,2,...,batch_size-1, 0,1,2,...,batch_size-1]
labels = labels.to(device)

In [6]:
labels

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,  0,  1,  2,  3,
         4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
        22, 23, 24, 25, 26, 27, 28, 29, 30, 31], device='cuda:1')