# There are 2 approaches possible:

1. Load the checkpoints into a standard ViT model and ignore all the extra layers 

2. Define a custom ViT model by wrapping a base ViT in layers corresponding to the checkpoints (weights that are named "loss_module_*" are ignored because they're mostly likely only for training and I don't know what kind of layers fit them)



In [1]:
import os
import numpy as np
import torch
import cv2
from timm import create_model
from torchvision import transforms
from sklearn.model_selection import train_test_split, LeaveOneOut
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm  # For progress bars
import torch
from timm import create_model
import timm
import torch.nn as nn
from wrappers_supervised import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load checkpoints

In [2]:

# Paths to model and data

joscha_checkpoint_path = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/models/vit_large_dinov2_ssl_joscha.ckpt"
robert_checkpoint_path = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/models/supervised_dinov2_large.ckpt"
vincent_checkpoint_path = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/models/ssl_vincent_vit_large.ckpt"
best_checkpoint_path = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/models/should-be-the-best-model_vit_large_dinoV2.ckpt"

model_name_joscha = "vit_large_dinov2_ssl_joscha"
model_name_robert = "supervised_dinov2_large"
model_name_vincent = "ssl_vincent_vit_large"
model_name_best = "best-model_vit_large_dinoV2"

wrapper_type = "timmWrapper"
model_name_wrapper = f"{model_name_best} in {wrapper_type}"

#checkpoint_path = robert_checkpoint_path
test_folder = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/datasets/cxl_all_split_60-25-15/test"
train_folder = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/datasets/cxl_all_split_60-25-15/train"
val_folder = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/datasets/cxl_all_split_60-25-15/val"

all_folder = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/datasets/cxl_all_face"

bristol_folder = "/workspaces/gorilla_watch/video_data/bristol/cropped_frames_square_filtered"
# filtered 372 images

### Inspect checkpoints in comparison to a standard ViT

In [3]:
# Load checkpoint and extract state_dict
checkpoint_joscha = torch.load(joscha_checkpoint_path, map_location=torch.device(device))
checkpoint_robert = torch.load(robert_checkpoint_path, map_location=torch.device(device))
checkpoint_vincent = torch.load(vincent_checkpoint_path, map_location=torch.device(device))
checkpoint_best = torch.load(best_checkpoint_path, map_location=torch.device(device))

# for k, v in checkpoint_best["hyper_parameters"].items():
#     print(k, v)
# print(checkpoint_best["hparams_name"])
# print(checkpoint_best["state_dict"].keys())


  checkpoint_joscha = torch.load(joscha_checkpoint_path, map_location=torch.device(device))
  checkpoint_robert = torch.load(robert_checkpoint_path, map_location=torch.device(device))
  checkpoint_vincent = torch.load(vincent_checkpoint_path, map_location=torch.device(device))
  checkpoint_best = torch.load(best_checkpoint_path, map_location=torch.device(device))


wandb_run <wandb.sdk.wandb_run.Run object at 0x7f7123031870>
loss_mode online/hard/l2sp
from_scratch False
weight_decay 0.2
lr_schedule cosine
warmup_mode constant
warmup_epochs 0
max_epochs 100
initial_lr 4.09891362474683e-06
start_lr 5.497236565453763e-06
end_lr 1e-07
stepwise_schedule False
lr_interval 1
beta1 0.9
beta2 0.999
epsilon 1e-07
embedding_size 256
batch_size 16
dataset_names ['cxlkfold', 'bristol', 'test']
accelerator cuda
dropout_p 0.0
use_dist_term False
use_inbatch_mixup False
kfold_k None
knn_with_train True
use_quantization_aware_training False
every_n_val_epochs 5
fast_dev_run False
margin 1.0
s 64.0
temperature 0.5
memory_bank_size 0
num_classes None
class_distribution None
l2_alpha 0.05901881729791581
l2_beta 0.028149782260071254
path_to_pretrained_weights ./pretrained_weights/vit_large_patch14_dinov2_lvd142m.pth
use_wildme_model False
k_subcenters 1
use_focal_loss False
label_smoothing 0.0
use_class_weights False
teacher_model_wandb_link 
loss_dist_term euclidean

In [4]:
# ### compare checkpoints' state_dict with a standard model's state_dict

def extract_clean_state_dict(checkpoint, wrapper_key="model_wrapper.", model_key="model."):
    # Extract the state_dict from the checkpoint
    state_dict = checkpoint.get('state_dict', checkpoint)  # Use 'state_dict' or checkpoint directly
    # Remove wrapper key prefix
    cleaned_state_dict = {k.replace(wrapper_key, ''): v for k, v in state_dict.items()}
    # Remove model key prefix
    cleaned_state_dict = {k.replace(model_key, ''): v for k, v in cleaned_state_dict.items()}
    
    return cleaned_state_dict
    
