## Joscha Model (VIT Large Dinov2)

1. Easy debug: pass in a few images adn check the output (embeddings + classification)

2. Check which layers are added to the normal vit model (instead of removing weight params)

3. datasets

In [None]:
import os
import numpy as np
import torch
import cv2
from timm import create_model
from torchvision import transforms
from sklearn.model_selection import train_test_split, LeaveOneOut
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from tqdm import tqdm  # For progress bars
import torch
from timm import create_model
import timm
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Paths to model and data

joscha_checkpoint_path = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/models/vit_large_dinov2_ssl_joscha.ckpt"
robert_checkpoint_path = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/models/supervised_dinov2_large.ckpt"
vincent_checkpoint_path = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/models/ssl_vincent_vit_large.ckpt"

#checkpoint_path = robert_checkpoint_path
test_folder = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/datasets/cxl_all_split_60-25-15/test"
train_folder = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/datasets/cxl_all_split_60-25-15/train"
val_folder = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/datasets/cxl_all_split_60-25-15/val"

all_folder = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/datasets/cxl_all_face"


## Load checkpoints (vit_large_dinov2_ssl_joscha)

In [None]:

### inspect the checkpoints

# checkpoint_path = joscha_checkpoint_path
checkpoint_path = robert_checkpoint_path
# checkpoint_path = vincent_checkpoint_path

# Load checkpoint and extract state_dict
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cuda'))
# print(checkpoint["hyper_parameters"])
# print(checkpoint["hparams_name"])
# print(checkpoint["state_dict"].keys())

In [None]:
### compare checkpoints' state_dict with a standard model's state_dict


# extract the state_dict from checkpoints
state_dict = checkpoint.get('state_dict', checkpoint)  # Get 'state_dict' or use checkpoint directly if no wrapper exists
new_state_dict = {k.replace('model_wrapper.', ''): v for k, v in state_dict.items()}
new_state_dict2 = {k.replace('model.', ''): v for k, v in new_state_dict.items()}

# Initialize the ViT model
vit_model = create_model('vit_large_patch14_dinov2.lvd142m', pretrained=False, img_size=192)

# Standard ViT model keys
standard_vit_keys = vit_model.state_dict().keys()

# Compare keys
custom_keys = new_state_dict2.keys()
missing_in_standard = [k for k in custom_keys if k not in standard_vit_keys]
extra_in_standard = [k for k in standard_vit_keys if k not in custom_keys]

# print("Missing keys in standard ViT:", missing_in_standard)
# print("Extra keys in standard ViT:", extra_in_standard)

# print(new_state_dict2)

for key in missing_in_standard:
    print(f"{key}: {new_state_dict2[key].shape}")
    
embedding_keys = [key for key in new_state_dict2.keys() if "embedding_layer" in key]
# print(embedding_keys)
    


In [None]:
# fix pos_emd mismatch
import torch
import torch.nn.functional as F

def interpolate_positional_embedding(pretrained_state_dict, current_model, patch_size):
    """
    Adjusts the positional embeddings from the pretrained model to fit the current model.

    Args:
        pretrained_state_dict (dict): State dictionary of the pretrained model.
        current_model (torch.nn.Module): The current Vision Transformer model.
        patch_size (int): Patch size of the ViT model.

    Returns:
        dict: Updated state dictionary compatible with the current model.
    """
    # Get the pretrained positional embeddings
    pretrained_pos_embed = pretrained_state_dict['pos_embed']
    current_pos_embed = current_model.state_dict()['pos_embed']
    
    # Exclude the CLS token (first token) if present
    cls_token = pretrained_pos_embed[:, :1, :]
    pretrained_grid = pretrained_pos_embed[:, 1:, :]
    
    # Get the grid dimensions
    num_patches = current_pos_embed.shape[1] - 1  # Exclude CLS token
    grid_size_pretrained = int((pretrained_grid.shape[1])**0.5)  # sqrt(num_patches)
    grid_size_current = int(num_patches**0.5)

    # Reshape the positional embeddings to a grid
    pretrained_grid = pretrained_grid.reshape(1, grid_size_pretrained, grid_size_pretrained, -1).permute(0, 3, 1, 2)
    
    # Interpolate to the current grid size
    interpolated_grid = F.interpolate(pretrained_grid, size=(grid_size_current, grid_size_current), mode='bilinear', align_corners=False)
    interpolated_grid = interpolated_grid.permute(0, 2, 3, 1).reshape(1, -1, interpolated_grid.shape[1])
    
    # Combine CLS token and the new grid
    new_pos_embed = torch.cat([cls_token, interpolated_grid], dim=1)
    pretrained_state_dict['pos_embed'] = new_pos_embed

    return pretrained_state_dict


