In [1]:
from functools import partial
from itertools import islice
from typing import Any, Dict, List, Optional, Tuple, Union

import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
import PIL
import seaborn as sns
import sklearn
import torch
import torchmetrics as tm
import wandb
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from sklearn.manifold import TSNE
from torchmetrics.functional import pairwise_euclidean_distance
from torchvision.transforms import ToPILImage

import gorillatracker.type_helper as gtypes
from gorillatracker.utils.labelencoder import LinearSequenceEncoder
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import accuracy_score, top_k_accuracy_score, roc_auc_score, f1_score, precision_score

In [2]:
def knn_ssl(embeddings: torch.Tensor, labels: torch.Tensor, k:int = 1) -> Dict[str, Any]:
    negatives: Dict[int, List[int]] = {}
    
    true_labels = []
    pred_labels = []
    
    for label in labels.unique():
        subset_labels = negatives[label.item()] + label.item()
        if(len(subset_labels) < 2):
            continue
        subset_mask = torch.isin(labels, subset_labels)
        subset_embeddings = embeddings[subset_mask]
        knn = NearestNeighbors(n_neighbors=k+1,algorithm='auto').fit(subset_embeddings.numpy())
        current_label_mask = (subset_labels == label.item())
        current_label_embeddings = subset_embeddings[current_label_mask]
        distances, indices = knn.kneighbors(current_label_embeddings.numpy())
        distances = distances[:, 1:]
        indices = indices[:, 1:]
        for idx_list in indices:
            neighbor_labels = subset_labels[idx_list]
            most_common = torch.mode(neighbor_labels).values.item()
            true_labels.append(label.item())
            pred_labels.append(most_common)
            
    true_labels = torch.tensor(true_labels)
    pred_labels = torch.tensor(pred_labels)
    
    accuracy = accuracy_score(true_labels, pred_labels)
    accuracy_top5 = top_k_accuracy_score(true_labels, pred_labels, k=5)
    auroc = roc_auc_score(true_labels, pred_labels, multi_class='ovr')
    f1 = f1_score(true_labels, pred_labels, average='macro')
    precision = precision_score(true_labels, pred_labels, average='macro')
    return {'accuracy': accuracy, 'accuracy_top5': accuracy_top5, 'auroc': auroc, 'f1': f1, 'precision': precision}