def compare_checkpoints_with_model(checkpoint, model_fn, model_name, img_size):
    # # Extract the state_dict from the checkpoint
    # state_dict = checkpoint.get('state_dict', checkpoint)  # Use 'state_dict' or checkpoint directly
    # new_state_dict = {k.replace('model_wrapper.', ''): v for k, v in state_dict.items()}
    # new_state_dict2 = {k.replace('model.', ''): v for k, v in new_state_dict.items()}
    
    cleaned_state_dict = extract_clean_state_dict(checkpoint)

    # Initialize the model
    vit_model = model_fn(model_name, pretrained=False, img_size=img_size)

    # Compare keys
    standard_vit_keys = vit_model.state_dict().keys()
    custom_keys = cleaned_state_dict.keys()
    missing_in_standard = [k for k in custom_keys if k not in standard_vit_keys]
    extra_in_standard = [k for k in standard_vit_keys if k not in custom_keys]

    # Collect information about embedding keys
    embedding_keys = [key for key in cleaned_state_dict.keys() if "embedding_layer" in key]

    # Print missing keys and their shapes
    # for key in missing_in_standard:
        # print(f"{key}: {new_state_dict2[key].shape}")

    return {
        "missing_in_standard": missing_in_standard,
        "extra_in_standard": extra_in_standard,
        "embedding_keys": embedding_keys,
    }
        
    
# Compare the checkpoint with the standard ViT model
# comparison_joscha = compare_checkpoints_with_model(
#     checkpoint=checkpoint_joscha,
#     model_fn=create_model,
#     model_name='vit_large_patch14_dinov2.lvd142m',
#     img_size=192
# )
# comparison_robert = compare_checkpoints_with_model(
#     checkpoint=checkpoint_robert,
#     model_fn=create_model,
#     model_name='vit_large_patch14_dinov2.lvd142m',
#     img_size=192
# )
# comparison_vincent = compare_checkpoints_with_model(
#     checkpoint=checkpoint_vincent,
#     model_fn=create_model,
#     model_name='vit_large_patch14_dinov2.lvd142m',
#     img_size=192
# )
comparison_best = compare_checkpoints_with_model(
    checkpoint=checkpoint_best,
    model_fn=create_model,
    model_name='vit_large_patch14_dinov2.lvd142m',
    img_size=224
)

# print("Missing keys in standard ViT (Joscha):", comparison_joscha)
# print("Missing keys in standard ViT (Robert):", comparison_robert)
# print("Missing keys in standard ViT (Vincent):", comparison_vincent)


def filter_missing_keys(keys, filter_keywords):
    return [key for key in keys if not any(keyword in key for keyword in filter_keywords)]


# Define the keywords to filter out
filter_keywords = ["loss_module_val", "loss_module_train"]

# # Apply filtering to each comparison
# comparison_joscha_missing_keys = filter_missing_keys(comparison_joscha["missing_in_standard"], filter_keywords)
# comparison_robert_missing_keys = filter_missing_keys(comparison_robert["missing_in_standard"], filter_keywords)
# comparison_vincent_missing_keys = filter_missing_keys(comparison_vincent["missing_in_standard"], filter_keywords)
comparison_best_missing_keys = filter_missing_keys(comparison_best["missing_in_standard"], filter_keywords)

# # Print filtered results
# print("Missing keys in standard ViT (Joscha):", comparison_joscha_missing_keys)
# print("Missing keys in standard ViT (Robert):", comparison_robert_missing_keys)
# print("Missing keys in standard ViT (Vincent):", comparison_vincent_missing_keys)
print("Missing keys in standard ViT (best):", comparison_best_missing_keys)

cleaned_state_dict = extract_clean_state_dict(checkpoint_best)
# print(cleaned_state_dict.keys())





Missing keys in standard ViT (best): ['embedding_layer.weight', 'embedding_layer.bias']


### helper functions used by both approaches:

In [5]:
# fix pos_emd mismatch
import torch
import torch.nn.functional as F

def interpolate_positional_embedding(pretrained_state_dict, current_model, patch_size):
    # Get the pretrained positional embeddings
    pretrained_pos_embed = pretrained_state_dict['pos_embed']
    current_pos_embed = current_model.state_dict()['pos_embed']
    
    # Exclude the CLS token (first token) if present
    cls_token = pretrained_pos_embed[:, :1, :]
    pretrained_grid = pretrained_pos_embed[:, 1:, :]
    
    # Get the grid dimensions
    num_patches = current_pos_embed.shape[1] - 1  # Exclude CLS token
    grid_size_pretrained = int((pretrained_grid.shape[1])**0.5)  # sqrt(num_patches)
    grid_size_current = int(num_patches**0.5)

    # Reshape the positional embeddings to a grid
    pretrained_grid = pretrained_grid.reshape(1, grid_size_pretrained, grid_size_pretrained, -1).permute(0, 3, 1, 2)
    
    # Interpolate to the current grid size
    interpolated_grid = F.interpolate(pretrained_grid, size=(grid_size_current, grid_size_current), mode='bilinear', align_corners=False)
    interpolated_grid = interpolated_grid.permute(0, 2, 3, 1).reshape(1, -1, interpolated_grid.shape[1])
    
    # Combine CLS token and the new grid
    new_pos_embed = torch.cat([cls_token, interpolated_grid], dim=1)
    pretrained_state_dict['pos_embed'] = new_pos_embed

    return pretrained_state_dict