interpolate_positional_embedding(new_state_dict2, vit_model, 16)

In [None]:
# Load the model supervised_dinov2_large into vit 

vit_model_robert = create_model('vit_large_patch14_dinov2.lvd142m', pretrained=False, img_size=192)

vit_model_robert.load_state_dict(new_state_dict2, strict=False)


In [None]:
# wrap the vit model with an embedding layer

In [None]:
# wrap the VIT model with custom layers (vit_large_dinov2_ssl_joscha)

class CustomVisionTransformer(nn.Module):
    def __init__(self, base_vit):
        super(CustomVisionTransformer, self).__init__()
        self.base_vit = base_vit
        
        # Define embedding layers based on checkpoint dimensions
        self.embedding_layer_0 = nn.BatchNorm1d(1024)  # Normalize input size 1024
        self.embedding_layer_2 = nn.Linear(1024, 256)  # Linear layer: 1024 -> 256
        self.embedding_layer_3 = nn.BatchNorm1d(256)  # Normalize input size 256

    def forward(self, x):
        # Pass input through the Vision Transformer backbone
        x = self.base_vit.forward_features(x)  # Output size is [batch_size, 1024]
        print(f"output shape of base_vit: {x.shape}")
        
        # Flatten the patch and token dimensions into batch
        # x = x.view(-1, x.size(-1))  # Shape: [batch_size * 257, 1024]
        x = x[:, 0, :]  # Shape: [batch_size, 1024]

        
        # Pass through the additional embedding layers
        x = self.embedding_layer_0(x)  # BatchNorm1d for 1024
        print(f"output shape of embedding_layer_0: {x.shape}")
        x = self.embedding_layer_2(x)  # Linear transformation to 256
        print(f"output shape of embedding_layer_2: {x.shape}")
        x = self.embedding_layer_3(x)  # BatchNorm1d for 256
        print(f"output shape of embedding_layer_3: {x.shape}")
        
        return x

# Initialize the base ViT model
base_vit = create_model('vit_large_patch14_224', pretrained=False)

# Wrap the base ViT model with the custom embedding layers
custom_vit = CustomVisionTransformer(base_vit)


In [None]:
# # load the checkpoints' state_dict into the custom_vit 

# # Adjust the keys
# adjusted_state_dict = {k.replace("model_wrapper.", ""): v for k, v in state_dict.items()}
# backbone_keys_adjusted = {k.replace("model.", ""): v for k, v in adjusted_state_dict.items()}

# # Separate backbone and embedding layers
# backbone_state_dict = {k: v for k, v in backbone_keys_adjusted.items() if "embedding_layer" not in k}
# filtered_backbone_state_dict = {k: v for k, v in backbone_state_dict.items() if "ls1.gamma" not in k and "ls2.gamma" not in k}

# import torch.nn.functional as F

# # Resize the positional embeddings
# checkpoint_pos_embed = backbone_state_dict["pos_embed"]
# checkpoint_pos_embed = checkpoint_pos_embed.view(1, 170, -1)  # Flatten
# new_pos_embed = F.interpolate(checkpoint_pos_embed.permute(0, 2, 1), size=257, mode='linear')  # Interpolate
# new_pos_embed = new_pos_embed.permute(0, 2, 1).view(1, 257, -1)

# # Replace the key in the state_dict
# filtered_backbone_state_dict["pos_embed"] = new_pos_embed

# embedding_layer_state_dict = {k: v for k, v in backbone_keys_adjusted.items() if "embedding_layer" in k}


# # Load weights into the backbone (base_vit)
# custom_vit.base_vit.load_state_dict(filtered_backbone_state_dict, strict=False)

