In [68]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
from torchvision import transforms
from collections import Counter
from tqdm import tqdm
from timm import create_model
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score
from collections import defaultdict
from wrappers_supervised import *


In [None]:
# set the path to dataset to evaluate 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

best_checkpoint_path = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/models/should-be-the-best-model_vit_large_dinoV2.ckpt"
checkpoint_best = torch.load(best_checkpoint_path, map_location=device)

# Path to folders
test_folder = "/workspaces/gorilla_watch/video_data/gorillawatch/gorillatracker/datasets/cxl_faces_squared_openset_kfold-5/test"

  checkpoint_best = torch.load(best_checkpoint_path, map_location=device)


For KNN-CV we only classify the filtered_filtered_images of individuals where there are at least 3 filtered_images of the same individual from a different video (Note that we still use images that do not fulfill this condition for classifying other images). When classifying we only classify an image using images from other videos.

In [None]:
# load the dataset and filter out the lsit of images that will be classified by cross-video KNN
img_size = 224

# Define transformations
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

class GorillaDataset(Dataset):
    def __init__(self, folder, transform, threshold=4):
        self.folder = folder
        self.transform = transform
        self.threshold = threshold
        
        self.images = [] 
        self.labels = []
        self.videos = []
        
        # keeps track of the idx of images for different KNN metric
        self.images_for_standard_knn = []
        self.images_for_cross_video_knn = []
        
        # the dataset of all images
        for filename in os.listdir(folder):
            if filename.endswith('.jpg') or filename.endswith('.png'):
                image_path = os.path.join(folder, filename)
                self.images.append(image_path)
                label = filename.split("_")[0] # Extract label from filename
                self.labels.append(label)
                video = filename.split("_")[1] + "_" + filename.split("_")[2]
                self.videos.append(video)
        
        # Organize images by label and video for filtering
        data_by_label = defaultdict(lambda: {"images": [], "videos": defaultdict(list)})
        for image, label, video in zip(self.images, self.labels, self.videos):
            data_by_label[label]["images"].append(image)
            data_by_label[label]["videos"][video].append(image)
        
        # Find valid images for cross-video KNN and standard KNN
        filtered_images = [] # to remove the classes with less than 4 images
        filtered_labels = []
        filtered_videos = []
        for idx, (image, label, video) in enumerate(zip(self.images, self.labels, self.videos)):
            # Check if the image's label has at least 3 images in other videos
            videos_with_label = data_by_label[label]["videos"]
            other_videos_count = sum(
                len(images) for vid, images in videos_with_label.items() if vid != video
            )
            # if an image has more than 3 images in other videos under the same class, we will use it for cross-video KNN
            if other_videos_count >= 3:
                self.images_for_cross_video_knn.append(idx)
                
            # if a class has more than 4(threshold) images, we put it in filtered_images
            if len(data_by_label[label]["images"]) >= threshold:
                # self.images_for_standard_knn.append((idx, label))
                self.images_for_standard_knn.append(idx)
                filtered_images.append(image)
                filtered_labels.append(label)
                filtered_videos.append(video)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # generate embeddings for all the images (and only classify the ones that are in the valid_classes)
        image_path = self.images[idx]
        label = self.labels[idx]
        video = self.videos[idx]    
        
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transform(image)
        return image, label, video

# Create dataset and DataLoader
test_dataset = GorillaDataset(test_folder, transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)


In [71]:

# Load model from provided checkpoint
def extract_clean_state_dict_for_wrapper(checkpoint, wrapper_key="model_wrapper.", model_key="model."):
    state_dict = checkpoint.get('state_dict', checkpoint)
    cleaned_state_dict = {k.replace(wrapper_key, ''): v for k, v in state_dict.items()}
    return cleaned_state_dict

# using TimmWrapper provided by previous BP: embedding_id could be "linear" or ""
model_wrapper = TimmWrapper(
    backbone_name="vit_large_patch14_dinov2.lvd142m",
    embedding_size=256,
    embedding_id="linear", # possible values: "linear", ""
    dropout_p=0.0,
    pool_mode="none",
    img_size=224
)

cleaned_state_dict_wrapper = extract_clean_state_dict_for_wrapper(checkpoint_best)
model_wrapper.load_state_dict(cleaned_state_dict_wrapper, strict=False)