# interpolate_positional_embedding(new_state_dict2, vit_model, 16)

In [6]:
# # extract cleaned state_dict from checkpoint
cleaned_state_dict_joscha = extract_clean_state_dict(checkpoint_joscha, wrapper_key="model_wrapper.", model_key="model.")
cleaned_state_dict_robert = extract_clean_state_dict(checkpoint_robert, wrapper_key="model_wrapper.", model_key="model.")
cleaned_state_dict_vincent = extract_clean_state_dict(checkpoint_vincent, wrapper_key="model_wrapper.", model_key="model.")
cleaned_state_dict_best = extract_clean_state_dict(checkpoint_best, wrapper_key="model_wrapper.", model_key="model.")

In [None]:
# load the data
# hyperparams
img_size = 224

# Preprocessing function to resize and normalize images
transform_custom = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Adjust based on model pretraining
])

transform_standard = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Adjust based on model pretraining
])

# Load images and extract class codes from file names
# def load_images_and_labels(folder, transform):
#     images = []
#     labels = []
#     for filename in os.listdir(folder):
#         if filename.endswith('.jpg') or filename.endswith('.png'):
#             image = cv2.imread(os.path.join(folder, filename))
#             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#             image = transform(image)
#             images.append(image)
#             label = filename.split("_")[0]
#             # labels.append(filename[:4])  # Assuming first 4 chars are label
#             labels.append(label)
                
#     return images, labels

def load_images_and_labels(folder, transform, threshold=3):
    images = []
    labels = []
    
    # Temporary storage for all data
    temp_images = []
    temp_labels = []
    
    for filename in os.listdir(folder):
        if filename.endswith('.jpg') or filename.endswith('.png'):
            image = cv2.imread(os.path.join(folder, filename))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = transform(image)
            temp_images.append(image)
            label = filename.split("_")[0]
            temp_labels.append(label)
    
    # Count occurrences of each class
    label_counts = Counter(temp_labels)
    
    # Filter out classes with fewer than 'threshold' images
    valid_classes = {label for label, count in label_counts.items() if count >= threshold}
    
    for image, label in zip(temp_images, temp_labels):
        if label in valid_classes:
            images.append(image)
            labels.append(label)
    
    return images, labels

# unique_labels = set(labels)



In [None]:
# filter out samples with less occurrences
from collections import Counter

def filter_samples_by_threshold(embeddings, labels, threshold=3):
    # Count occurrences of each label
    label_counts = Counter(labels)
    print(f"Number of images before filtering: {len(labels)}")

    # Filter embeddings and labels for labels with >= threshold occurrences
    filtered_indices = [i for i, label in enumerate(labels) if label_counts[label] >= threshold]
    filtered_embeddings = embeddings[filtered_indices]
    filtered_labels = [labels[i] for i in filtered_indices]

    # Log results
    print(f"Number of embeddings after filtering: {len(filtered_embeddings)}")

    return filtered_embeddings, filtered_labels


In [9]:

# Function to generate embeddings from a model
def generate_embeddings(model, images_tensor):
    with torch.no_grad():
        # Call the forward method to extract embeddings
        embeddings = model(images_tensor)
        embeddings_flat = embeddings.view(embeddings.size(0), -1).numpy() # Flatten embeddings(optional?)

    return embeddings

# KNN classifier and excluding the closest neighbor (self-matching)
def knn_classifier(embeddings, labels):
    for k in range(5, 6):
        # Use k+1 neighbors to account for excluding self
        vit_knn = KNeighborsClassifier(n_neighbors=k + 1)
        vit_knn.fit(embeddings, labels)  # Fit KNN on all data

        vit_y_pred = []
        for idx, test_embedding in enumerate(embeddings):
            # Find k+1 neighbors (including the test sample itself)
            neighbors = vit_knn.kneighbors(test_embedding.reshape(1, -1), return_distance=False)[0]

            # Exclude the test sample itself (assumed to be the closest neighbor)
            filtered_neighbors = neighbors[1:]  # Exclude the first neighbor
        
            # Predict based on the remaining neighbors (majority vote)
            filtered_neighbor_labels = [labels[n] for n in filtered_neighbors]
            predicted_label = max(set(filtered_neighbor_labels), key=filtered_neighbor_labels.count)
            vit_y_pred.append(predicted_label)

        # Calculate accuracy and F1 for the current k
        vit_accuracy = accuracy_score(labels, vit_y_pred)
        vit_f1 = f1_score(labels, vit_y_pred, average='macro')
        print(f"ViT Accuracy with {k}-nearest neighbors: {vit_accuracy:.4f}, F1_score:{vit_f1: 4f}")
        

In [None]:
# load the data (choose the folder here!!!!)
# Load data (SPAC or Bristol)
# test_images, test_labels = load_images_and_labels(test_folder, transform_standard)

test_images, test_labels = load_images_and_labels(bristol_folder, transform_standard)

