## Joscha Model (VIT Large Dinov2)

In [4]:
import os
import numpy as np
import torch
import cv2
from timm import create_model
from torchvision import transforms
from sklearn.model_selection import train_test_split, LeaveOneOut
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from tqdm import tqdm  # For progress bars
import torch
from timm import create_model
import timm


In [5]:
# Paths to model and data

joscha_checkpoint_path = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/models/vit_large_dinov2_ssl_joscha.ckpt"
robert_checkpoint_path = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/models/supervised_dinov2_large.ckpt"
vincent_checkpoint_path = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/models/ssl_vincent_vit_large.ckpt"

#checkpoint_path = robert_checkpoint_path
test_folder = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/datasets/cxl_all_split_60-25-15/test"
train_folder = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/datasets/cxl_all_split_60-25-15/train"
val_folder = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/datasets/cxl_all_split_60-25-15/val"

all_folder = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/datasets/cxl_all_face"


## Joscha Laden

In [7]:
checkpoint_path = joscha_checkpoint_path
# Initialize the ViT model
vit_model = create_model('vit_large_patch14_dinov2.lvd142m', pretrained=False, img_size=192)

# Load checkpoint and extract state_dict
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cuda'))
state_dict = checkpoint.get('state_dict', checkpoint)  # Get 'state_dict' or use checkpoint directly if no wrapper exists

# Adjust the keys if necessary (remove any prefix like 'model.')
new_state_dict = {k.replace('model_wrapper.', ''): v for k, v in state_dict.items()}
new_state_dict2 = {k.replace('model.', ''): v for k, v in new_state_dict.items()}

# Filter out unexpected keys from the state_dict
model_keys = set(vit_model.state_dict().keys())
filtered_state_dict = {k: v for k, v in new_state_dict2.items() if k in model_keys}

# Load the filtered state_dict into the model
vit_model.load_state_dict(filtered_state_dict, strict=True)
vit_model.eval()  # Set to evaluation mode

# Preprocessing function to resize and normalize images
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((192, 192)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Adjust based on model pretraining
])

# Load images and extract class codes from file names
def load_images_and_labels(folder):
    images = []
    labels = []
    for filename in os.listdir(folder):
        if filename.endswith('.jpg') or filename.endswith('.png'):
            image = cv2.imread(os.path.join(folder, filename))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = transform(image)
            images.append(image)
            labels.append(filename[:4])  # Assuming first 4 chars are label
    return images, labels

# Load data
data_folder = test_folder
images, labels = load_images_and_labels(data_folder)
images_tensor = torch.stack(images)


  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cuda'))


## Robert Loading

In [8]:
checkpoint_path = robert_checkpoint_path
# Initialize the ViT model
vit_model = create_model('vit_large_patch14_dinov2.lvd142m', pretrained=False, img_size=192)

# Load checkpoint and extract state_dict
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cuda'))
state_dict = checkpoint.get('state_dict', checkpoint)  # Get 'state_dict' or use checkpoint directly if no wrapper exists

# Adjust the keys if necessary (remove any prefix like 'model.')
new_state_dict = {k.replace('model_wrapper.', ''): v for k, v in state_dict.items()}
new_state_dict2 = {k.replace('model.', ''): v for k, v in new_state_dict.items()}

# Interpolate positional embeddings if size mismatch
if 'pos_embed' in new_state_dict2:
    pos_embed_checkpoint = new_state_dict2['pos_embed']
    pos_embed_model = vit_model.state_dict()['pos_embed']
    if pos_embed_checkpoint.shape != pos_embed_model.shape:
        print(f"Interpolating pos_embed from {pos_embed_checkpoint.shape} to {pos_embed_model.shape}")
        num_patches = pos_embed_model.shape[1] - 1  # Exclude class token
        class_pos_embed = pos_embed_checkpoint[:, :1, :]  # Class token
        patch_pos_embed = pos_embed_checkpoint[:, 1:, :]  # Patch tokens

        # Reshape and interpolate patch embeddings
        patch_pos_embed = patch_pos_embed.reshape(1, int(patch_pos_embed.size(1)**0.5), -1, patch_pos_embed.size(-1))
        patch_pos_embed = torch.nn.functional.interpolate(
            patch_pos_embed.permute(0, 3, 1, 2),  # Convert to NCHW for interpolation
            size=(int(num_patches**0.5), int(num_patches**0.5)),  # Target size
            mode='bilinear',
            align_corners=False
        ).permute(0, 2, 3, 1).reshape(1, num_patches, -1)  # Back to NHWC

        # Concatenate class token and interpolated patch embeddings
        new_pos_embed = torch.cat((class_pos_embed, patch_pos_embed), dim=1)
        new_state_dict2['pos_embed'] = new_pos_embed