_IncompatibleKeys(missing_keys=[], unexpected_keys=['loss_module_train.model.cls_token', 'loss_module_train.model.pos_embed', 'loss_module_train.model.patch_embed.proj.weight', 'loss_module_train.model.patch_embed.proj.bias', 'loss_module_train.model.blocks.0.norm1.weight', 'loss_module_train.model.blocks.0.norm1.bias', 'loss_module_train.model.blocks.0.attn.qkv.weight', 'loss_module_train.model.blocks.0.attn.qkv.bias', 'loss_module_train.model.blocks.0.attn.proj.weight', 'loss_module_train.model.blocks.0.attn.proj.bias', 'loss_module_train.model.blocks.0.ls1.gamma', 'loss_module_train.model.blocks.0.norm2.weight', 'loss_module_train.model.blocks.0.norm2.bias', 'loss_module_train.model.blocks.0.mlp.fc1.weight', 'loss_module_train.model.blocks.0.mlp.fc1.bias', 'loss_module_train.model.blocks.0.mlp.fc2.weight', 'loss_module_train.model.blocks.0.mlp.fc2.bias', 'loss_module_train.model.blocks.0.ls2.gamma', 'loss_module_train.model.blocks.1.norm1.weight', 'loss_module_train.model.blocks.1.n

In [72]:
# generate embeddings for the test dataset
model_wrapper.to(device)
model_wrapper.eval()

# Generate embeddings
def generate_embeddings(model, data_loader):
    all_embeddings = []
    all_labels = []
    all_videos = []
    with torch.no_grad():
        for images, labels, videos in tqdm(data_loader, desc="Generating Embeddings"):
            images = images.to(device)
            embeddings = model(images)
            all_embeddings.append(embeddings.cpu())
            all_labels.extend(labels)
            all_videos.extend(videos)
    return torch.cat(all_embeddings), all_labels, all_videos

embeddings, labels, video_ids = generate_embeddings(model_wrapper, test_loader)



Generating Embeddings:   0%|          | 0/4 [00:00<?, ?it/s]

Generating Embeddings: 100%|██████████| 4/4 [00:22<00:00,  5.73s/it]


In [73]:
# KNN for cross-video classification

# Function to calculate distance based on the chosen metric
def calculate_distance(embeddings, test_embedding, metric):
    if metric == "euclidean":
        # Compute Euclidean distance
        distances = np.linalg.norm(embeddings - test_embedding, axis=1)
    elif metric == "cosine":
        # Normalize embeddings to unit vectors for cosine similarity
        normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
        normalized_test_embedding = test_embedding / np.linalg.norm(test_embedding)
        # Compute cosine similarity
        cosine_similarity = np.dot(normalized_embeddings, normalized_test_embedding)
        # Convert similarity to distance
        distances = 1 - cosine_similarity
    else:
        raise ValueError(f"Unsupported distance metric: {metric}")
    return distances

def KNN_CV(embeddings, images_to_check, labels, video_ids, distance_metric="euclidean", num_neighbors=5):
    embeddings = embeddings.numpy()
    vit_y_pred = []
    actual_labels = []

    for idx, test_embedding in enumerate(embeddings):
        # Only classify images that satisfies the condition for cross-video check
        if idx not in images_to_check:
            continue
        
        # Calculate distances using the chosen metric
        distances = calculate_distance(embeddings, test_embedding, distance_metric)

        # Get sorted indices of neighbors based on distance
        sorted_indices = np.argsort(distances)
        sorted_indices = sorted_indices[sorted_indices != idx]  # Exclude self

        # Get valid neighbors (i.e., neighbors from different videos)
        valid_neighbors = []
        for neighbor_idx in sorted_indices:
            if video_ids[neighbor_idx] != video_ids[idx]:
                valid_neighbors.append(neighbor_idx)
            if len(valid_neighbors) == num_neighbors:
                    break

        if len(valid_neighbors) < num_neighbors:
            print(f"Warning: Less than {num_neighbors} valid neighbors for index {idx}.")
        
        # Get labels for the valid neighbors
        valid_neighbor_labels = [labels[i] for i in valid_neighbors]
        
        predicted_label = max(set(valid_neighbor_labels), key=valid_neighbor_labels.count)
        actual_label = labels[idx]

        vit_y_pred.append(predicted_label)
        actual_labels.append(actual_label)
    
    accuracy = accuracy_score(actual_labels, vit_y_pred)
    return accuracy

print(len(test_dataset.images_for_cross_video_knn))
KNN5_CV_accuracy = KNN_CV(embeddings, test_dataset.images_for_cross_video_knn, labels, video_ids, distance_metric="euclidean")
print(f"Cross-Video KNN5 Accuracy: {KNN5_CV_accuracy:.4f}")


129
Cross-Video KNN5 Accuracy: 0.8372


In [74]:
# standard KNN classification (without cross-video check)
def KNN_standard(embeddings, images_to_check, labels, distance_metric="euclidean", num_neighbors=5):
    embeddings = embeddings.numpy()
    vit_y_pred = []
    actual_labels = []

    for idx, test_embedding in enumerate(embeddings):
        # Only classify images that satisfies the condition for cross-video check
        if idx not in images_to_check:
            continue
        
        # Calculate distances using the chosen metric
        distances = calculate_distance(embeddings, test_embedding, distance_metric)

        # Get sorted indices of neighbors based on distance
        sorted_indices = np.argsort(distances)
        sorted_indices = sorted_indices[sorted_indices != idx]  # Exclude self
        neighbors = sorted_indices[:num_neighbors]
        neighbor_labels = [labels[i] for i in neighbors]
        
        predicted_label = max(set(neighbor_labels), key=neighbor_labels.count)
        actual_label = labels[idx]

        vit_y_pred.append(predicted_label)
        actual_labels.append(actual_label)
    
    accuracy = accuracy_score(actual_labels, vit_y_pred)
    return accuracy

print(len(test_dataset.images))
KNN5_standard_accuracy = KNN_standard(embeddings, test_dataset.images_for_standard_knn, labels, distance_metric="euclidean")
print(f"Standard KNN5 Accuracy: {KNN5_standard_accuracy:.4f}")


205
Standard KNN5 Accuracy: 0.9560
