### imports and metric calc init

In [1]:
from collections import defaultdict
from itertools import islice
from typing import Any, Dict, List, Literal, Optional

import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import seaborn as sns
import sklearn
import torch
import torchmetrics as tm
import tqdm
import wandb
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from scipy import stats
from sklearn.manifold import TSNE
from sklearn.metrics import accuracy_score, f1_score, precision_score
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import DataLoader as Dataloader
from torchmetrics.functional import pairwise_euclidean_distance
from torchvision.transforms import ToPILImage

import gorillatracker.type_helper as gtypes
from gorillatracker.data.contrastive_sampler import ContrastiveKFoldValSampler, get_individual, get_individual_video_id
from gorillatracker.data.nlet_dm import NletDataModule
from gorillatracker.utils.labelencoder import LinearSequenceEncoder

params = {
    "font.size": 11,
    "font.family": "serif",
}


plt.rcParams.update(params)

In [2]:
def get_partition_from_dataframe(
    data: pd.DataFrame, partition: Literal["val", "train", "test"] = "val"
) -> tuple[pd.DataFrame, torch.Tensor, torch.Tensor, list[gtypes.Id], torch.Tensor]:
    partition_df = data.where(data["partition"] == partition).dropna()
    partition_labels = torch.tensor(partition_df["label"].tolist()).long()
    partition_embeddings = np.stack(partition_df["embedding"].apply(np.array)).astype(np.float32)
    partition_embeddings = torch.tensor(partition_embeddings)
    partition_ids = partition_df["id"].tolist()
    partition_encoded_labels = torch.tensor(partition_df["encoded_label"].tolist()).long()

    return partition_df, partition_labels, partition_embeddings, partition_ids, partition_encoded_labels


def _get_crossvideo_masks(
    labels: torch.Tensor, ids: list[gtypes.Id], min_samples: int = 3
) -> tuple[torch.Tensor, torch.Tensor]:
    distance_mask = torch.zeros((len(labels), len(labels)))
    classification_mask = torch.zeros(len(labels))

    vids_per_id: defaultdict[str, defaultdict[str, int]] = defaultdict(
        lambda: defaultdict(lambda: 0)
    )  # NOTE: individual_id -> (individual_video_id -> num_images))
    idx_per_vid: defaultdict[str, list[int]] = defaultdict(list)
    for i, id in enumerate(ids):
        individual_video_id = get_individual_video_id(id)
        vids_per_id[get_individual(id)][individual_video_id] += 1
        idx_per_vid[individual_video_id].append(i)

    for i, id in enumerate(ids):
        individual_video_id = get_individual_video_id(id)

        distance_mask_ = [True] * len(ids)
        for idx in idx_per_vid[individual_video_id]:
            distance_mask_[idx] = False
        distance_mask[i] = torch.tensor(distance_mask_)  # 1 if not same video, 0 if same video

        if (
            sum(vids_per_id[get_individual(id)].values()) - vids_per_id[get_individual(id)][individual_video_id]
            >= min_samples
        ):
            classification_mask[i] = True

    return distance_mask.bool(), classification_mask.bool()