# # Filter out unexpected keys from the state_dict
# model_keys = set(vit_model.state_dict().keys())
# filtered_state_dict = {k: v for k, v in new_state_dict2.items() if k in model_keys}

# Load the filtered state_dict into the model
vit_model.load_state_dict(filtered_state_dict, strict=True)
vit_model.eval()  # Set to evaluation mode

# Preprocessing function to resize and normalize images
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((192, 192)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Adjust based on model pretraining
])

# Load images and extract class codes from file names
def load_images_and_labels(folder):
    images = []
    labels = []
    for filename in os.listdir(folder):
        if filename.endswith('.jpg') or filename.endswith('.png'):
            image = cv2.imread(os.path.join(folder, filename))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = transform(image)
            images.append(image)
            labels.append(filename[:4])  # Assuming first 4 chars are label
    return images, labels

# Load data
images, labels = load_images_and_labels(data_folder)
images_tensor = torch.stack(images)


  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cuda'))


Interpolating pos_embed from torch.Size([1, 257, 1024]) to torch.Size([1, 170, 1024])


## Using the model

In [9]:
# Function to generate embeddings from a model
def generate_embeddings(model, images_tensor):
    with torch.no_grad():
        embeddings = model.forward_features(images_tensor)
    return embeddings



# Generate and flatten embeddings for ViT
vit_embeddings = generate_embeddings(vit_model, images_tensor)
vit_embeddings_flat = vit_embeddings.view(vit_embeddings.size(0), -1).numpy()

# Train-test split and train KNN on ViT embeddings
vit_X_train, vit_X_test, vit_y_train, vit_y_test = train_test_split(vit_embeddings_flat, labels, test_size=0.2, random_state=100)
vit_knn = KNeighborsClassifier(n_neighbors=5)
vit_knn.fit(vit_X_train, vit_y_train)
vit_y_pred = vit_knn.predict(vit_X_test)
vit_accuracy = accuracy_score(vit_y_test, vit_y_pred)
print(f'ViT Accuracy: {vit_accuracy:.4f}')

# Leave-One-Out Cross-Validation KNN Classification with progress bar
def leave_one_out_knn_classification(model, model_name, images_tensor, labels):
    print(f"Using model: {model_name}")
    embeddings = generate_embeddings(model, images_tensor)
    embeddings_flat = embeddings.view(embeddings.size(0), -1).numpy()
    loo = LeaveOneOut()
    y_true, y_pred = [], []

    for train_index, test_index in tqdm(loo.split(embeddings_flat), desc="Leave-One-Out CV", total=len(embeddings_flat), unit="sample"):
        X_train, X_test = embeddings_flat[train_index], embeddings_flat[test_index]
        y_train, y_test = np.array(labels)[train_index], np.array(labels)[test_index]
        knn = KNeighborsClassifier(n_neighbors=5)
        knn.fit(X_train, y_train)
        y_test_pred = knn.predict(X_test)
        y_true.append(y_test[0])
        y_pred.append(y_test_pred[0])

    accuracy = accuracy_score(y_true, y_pred)
    print(f'Leave-One-Out Cross-Validation Accuracy for {model_name}: {accuracy:.4f}')
    return accuracy

# Perform Leave-One-Out KNN Classification on ViT model
#leave_one_out_knn_classification(vit_model, 'ViT', images_tensor, labels)


ViT Accuracy: 0.0909


Letztes mal:

ViT Accuracy: 0.6667 (Mit test_size 0.05)