# # Load weights into the embedding layers
# custom_vit.embedding_layer_0.load_state_dict({
#     "weight": embedding_layer_state_dict["embedding_layer.0.weight"],
#     "bias": embedding_layer_state_dict["embedding_layer.0.bias"],
#     "running_mean": embedding_layer_state_dict["embedding_layer.0.running_mean"],
#     "running_var": embedding_layer_state_dict["embedding_layer.0.running_var"],
#     "num_batches_tracked": embedding_layer_state_dict["embedding_layer.0.num_batches_tracked"],
# })
# custom_vit.embedding_layer_2.load_state_dict({
#     "weight": embedding_layer_state_dict["embedding_layer.2.weight"],
#     "bias": embedding_layer_state_dict["embedding_layer.2.bias"],
# })
# custom_vit.embedding_layer_3.load_state_dict({
#     "weight": embedding_layer_state_dict["embedding_layer.3.weight"],
#     "bias": embedding_layer_state_dict["embedding_layer.3.bias"],
#     "running_mean": embedding_layer_state_dict["embedding_layer.3.running_mean"],
#     "running_var": embedding_layer_state_dict["embedding_layer.3.running_var"],
#     "num_batches_tracked": embedding_layer_state_dict["embedding_layer.3.num_batches_tracked"],
# })

# # custom_vit.to(device)
# custom_vit.eval()


In [None]:
# ### Max' approach to loading the state_dict

# # Adjust the keys if necessary (remove any prefix like 'model.')
# new_state_dict = {k.replace('model_wrapper.', ''): v for k, v in state_dict.items()}
# new_state_dict2 = {k.replace('model.', ''): v for k, v in new_state_dict.items()}

# # Filter out unexpected keys from the state_dict
# model_keys = set(vit_model.state_dict().keys())
# filtered_state_dict = {k: v for k, v in new_state_dict2.items() if k in model_keys}

# # Load the filtered state_dict into the model
# vit_model.load_state_dict(filtered_state_dict, strict=True)
# vit_model.eval()  # Set to evaluation mode

# ### MAx' approach

In [None]:
# load the data
# Preprocessing function to resize and normalize images
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Adjust based on model pretraining
])

# Load images and extract class codes from file names
def load_images_and_labels(folder):
    images = []
    labels = []
    filenames = []
    for filename in os.listdir(folder):
        if filename.endswith('.jpg') or filename.endswith('.png'):
            image = cv2.imread(os.path.join(folder, filename))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = transform(image)
            images.append(image)
            labels.append(filename[:4])  # Assuming first 4 chars are label
            
            filenames.append(filename)
    
    return images, labels

# Load data
test_images, test_labels = load_images_and_labels(test_folder)
# train_images, train_labels = load_images_and_labels(train_folder)
# val_images, val_labels = load_images_and_labels(val_folder)
# all_images, all_labels = load_images_and_labels(all_folder)

test_images_tensor = torch.stack(test_images)
# train_images_tensor = torch.stack(train_images)
# val_images_tensor = torch.stack(val_images)
# all_images_tensor = torch.stack(all_images)

# train_val_images_tensor = torch.cat((train_images_tensor, val_images_tensor), 0)
# train_val_labels = train_labels + val_labels


## Robert Loading

## Using the custom model

In [None]:
# Function to generate embeddings from a model
def generate_embeddings(model, images_tensor):
    with torch.no_grad():
        # Call the forward method to extract embeddings
        embeddings = model(images_tensor)
    return embeddings

# test_images_tensor = test_images_tensor.to(device)
# all_images_tensor = all_images_tensor.to(device)

# Generate and flatten embeddings for ViT
test_vit_embeddings = generate_embeddings(custom_vit, test_images_tensor)
# train_val_vit_embeddings = generate_embeddings(custom_vit, train_val_images_tensor)
# all_vit_embeddings = generate_embeddings(custom_vit, all_images_tensor)

# Flatten the embeddings
test_vit_embeddings_flat = test_vit_embeddings.view(test_vit_embeddings.size(0), -1).numpy()
# train_val_vit_embeddings_flat = train_val_vit_embeddings.view(train_val_vit_embeddings.size(0), -1).numpy()
# all_vit_embeddings_flat = all_vit_embeddings.view(all_vit_embeddings.size(0), -1).numpy()