def knn(
    data: pd.DataFrame,
    average: Literal["micro", "macro", "weighted", "none"] = "weighted",
    k: int = 5,
    use_train_embeddings: bool = False,
    use_crossvideo_positives: bool = False,
    distance_metric: Literal["euclidean", "cosine"] = "euclidean",
    use_filter: bool = False,
) -> Dict[str, Any]:
    """
    Algorithmic Description:
    1. Calculate the distance matrix between all embeddings (len(embeddings) x len(embeddings))
       Set the diagonal of the distance matrix to a large value so that the distance to itself is ignored
    2. For each embedding find the k closest [smallest distances] embeddings (len(embeddings) x k)
       First find the indexes, the map to the labels (numbers).
    3. Create classification matrix where every embedding has a row with the probability for each class in it's top k surroundings (len(embeddings) x num_classes)
    4. Select only the validation part of the classification matrix (len(val_embeddings) x num_classes)
    5. Calculate the accuracy, accuracy_top5, auroc and f1 score: Either choose highest probability as class as matched class or check if any of the top 5 classes matches.
    """

    # convert embeddings and labels to tensors
    _, _, val_embeddings, val_ids, val_labels = get_partition_from_dataframe(data, partition="val")
    train_labels, train_embeddings = torch.Tensor([]), torch.Tensor([])
    if use_train_embeddings:
        _, _, train_embeddings, _, train_labels = get_partition_from_dataframe(data, partition="train")

    # NOTE(rob2u): k // 2 + 1 for majority +1 because one is classified
    min_amount = k // 2 + 2 if use_filter else 0
    val_labels_unique, val_labels_counts = torch.unique(val_labels, return_counts=True)

    classification_mask = torch.zeros(
        len(val_labels)
    ).bool()  # NOTE(rob2u): mask to filter for classification metric calculation
    classification_mask.fill_(True)

    for label, count in zip(val_labels_unique, val_labels_counts):
        if count < min_amount:
            classification_mask[val_labels == label] = False

    combined_embeddings = torch.cat([train_embeddings, val_embeddings], dim=0)
    combined_labels = torch.cat([train_labels, val_labels], dim=0)

    num_classes: int = int(torch.max(combined_labels).item() + 1)
    assert num_classes == len(np.unique(combined_labels))
    if num_classes < k:
        k = num_classes

    distance_matrix: torch.Tensor
    if distance_metric == "cosine":
        distance_matrix = (
            torch.nn.functional.cosine_similarity(
                combined_embeddings.unsqueeze(0), combined_embeddings.unsqueeze(1), dim=-1
            )
            * -1.0
            + 1.0
        )  # range [0, 2]
    elif distance_metric == "euclidean":
        distance_matrix = pairwise_euclidean_distance(combined_embeddings)  # range [0, inf]
    else:
        raise ValueError(f"Unknown distance metric: {distance_metric}")

    distance_matrix.fill_diagonal_(float("inf"))

    distance_mask: torch.Tensor  # NOTE(rob2u): mask to filter for distance calculation
    samples_left = [
        sum(val_labels == val_labels[i]).item() - 1 for i in range(len(val_labels)) if classification_mask[i]
    ]
    if use_crossvideo_positives:
        distance_mask, classification_mask_cv = _get_crossvideo_masks(val_labels, val_ids)
        samples_left = [
            sum([distance_mask[i][j] for j in range(len(val_labels)) if val_labels[i] == val_labels[j]]).item()
            for i in range(len(val_labels))
            if classification_mask_cv[i]
        ]

        classification_mask = classification_mask & classification_mask_cv
        if use_train_embeddings:  # add train embeddings to the distance mask (shapes would not match otherwise)
            train_distance_mask = torch.ones((len(train_labels), len(train_labels) + len(val_labels)))
            distance_mask = torch.cat([torch.ones((len(val_labels), len(train_labels))), distance_mask], dim=1)
            distance_mask = torch.cat([train_distance_mask, distance_mask], dim=0)
            distance_mask = distance_mask.bool()
        distance_matrix[~distance_mask] = float("inf")

    _, closest_indices = torch.topk(
        distance_matrix,
        k,
        largest=False,
        sorted=True,
    )
    assert closest_indices.shape == (len(combined_embeddings), k)

    closest_labels = combined_labels[closest_indices]
    assert closest_labels.shape == closest_indices.shape

    classification_matrix = torch.zeros((len(combined_embeddings), num_classes))
    for i in range(num_classes):
        classification_matrix[:, i] = torch.sum(closest_labels == i, dim=1) / k

    # NOTE(rob2u): break ties by using the nearest neighbor (tie is when the the two closest neighbors have the same label)
    for i in range(len(combined_embeddings)):
        max_prob = torch.max(classification_matrix[i])
        max_prob_indices = torch.where(max_prob - classification_matrix[i] < 1e-6)[0]

        if len(max_prob_indices) == 1:
            continue
            # add 1e-6 to the closest indice of the max_prob_indices substract elsewhere (in max_prob_indices)

        classification_matrix[i, max_prob_indices] += (1e-6) / len(max_prob_indices)
        for j in range(k):
            if closest_indices[i][j] in max_prob_indices:
                classification_matrix[i][closest_labels[i][j].int()] += 1e-6
                break

    assert classification_matrix.shape == (len(combined_embeddings), num_classes)

    # Select only the validation part of the classification matrix
    val_classification_matrix = classification_matrix[-len(val_embeddings) :]

    val_classification_matrix = val_classification_matrix[classification_mask]
    val_labels = val_labels[classification_mask]

    return val_classification_matrix, val_labels, samples_left

### load model

In [None]:
from gorillatracker.model.wrappers_supervised import BaseModuleSupervised

model = BaseModuleSupervised.load_from_checkpoint(
    "/workspaces/gorillatracker/models/roberts_models/gorillas_models/vit_large_dinov2_bayes/fold-0-epoch-19-cxlkfold/fold-0/val/embeddings/knn5_crossvideo/accuracy-0.63.ckpt",
    data_module=None,
    strict=False,
).model_wrapper
model.eval()

### init cxl-dataset for fold-0 and bristol dataset

In [None]:
from pathlib import Path

from gorillatracker.data.nlet import SupervisedKFoldDataset, SupervisedDataset
from gorillatracker.data.builder import build_onelet
from torchvision.transforms import ToTensor, Compose, Resize, Normalize
from gorillatracker.transform_utils import SquarePad