# test_images, test_labels = load_images_and_labels(val_folder, transform_standard)

# train_images, train_labels = load_images_and_labels(train_folder)
# val_images, val_labels = load_images_and_labels(val_folder)
# all_images, all_labels = load_images_and_labels(all_folder)

test_images_tensor = torch.stack(test_images)
# train_images_tensor = torch.stack(train_images)
# val_images_tensor = torch.stack(val_images)
# all_images_tensor = torch.stack(all_images)

# train_val_images_tensor = torch.cat((train_images_tensor, val_images_tensor), 0)
# train_val_labels = train_labels + val_labels

### Approach 1: load the chcekpoints into a standard ViT moel


In [11]:
### Approach 1: load the chcekpoints into a standard ViT moel

vit_model_joscha = create_model('vit_large_patch14_dinov2.lvd142m', pretrained=False, img_size=img_size)
vit_model_robert = create_model('vit_large_patch14_dinov2.lvd142m', pretrained=False, img_size=img_size)
vit_model_vincent = create_model('vit_large_patch14_dinov2.lvd142m', pretrained=False, img_size=img_size)
vit_model_best = create_model('vit_large_patch14_dinov2.lvd142m', pretrained=False, img_size=img_size)

# interpolate pos_embeddings
interpolated_state_dict_joscha = interpolate_positional_embedding(cleaned_state_dict_joscha, vit_model_joscha, 16)
interpolated_state_dict_robert = interpolate_positional_embedding(cleaned_state_dict_robert, vit_model_robert, 16)
interpolated_state_dict_vincent = interpolate_positional_embedding(cleaned_state_dict_vincent, vit_model_vincent, 16)
interpolated_state_dict_best = interpolate_positional_embedding(cleaned_state_dict_best, vit_model_best, 16)

vit_model_joscha.load_state_dict(interpolated_state_dict_joscha, strict=False)
vit_model_robert.load_state_dict(interpolated_state_dict_robert, strict=False)
vit_model_vincent.load_state_dict(interpolated_state_dict_vincent, strict=False)
vit_model_best.load_state_dict(interpolated_state_dict_best, strict=False)