In [None]:
# load the data
# Preprocessing function to resize and normalize images
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((192, 192)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Adjust based on model pretraining
])

# Load images and extract class codes from file names
def load_images_and_labels(folder):
    images = []
    labels = []
    filenames = []
    for filename in os.listdir(folder):
        if filename.endswith('.jpg') or filename.endswith('.png'):
            image = cv2.imread(os.path.join(folder, filename))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = transform(image)
            images.append(image)
            labels.append(filename[:4])  # Assuming first 4 chars are label
            
            filenames.append(filename)
    
    return images, labels

# Load data
test_images, test_labels = load_images_and_labels(test_folder)
# train_images, train_labels = load_images_and_labels(train_folder)
# val_images, val_labels = load_images_and_labels(val_folder)
# all_images, all_labels = load_images_and_labels(all_folder)

test_images_tensor = torch.stack(test_images)
# train_images_tensor = torch.stack(train_images)
# val_images_tensor = torch.stack(val_images)
# all_images_tensor = torch.stack(all_images)

# train_val_images_tensor = torch.cat((train_images_tensor, val_images_tensor), 0)
# train_val_labels = train_labels + val_labels


robert_test_embeddings = generate_embeddings(vit_model_robert, test_images_tensor)

In [None]:
# # try reducing smaples with insufficient data

# from collections import Counter
# threshold = 3

# # Count occurrences of each label
# label_counts = Counter(test_labels)
# print(f"Number of images before filtering: {len(test_labels)}")

# # 
# # Filter embeddings and labels for individuals with >= 3 samples
# filtered_indices = [i for i, label in enumerate(test_labels) if label_counts[label] >= threshold]
# test_vit_embeddings_filtered = test_vit_embeddings_flat[filtered_indices]
# test_labels_filtered = [test_labels[i] for i in filtered_indices]

# # check the number of images after filtering
# print(f"Number of images after filtering: {len(test_labels_filtered)}")
# print(f"Number of images after filtering: {len(test_vit_embeddings_filtered)}")

from collections import Counter

def filter_samples_by_threshold(embeddings, labels, threshold=3):
    # Count occurrences of each label
    label_counts = Counter(labels)
    print(f"Number of images before filtering: {len(labels)}")

    # Filter embeddings and labels for labels with >= threshold occurrences
    filtered_indices = [i for i, label in enumerate(labels) if label_counts[label] >= threshold]
    filtered_embeddings = embeddings[filtered_indices]
    filtered_labels = [labels[i] for i in filtered_indices]

    # Log results
    print(f"Number of images after filtering: {len(filtered_labels)}")
    print(f"Number of embeddings after filtering: {len(filtered_embeddings)}")

    return filtered_embeddings, filtered_labels



In [None]:
# KNN classifier and excluding the closest neighbor (self-matching)
def knn_classifier(embeddings, labels):
    for k in range(1, 6):
        # Use k+1 neighbors to account for excluding self
        vit_knn = KNeighborsClassifier(n_neighbors=k + 1)
        vit_knn.fit(embeddings, labels)  # Fit KNN on all data

        vit_y_pred = []
        for idx, test_embedding in enumerate(embeddings):
            # Find k+1 neighbors (including the test sample itself)
            neighbors = vit_knn.kneighbors(test_embedding.reshape(1, -1), return_distance=False)[0]

            # Exclude the test sample itself (assumed to be the closest neighbor)
            filtered_neighbors = neighbors[1:]  # Exclude the first neighbor
        
            # Predict based on the remaining neighbors (majority vote)
            filtered_neighbor_labels = [labels[n] for n in filtered_neighbors]
            predicted_label = max(set(filtered_neighbor_labels), key=filtered_neighbor_labels.count)
            vit_y_pred.append(predicted_label)

        # Calculate accuracy for the current k
        vit_accuracy = accuracy_score(labels, vit_y_pred)
        print(f"ViT Accuracy with {k}-nearest neighbors (excluding self): {vit_accuracy:.4f}")
        