transformations = Compose(
    [  # NOTE(rob2u): Square Padding and to Tensor is applied in the dataset
        Resize((224, 224)),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

cxl_dataset_fold_0_train = SupervisedKFoldDataset(
    Path("/workspaces/gorillatracker/data/supervised/splits/cxl_faces_openset_seed42_square_kfold-5"),
    build_onelet,
    "train",
    0,
    5,
    transformations,
    aug_num_ops=0,
    aug_magnitude=0,
)

cxl_dataset_fold_0_val = SupervisedKFoldDataset(
    Path("/workspaces/gorillatracker/data/supervised/splits/cxl_faces_openset_seed42_square_kfold-5"),
    build_onelet,
    "test",
    0,
    5,
    transformations,
    aug_num_ops=0,
    aug_magnitude=0,
)

bristol_dataset = SupervisedDataset(
    Path(
        "/workspaces/gorillatracker/data/supervised/bristol/cross_encounter_validation/cropped_frames_square_filtered"
    ),
    build_onelet,
    "val",  # NOTE(rob2u): we have specified a directory without a val directory therefore all the data is used for validation
    transformations,
    aug_num_ops=0,
    aug_magnitude=0,
)

### build dataframe

In [5]:
bristol_labels = []
bristol_embeddings = []
bristol_ids = []

for sample in bristol_dataset:
    sample_id = sample[0][0]
    sample_img = sample[1][0]
    sample_label = sample[2][0]

    with torch.no_grad():
        embedding = model(sample_img.unsqueeze(0).cuda()).cpu().squeeze(0).numpy()

    bristol_labels.append(sample_label)
    bristol_embeddings.append(embedding)
    bristol_ids.append(sample_id)

bristol_df = pd.DataFrame(
    {
        "label": bristol_labels,
        "embedding": bristol_embeddings,
        "id": bristol_ids,
        "partition": "val",
        "dataset": "bristol",
    }
)

cxl_val_labels = []
cxl_val_embeddings = []
cxl_val_ids = []

for sample in cxl_dataset_fold_0_val:
    sample_id = sample[0][0]
    sample_img = sample[1][0]
    sample_label = sample[2][0]

    with torch.no_grad():
        embedding = model(sample_img.unsqueeze(0).cuda()).cpu().squeeze(0).numpy()

    cxl_val_labels.append(sample_label)
    cxl_val_embeddings.append(embedding)
    cxl_val_ids.append(sample_id)


cxl_val_df = pd.DataFrame(
    {
        "label": cxl_val_labels,
        "embedding": cxl_val_embeddings,
        "id": cxl_val_ids,
        "partition": "val",
        "dataset": "cxl",
    }
)

cxl_train_labels = []
cxl_train_embeddings = []
cxl_train_ids = []

for sample in cxl_dataset_fold_0_train:
    sample_id = sample[0][0]
    sample_img = sample[1][0]
    sample_label = sample[2][0]

    with torch.no_grad():
        embedding = model(sample_img.unsqueeze(0).cuda()).cpu().squeeze(0).numpy()

    cxl_train_labels.append(sample_label)
    cxl_train_embeddings.append(embedding)
    cxl_train_ids.append(sample_id)

cxl_train_df = pd.DataFrame(
    {
        "label": cxl_train_labels,
        "embedding": cxl_train_embeddings,
        "id": cxl_train_ids,
        "partition": "train",
        "dataset": "cxl",
    }
)


# encode all labels
from gorillatracker.utils.labelencoder import LinearSequenceEncoder

label_encoder = LinearSequenceEncoder()
bristol_df["encoded_label"] = label_encoder.encode_list(bristol_df["label"].tolist())
label_encoder_2 = LinearSequenceEncoder()
cxl_val_df["encoded_label"] = label_encoder_2.encode_list(cxl_val_df["label"].tolist())
label_encoder_3 = LinearSequenceEncoder()
cxl_train_df["encoded_label"] = label_encoder_3.encode_list(cxl_train_df["label"].tolist())

### evaluate model

In [22]:
cxl_val_predictions, cxl_val_labels, samples_left = knn(
    cxl_val_df,
    k=5,
    use_train_embeddings=False,
    use_crossvideo_positives=False,
    distance_metric="euclidean",
    use_filter=True,
)

samples_left = np.array(samples_left)

In [None]:
num_classes = len(np.unique(cxl_val_df["label"].tolist()))

accuracies = tm.functional.accuracy(
    cxl_val_predictions, cxl_val_labels, task="multiclass", num_classes=num_classes, average="none"
)

bad_macro_accuracy = tm.functional.accuracy(
    cxl_val_predictions, cxl_val_labels, task="multiclass", num_classes=num_classes, average="macro"
)

weighted_accuracy = tm.functional.accuracy(
    cxl_val_predictions, cxl_val_labels, task="multiclass", num_classes=num_classes, average="weighted"
)
print("Bad Macro Accuracy: ", bad_macro_accuracy)
print("Weighted Accuracy: ", weighted_accuracy)

### calculate correct macro accuracy

In [None]:
label_2_count_acc = {}

for label in np.unique(cxl_val_df["encoded_label"].tolist()):
    label_count = sum(cxl_val_df["encoded_label"].tolist() == label)
    label_count_left = sum(cxl_val_labels == label).item()
    acc = accuracies[label].item()
    label_2_count_acc[label] = (label_count, label_count_left, acc)

macro_acc = sum([acc for _, left, acc in label_2_count_acc.values() if left > 0]) / sum(
    [1 for _, left, _ in label_2_count_acc.values() if left > 0]
)
weighted_acc = sum([acc * left for _, left, acc in label_2_count_acc.values()]) / sum(
    [left for _, left, _ in label_2_count_acc.values() if left > 0]
)

print("Macro Accuracy: ", macro_acc)
print("Weighted Accuracy: ", weighted_acc)

In [None]:
print(samples_left)

In [None]:
buckets = [5, 10, 20, 30, 10000]
true_rate_buckets = {
    5: [],
    10: [],
    20: [],
    30: [],
    10000: [],
}

for samples_left_, cxl_val_prediction, cxl_val_label in zip(samples_left, cxl_val_predictions, cxl_val_labels):
    pred = torch.argmax(cxl_val_prediction).item()

    # get bucket
    for bucket in buckets:
        if samples_left_ <= bucket:
            true_rate_buckets[bucket].append(pred == cxl_val_label.item())
            break


true_rate_buckets_avg = {
    bucket: np.mean(values) if len(values) > 0 else 0.0 for bucket, values in true_rate_buckets.items()
}

plt.bar(
    range(len(true_rate_buckets_avg)),
    list(true_rate_buckets_avg.values()),
    align="center",
    color="orange",
    edgecolor="black",
)
plt.xticks(range(len(true_rate_buckets_avg)), ["3 - 5", "6 - 10", "11 - 20", "21 - 30", "> 30"])
plt.grid(axis="y")
plt.ylabel("True Positive Rate")
plt.xlabel("Number of samples used for classification")
sns.despine()
# plt.savefig("plots/macro/true_rate_knn5_CROSSVIDEO_fold0.pdf", bbox_inches="tight", dpi=500)
plt.show()

In [None]:
# plot accuracy vs average samples left for this class
samples_left_avg = {
    label: (
        sum([samples_left[i] for i in range(len(cxl_val_labels)) if cxl_val_labels[i] == label])
        / sum([1 for i in range(len(cxl_val_labels)) if cxl_val_labels[i] == label])
        if sum(cxl_val_labels == label) > 0
        else 0
    )
    for label in np.unique(cxl_val_df["encoded_label"].tolist())
}


samples_left_avg = [samples_left_avg[label] for label in np.unique(cxl_val_df["encoded_label"].tolist())]

# filter to only keep the labels that have samples left
samples_left_avg, accuracies_samples_left = zip(
    *[(samples_left_avg[i], accuracies[i].item()) for i in range(len(samples_left_avg)) if sum(cxl_val_labels == i) > 0]
)

plt.scatter(samples_left_avg, accuracies_samples_left, color="darkred", marker="x", s=100)
plt.grid()
plt.xlabel("Average Samples Left for classification")
plt.ylabel("Accuracy")
sns.despine()
# plt.savefig("plots/macro/accuracy_vs_samples_left_knn5_CROSSVIDEO_fold0.pdf", bbox_inches="tight", dpi=500)
plt.show()

In [None]:
# plot accuracy vs samples (not left but total) for class

samples_total = {
    label: sum(cxl_val_labels == label).item() for label in np.unique(cxl_val_df["encoded_label"].tolist())
}

samples_total = [samples_total[label] for label in np.unique(cxl_val_df["encoded_label"].tolist())]

samples_total, accuracies_samples_total = zip(
    *[(samples_total[i], accuracies[i]) for i in range(len(samples_total)) if sum(cxl_val_labels == i) > 0]
)

plt.scatter(samples_total, accuracies_samples_total, color="darkred", marker="x", s=100)
plt.grid()
sns.despine()
plt.xlabel("Samples per class")
plt.ylabel("Accuracy")
plt.savefig("plots/macro/accuracy_vs_samples_classified_knn5_test.pdf", bbox_inches="tight", dpi=500)
plt.show()