_IncompatibleKeys(missing_keys=[], unexpected_keys=['embedding_layer.weight', 'embedding_layer.bias', 'loss_module_train.cls_token', 'loss_module_train.pos_embed', 'loss_module_train.patch_embed.proj.weight', 'loss_module_train.patch_embed.proj.bias', 'loss_module_train.blocks.0.norm1.weight', 'loss_module_train.blocks.0.norm1.bias', 'loss_module_train.blocks.0.attn.qkv.weight', 'loss_module_train.blocks.0.attn.qkv.bias', 'loss_module_train.blocks.0.attn.proj.weight', 'loss_module_train.blocks.0.attn.proj.bias', 'loss_module_train.blocks.0.ls1.gamma', 'loss_module_train.blocks.0.norm2.weight', 'loss_module_train.blocks.0.norm2.bias', 'loss_module_train.blocks.0.mlp.fc1.weight', 'loss_module_train.blocks.0.mlp.fc1.bias', 'loss_module_train.blocks.0.mlp.fc2.weight', 'loss_module_train.blocks.0.mlp.fc2.bias', 'loss_module_train.blocks.0.ls2.gamma', 'loss_module_train.blocks.1.norm1.weight', 'loss_module_train.blocks.1.norm1.bias', 'loss_module_train.blocks.1.attn.qkv.weight', 'loss_module

In [12]:
# # load the data
# # Load data (SPAC or Bristol)
# # test_images, test_labels = load_images_and_labels(test_folder, transform_standard)
# # test_images, test_labels = load_images_and_labels(bristol_folder, transform_standard)
# test_images, test_labels = load_images_and_labels(val_folder, transform_standard)
# print(test_labels[:5])

# # train_images, train_labels = load_images_and_labels(train_folder)
# # val_images, val_labels = load_images_and_labels(val_folder)
# # all_images, all_labels = load_images_and_labels(all_folder)

# test_images_tensor = torch.stack(test_images)
# # train_images_tensor = torch.stack(train_images)
# # val_images_tensor = torch.stack(val_images)
# # all_images_tensor = torch.stack(all_images)

# # train_val_images_tensor = torch.cat((train_images_tensor, val_images_tensor), 0)
# # train_val_labels = train_labels + val_labels

#### evaluate with standard models



In [None]:
vit_model_joscha.eval()
vit_model_robert.eval()
vit_model_vincent.eval()
vit_model_best.eval()

standard_embeddings_joscha = generate_embeddings(vit_model_joscha, test_images_tensor)
stadard_embeddings_robert = generate_embeddings(vit_model_robert, test_images_tensor)
standard_embeddings_vincent = generate_embeddings(vit_model_vincent, test_images_tensor)
standard_embeddings_best = generate_embeddings(vit_model_best, test_images_tensor)


vit_large_dinov2_ssl_joscha:
ViT Accuracy with 5-nearest neighbors: 0.7124, F1_score: 0.496258
supervised_dinov2_large:
ViT Accuracy with 5-nearest neighbors: 0.8170, F1_score: 0.648936
ssl_vincent_vit_large:
ViT Accuracy with 5-nearest neighbors: 0.6993, F1_score: 0.514996
best-model_vit_large_dinoV2:
ViT Accuracy with 5-nearest neighbors: 0.8431, F1_score: 0.687583


In [34]:
# evaluate with unfiltered embeddings
def classify_and_print(model_name, model, embeddings,labels):
    print(f"{model_name}:")
    knn_classifier(embeddings, labels)
    
classify_and_print(model_name_joscha, vit_model_joscha, standard_embeddings_joscha, test_labels)
classify_and_print(model_name_robert, vit_model_robert, stadard_embeddings_robert, test_labels)
classify_and_print(model_name_vincent, vit_model_vincent, standard_embeddings_vincent, test_labels)
classify_and_print(model_name_best, vit_model_best, standard_embeddings_best, test_labels)

# evaluate with filtered embeddings
filtered_embeddings_joscha, filtered_labels_joscha = filter_samples_by_threshold(standard_embeddings_joscha, test_labels, threshold=3)
filtered_embeddings_robert, filtered_labels_robert = filter_samples_by_threshold(stadard_embeddings_robert, test_labels, threshold=3)
filtered_embeddings_vincent, filtered_labels_vincent = filter_samples_by_threshold(standard_embeddings_vincent, test_labels, threshold=3)
filtered_embeddings_best, filtered_labels_best = filter_samples_by_threshold(standard_embeddings_best, test_labels, threshold=3)

print("Evaluate with filtered embeddings")
classify_and_print(model_name_joscha, vit_model_joscha, filtered_embeddings_joscha, filtered_labels_joscha)
classify_and_print(model_name_robert, vit_model_robert, filtered_embeddings_robert, filtered_labels_robert)
classify_and_print(model_name_vincent, vit_model_vincent, filtered_embeddings_vincent, filtered_labels_vincent)
classify_and_print(model_name_best, vit_model_best, filtered_embeddings_best, filtered_labels_best)


vit_large_dinov2_ssl_joscha:


ValueError: Found input variables with inconsistent numbers of samples: [153, 372]

### Approach 2: Define custom ViT models for all models

In [15]:
# Custom ViT for supervised_dinov2_large

class CustomViT_supervised(nn.Module):
    def __init__(self, base_vit):
        super(CustomViT_supervised, self).__init__()
        self.base_vit = base_vit
        # Define embedding layer to match the missing keys
        self.embedding_layer = nn.Linear(1024, 256)  # Assuming base_vit outputs 768-d embeddings
        

    def forward(self, x):
        # Pass input through the Vision Transformer backbone
        x = self.base_vit.forward_features(x)  # Output size is [batch_size, num_tokens, 768]
        print(f"Output shape of base_vit: {x.shape}")
        
        # Extract the [CLS] token embedding (assuming it's the first token)
        x = x[:, 0, :]  # Shape: [batch_size, 768]
        # print(f"Output shape after selecting CLS token: {x.shape}")
        
        # Pass through the custom embedding layer
        x = self.embedding_layer(x)  # Linear transformation to 1024
        # print(f"Output shape of embedding_layer: {x.shape}")
        
        return x
    


### test with the model wrapper only!

In [31]:
# original model wrapper
embedding_id = "linear"
# embedding_id = "mlp"
# embedding_id = "mlp_norm_dropout"
# embedding_id = "linear_norm_dropout"
# embedding_id = ""
# Custom ViT for supervised_dinov2_large
model_wrapper = TimmWrapper(
    backbone_name="vit_large_patch14_dinov2.lvd142m",
    embedding_size=256,  # Set based on your checkpoint
    embedding_id=embedding_id,  # Ensure this matches the checkpoint (OG)
    dropout_p=0.0,  # Set dropout probability
    pool_mode="none",  # Assuming no global pooling for Vision Transformers
    img_size=224,  # Match the input size expected by the model
)


In [32]:
# Load weights into the model

def extract_clean_state_dict_for_wrapper(checkpoint, wrapper_key="model_wrapper.", model_key="model."):
    # Extract the state_dict from the checkpoint
    state_dict = checkpoint.get('state_dict', checkpoint)  # Use 'state_dict' or checkpoint directly
    # Remove wrapper key prefix
    cleaned_state_dict = {k.replace(wrapper_key, ''): v for k, v in state_dict.items()}
    
    return cleaned_state_dict

# only wrapper. prefix are removed
# only loss_module. keys are ignored
cleaned_state_dict_wrapper = extract_clean_state_dict_for_wrapper(checkpoint_best)
model_wrapper.load_state_dict(cleaned_state_dict_wrapper, strict=False)
# print("Checkpoint loaded successfully!") 


_IncompatibleKeys(missing_keys=[], unexpected_keys=['loss_module_train.model.cls_token', 'loss_module_train.model.pos_embed', 'loss_module_train.model.patch_embed.proj.weight', 'loss_module_train.model.patch_embed.proj.bias', 'loss_module_train.model.blocks.0.norm1.weight', 'loss_module_train.model.blocks.0.norm1.bias', 'loss_module_train.model.blocks.0.attn.qkv.weight', 'loss_module_train.model.blocks.0.attn.qkv.bias', 'loss_module_train.model.blocks.0.attn.proj.weight', 'loss_module_train.model.blocks.0.attn.proj.bias', 'loss_module_train.model.blocks.0.ls1.gamma', 'loss_module_train.model.blocks.0.norm2.weight', 'loss_module_train.model.blocks.0.norm2.bias', 'loss_module_train.model.blocks.0.mlp.fc1.weight', 'loss_module_train.model.blocks.0.mlp.fc1.bias', 'loss_module_train.model.blocks.0.mlp.fc2.weight', 'loss_module_train.model.blocks.0.mlp.fc2.bias', 'loss_module_train.model.blocks.0.ls2.gamma', 'loss_module_train.model.blocks.1.norm1.weight', 'loss_module_train.model.blocks.1.n

In [40]:
### test only the model wrapper
### generate embeddings from the custom_vit & KNN classification
model_wrapper.eval()

# create corresponding image tensors (224 instead of 192)
# test_images, test_labels = load_images_and_labels(test_folder, transform_standard)
# test_images_tensor = torch.stack(test_images)

custom_embeddings_wrapper = generate_embeddings(model_wrapper, test_images_tensor)

def classify_and_print_custom(model_name, model, embeddings, labels):
    print(f"{model_name} with custom model:")
    knn_classifier(embeddings, labels)
    

    
classify_and_print_custom(model_name_wrapper, model_wrapper, custom_embeddings_wrapper, test_labels)
# knn_classifier(custom_embeddings_joscha, test_labels)

# evaluate with filtered embeddings

filtered_embeddings_wrapper, filtered_labels_wrapper = filter_samples_by_threshold(custom_embeddings_wrapper, test_labels, threshold=3)

print(f"Evaluate with filtered embeddings: {embedding_id}")
classify_and_print_custom(model_name_wrapper, model_wrapper, filtered_embeddings_wrapper, filtered_labels_wrapper)



best-model_vit_large_dinoV2 in timmWrapper with custom model:
ViT Accuracy with 5-nearest neighbors: 0.8441, F1_score: 0.825991
Number of images before filtering: 372
Number of embeddings after filtering: 372
Evaluate with filtered embeddings: linear
best-model_vit_large_dinoV2 in timmWrapper with custom model:
ViT Accuracy with 5-nearest neighbors: 0.8441, F1_score: 0.825991


In [19]:
def load_checkpoint_into_custom_vit_supervised(checkpoint_state_dict, custom_vit, patch_size=16):
    """
    Loads a checkpoint's state_dict into a custom Vision Transformer (ViT) model.

    Args:
        checkpoint_state_dict (dict): The state_dict from the checkpoint.
        custom_vit (torch.nn.Module): The custom ViT model with backbone and embedding layers.
        patch_size (int): Patch size for the ViT model, used for positional embedding interpolation.

    Returns:
        None
    """
    # Separate backbone (base_vit) and embedding layer weights
    backbone_state_dict = {k: v for k, v in checkpoint_state_dict.items() if "embedding_layer" not in k}
    embedding_layer_state_dict = {k: v for k, v in checkpoint_state_dict.items() if "embedding_layer" in k}

    # Interpolate positional embeddings if necessary
    interpolated_state_dict = interpolate_positional_embedding(backbone_state_dict, custom_vit.base_vit, patch_size)
    
    backbone_state_dict["pos_embed"] = interpolated_state_dict["pos_embed"]
    
    # Load the backbone weights into base_vit
    custom_vit.base_vit.load_state_dict(backbone_state_dict, strict=False)

    # Load weights into the custom embedding layer
    custom_vit.embedding_layer.load_state_dict({
        "weight": embedding_layer_state_dict["embedding_layer.weight"],
        "bias": embedding_layer_state_dict["embedding_layer.bias"],
    })



In [20]:
# Custom VIT for vit_large_dinov2_ssl_joscha & ssl_vincent_vit_large
# check thesis for best embedding layer

class CustomVisionTransformer_ssl(nn.Module):
    def __init__(self, base_vit):
        super(CustomVisionTransformer_ssl, self).__init__()
        self.base_vit = base_vit
        
        # Define embedding layers based on checkpoint dimensions
        self.embedding_layer_0 = nn.BatchNorm1d(1024)  # Normalize input size 1024
        self.embedding_layer_2 = nn.Linear(1024, 256)  # Linear layer: 1024 -> 256
        self.embedding_layer_3 = nn.BatchNorm1d(256)  # Normalize input size 256

    def forward(self, x):
        # Pass input through the Vision Transformer backbone
        x = self.base_vit.forward_features(x)  # Output size is [batch_size, 1024]
        print(f"output shape of base_vit: {x.shape}")
        
        # Flatten the patch and token dimensions into batch
        # x = x.view(-1, x.size(-1))  # Shape: [batch_size * 257, 1024]
        x = x[:, 0, :]  # Shape: [batch_size, 1024]

        
        # Pass through the additional embedding layers
        x = self.embedding_layer_0(x)  # BatchNorm1d for 1024
        # print(f"output shape of embedding_layer_0: {x.shape}")
        
        x = self.embedding_layer_2(x)  # Linear transformation to 256
        # print(f"output shape of embedding_layer_2: {x.shape}")
        
        x = self.embedding_layer_3(x)  # BatchNorm1d for 256
        # print(f"output shape of embedding_layer_3: {x.shape}")
        
        return x


In [21]:
def load_checkpoint_into_custom_vit_ssl(checkpoint_state_dict, custom_vit, patch_size=16):
    """
    Loads a checkpoint's state_dict into a custom Vision Transformer (ViT) model.

    Args:
        checkpoint_state_dict (dict): The state_dict from the checkpoint.
        vit_model (torch.nn.Module): The base Vision Transformer model to adjust positional embeddings.
        custom_vit (torch.nn.Module): The custom ViT model with backbone and embedding layers.
        patch_size (int): Patch size for the ViT model, used for positional embedding interpolation.

    Returns:
        None
    """
    # Separate backbone and embedding layers
    backbone_state_dict = {k: v for k, v in checkpoint_state_dict.items() if "embedding_layer" not in k}
    filtered_backbone_state_dict = {
        k: v for k, v in backbone_state_dict.items() 
        if "ls1.gamma" not in k and "ls2.gamma" not in k
    }

    interpolated_backbone_state_dict = interpolate_positional_embedding(filtered_backbone_state_dict, custom_vit.base_vit, patch_size)

    # Replace the key in the state_dict
    filtered_backbone_state_dict["pos_embed"] = interpolated_backbone_state_dict["pos_embed"]

    # Load weights into the backbone (base_vit)
    custom_vit.base_vit.load_state_dict(filtered_backbone_state_dict, strict=False)

    # Extract embedding layer weights
    embedding_layer_state_dict = {k: v for k, v in checkpoint_state_dict.items() if "embedding_layer" in k}

    # Load weights into the embedding layers
    custom_vit.embedding_layer_0.load_state_dict({
        "weight": embedding_layer_state_dict["embedding_layer.0.weight"],
        "bias": embedding_layer_state_dict["embedding_layer.0.bias"],
        "running_mean": embedding_layer_state_dict["embedding_layer.0.running_mean"],
        "running_var": embedding_layer_state_dict["embedding_layer.0.running_var"],
        "num_batches_tracked": embedding_layer_state_dict["embedding_layer.0.num_batches_tracked"],
    })
    custom_vit.embedding_layer_2.load_state_dict({
        "weight": embedding_layer_state_dict["embedding_layer.2.weight"],
        "bias": embedding_layer_state_dict["embedding_layer.2.bias"],
    })
    custom_vit.embedding_layer_3.load_state_dict({
        "weight": embedding_layer_state_dict["embedding_layer.3.weight"],
        "bias": embedding_layer_state_dict["embedding_layer.3.bias"],
        "running_mean": embedding_layer_state_dict["embedding_layer.3.running_mean"],
        "running_var": embedding_layer_state_dict["embedding_layer.3.running_var"],
        "num_batches_tracked": embedding_layer_state_dict["embedding_layer.3.num_batches_tracked"],
    })


#### create custom models and load the checkpoints' state_dict 

In [22]:
base_vit_joscha = create_model("vit_large_patch14_dinov2", pretrained=False, img_size=img_size)
base_vit_vincent = create_model("vit_large_patch14_dinov2", pretrained=False, img_size=img_size)
base_vit_robert = create_model("vit_large_patch14_dinov2", pretrained=False, img_size=img_size)
base_vit_best = create_model("vit_large_patch14_dinov2", pretrained=False, img_size=img_size)

custom_vit_joscha = CustomVisionTransformer_ssl(base_vit_joscha)
custom_vit_vincent = CustomVisionTransformer_ssl(base_vit_vincent)
custom_vit_robert = CustomViT_supervised(base_vit_robert)
custom_vit_best = CustomViT_supervised(base_vit_robert)

### load the checkpoints' state_dict into the custom_vit
load_checkpoint_into_custom_vit_ssl(cleaned_state_dict_joscha, custom_vit_joscha, patch_size=16)
load_checkpoint_into_custom_vit_ssl(cleaned_state_dict_vincent, custom_vit_vincent, patch_size=16)
load_checkpoint_into_custom_vit_supervised(cleaned_state_dict_robert, custom_vit_robert, patch_size=16)
load_checkpoint_into_custom_vit_supervised(cleaned_state_dict_best, custom_vit_best, patch_size=16)



In [23]:
### generate embeddings from the custom_vit & KNN classification
custom_vit_joscha.eval()
custom_vit_vincent.eval()
custom_vit_robert.eval()
custom_vit_best.eval()
model_wrapper.eval()


# create corresponding image tensors (224 instead of 192)
# test_images, test_labels = load_images_and_labels(test_folder, transform_standard)
# test_images, test_labels = load_images_and_labels(bristol_folder, transform_standard)

# test_images_tensor = torch.stack(test_images)

custom_embeddings_joscha = generate_embeddings(custom_vit_joscha, test_images_tensor)
custom_embeddings_vincent = generate_embeddings(custom_vit_vincent, test_images_tensor)
custom_embeddings_robert = generate_embeddings(custom_vit_robert, test_images_tensor)
custom_embeddings_best = generate_embeddings(custom_vit_best, test_images_tensor)
custom_embeddings_wrapper = generate_embeddings(model_wrapper, test_images_tensor)

def classify_and_print_custom(model_name, model, embeddings, labels):
    print(f"{model_name} with custom model:")
    knn_classifier(embeddings, labels)
    
classify_and_print_custom(model_name_joscha, custom_vit_joscha, custom_embeddings_joscha, test_labels)
classify_and_print_custom(model_name_robert, custom_vit_robert, custom_embeddings_robert, test_labels)
classify_and_print_custom(model_name_vincent, custom_vit_vincent, custom_embeddings_vincent, test_labels)
classify_and_print_custom(model_name_best, custom_vit_best, custom_embeddings_best, test_labels)
classify_and_print_custom(model_name_wrapper, model_wrapper, custom_embeddings_wrapper, test_labels)
# knn_classifier(custom_embeddings_joscha, test_labels)


output shape of base_vit: torch.Size([153, 257, 1024])
output shape of base_vit: torch.Size([153, 257, 1024])
Output shape of base_vit: torch.Size([153, 257, 1024])
Output shape of base_vit: torch.Size([153, 257, 1024])
vit_large_dinov2_ssl_joscha with custom model:
ViT Accuracy with 5-nearest neighbors: 0.3595, F1_score: 0.204865
supervised_dinov2_large with custom model:
ViT Accuracy with 5-nearest neighbors: 0.8105, F1_score: 0.641030
ssl_vincent_vit_large with custom model:
ViT Accuracy with 5-nearest neighbors: 0.4510, F1_score: 0.261567
best-model_vit_large_dinoV2 with custom model:
ViT Accuracy with 5-nearest neighbors: 0.8039, F1_score: 0.631577
best-model_vit_large_dinoV2 in timmWrapper with custom model:
ViT Accuracy with 5-nearest neighbors: 0.8039, F1_score: 0.631577


In [24]:
# evaluate with filtered embeddings

filtered_embeddings_joscha, filtered_labels_joscha = filter_samples_by_threshold(custom_embeddings_joscha, test_labels, threshold=3)
filtered_embeddings_robert, filtered_labels_robert = filter_samples_by_threshold(custom_embeddings_robert, test_labels, threshold=3)
filtered_embeddings_vincent, filtered_labels_vincent = filter_samples_by_threshold(custom_embeddings_vincent, test_labels, threshold=3)
filtered_embeddings_best, filtered_labels_best = filter_samples_by_threshold(custom_embeddings_best, test_labels, threshold=3)
filtered_embeddings_wrapper, filtered_labels_wrapper = filter_samples_by_threshold(custom_embeddings_wrapper, test_labels, threshold=3)

print("Evaluate with filtered embeddings")
classify_and_print_custom(model_name_joscha, custom_vit_joscha, filtered_embeddings_joscha, filtered_labels_joscha)
classify_and_print_custom(model_name_robert, custom_vit_robert, filtered_embeddings_robert, filtered_labels_robert)
classify_and_print_custom(model_name_vincent, custom_vit_vincent, filtered_embeddings_vincent, filtered_labels_vincent)
classify_and_print_custom(model_name_best, custom_vit_best, filtered_embeddings_best, filtered_labels_best)
classify_and_print_custom(model_name_wrapper, model_wrapper, filtered_embeddings_wrapper, filtered_labels_wrapper)


Number of images before filtering: 153
Number of embeddings after filtering: 144
Number of images before filtering: 153
Number of embeddings after filtering: 144
Number of images before filtering: 153
Number of embeddings after filtering: 144
Number of images before filtering: 153
Number of embeddings after filtering: 144
Number of images before filtering: 153
Number of embeddings after filtering: 144
Evaluate with filtered embeddings
vit_large_dinov2_ssl_joscha with custom model:
ViT Accuracy with 5-nearest neighbors: 0.3889, F1_score: 0.249918
supervised_dinov2_large with custom model:
ViT Accuracy with 5-nearest neighbors: 0.8681, F1_score: 0.789346
ssl_vincent_vit_large with custom model:
ViT Accuracy with 5-nearest neighbors: 0.4792, F1_score: 0.318877
best-model_vit_large_dinoV2 with custom model:
ViT Accuracy with 5-nearest neighbors: 0.8542, F1_score: 0.770804
best-model_vit_large_dinoV2 in timmWrapper with custom model:
ViT Accuracy with 5-nearest neighbors: 0.8542, F1_score: 