# knn_classifier(all_vit_embeddings_reduced, all_labels, test_vit_embeddings_reduced, test_labels)
# knn_classifier(all_vit_embeddings_filtered, all_labels_filtered, test_vit_embeddings_flat, test_labels)
print("all test samples: joscha")
knn_classifier(test_vit_embeddings_flat, test_labels)
print("all test samples: robert")
knn_classifier(robert_test_embeddings, test_labels)

threshold = 3
test_vit_embeddings_filtered, test_labels_filtered = filter_samples_by_threshold(
    test_vit_embeddings_flat,
    test_labels,
    threshold=3
)
print(f"filtered test samples (threshold = {threshold}):joscha")
knn_classifier(test_vit_embeddings_filtered, test_labels_filtered)

robert_embeddings_filtered, robert_labels_filtered = filter_samples_by_threshold(
    robert_test_embeddings,
    test_labels,
    threshold=3
)
print(f"filtered test samples (threshold = {threshold}):robert")
knn_classifier(robert_embeddings_filtered, robert_labels_filtered)



all test samples: joscha
ViT Accuracy with 1-nearest neighbors (excluding self): 0.3704
ViT Accuracy with 2-nearest neighbors (excluding self): 0.2963
ViT Accuracy with 3-nearest neighbors (excluding self): 0.2222
ViT Accuracy with 4-nearest neighbors (excluding self): 0.2407
ViT Accuracy with 5-nearest neighbors (excluding self): 0.2222
all test samples: robert
ViT Accuracy with 1-nearest neighbors (excluding self): 0.8704
ViT Accuracy with 2-nearest neighbors (excluding self): 0.7778
ViT Accuracy with 3-nearest neighbors (excluding self): 0.7222
ViT Accuracy with 4-nearest neighbors (excluding self): 0.6852
ViT Accuracy with 5-nearest neighbors (excluding self): 0.6852
Number of images before filtering: 54
Number of images after filtering: 40
Number of embeddings after filtering: 40
filtered test samples (threshold = 3):joscha
ViT Accuracy with 1-nearest neighbors (excluding self): 0.4000
ViT Accuracy with 2-nearest neighbors (excluding self): 0.3250
ViT Accuracy with 3-nearest neighbors (excluding self): 0.2750
ViT Accuracy with 4-nearest neighbors (excluding self): 0.3250
ViT Accuracy with 5-nearest neighbors (excluding self): 0.2500
Number of images before filtering: 54
Number of images after filtering: 40
Number of embeddings after filtering: 40
filtered test samples (threshold = 3):robert
ViT Accuracy with 1-nearest neighbors (excluding self): 0.9750
ViT Accuracy with 2-nearest neighbors (excluding self): 0.9500
ViT Accuracy with 3-nearest neighbors (excluding self): 0.9750
ViT Accuracy with 4-nearest neighbors (excluding self): 0.9250
ViT Accuracy with 5-nearest neighbors (excluding self): 0.8750


Letztes mal:

ViT Accuracy: 0.6667 (Mit test_size 0.05)

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# Assuming `all_vit_embeddings_flat` contains all embeddings
# and `all_labels` contains corresponding labels

# Using t-SNE to reduce dimensionality to 2D for visualization
tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
# reduced_embeddings_tsne = tsne.fit_transform(all_vit_embeddings_flat)
reduced_embeddings_tsne = tsne.fit_transform(test_vit_embeddings_flat)

# Convert labels to numeric values for coloring
unique_labels = list(set(test_labels))
label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
numeric_labels = [label_to_index[label] for label in test_labels]

# Visualizing the reduced embeddings with a scatter plot
plt.figure(figsize=(12, 8))
scatter = plt.scatter(
    reduced_embeddings_tsne[:, 0],
    reduced_embeddings_tsne[:, 1],
    c=numeric_labels,
    cmap="tab10",
    alpha=0.7,
    edgecolor="k"
)
plt.colorbar(scatter, ticks=range(len(unique_labels)), label="Classes")
plt.title("t-SNE Visualization of Embeddings")
plt.xlabel("t-SNE Dimension 1")
plt.ylabel("t-SNE Dimension 2")
plt.show()
